Lutz Roeder 3 лет назад
Родитель
Сommit
a8a27f36f2
2 измененных файлов с 33 добавлено и 0 удалено
  1. 19 0
      source/nnabla-metadata.json
  2. 14 0
      source/nnabla-proto.js

+ 19 - 0
source/nnabla-metadata.json

@@ -5855,6 +5855,20 @@
         "required": false,
         "default": 1,
         "description": "First dimension of the sample shape."
+      },
+      {
+        "name": "largest",
+        "type": "bool",
+        "required": false,
+        "default": true,
+        "description": "Whether to select the `k` largest or smallest values."
+      },
+      {
+        "name": "with_index",
+        "type": "bool",
+        "required": false,
+        "default": false,
+        "description": "Return top-k values and indices."
       }
     ],
     "outputs": [
@@ -5862,6 +5876,11 @@
         "name": "y",
         "type": "nnabla.Variable",
         "description": "N-D array."
+      },
+      {
+        "name": "indices",
+        "type": "nnabla.Variable",
+        "description": "N-D array of top-k indices."
       }
     ]
   },

+ 14 - 0
source/nnabla-proto.js

@@ -9913,6 +9913,12 @@ $root.nnabla.TopKDataParameter = class TopKDataParameter {
                 case 4:
                     message.base_axis = reader.int64();
                     break;
+                case 5:
+                    message.largest = reader.bool();
+                    break;
+                case 6:
+                    message.with_index = reader.bool();
+                    break;
                 default:
                     reader.skipType(tag & 7);
                     break;
@@ -9939,6 +9945,12 @@ $root.nnabla.TopKDataParameter = class TopKDataParameter {
                 case "base_axis":
                     message.base_axis = reader.int64();
                     break;
+                case "largest":
+                    message.largest = reader.bool();
+                    break;
+                case "with_index":
+                    message.with_index = reader.bool();
+                    break;
                 default:
                     reader.field(tag, message);
                     break;
@@ -9952,6 +9964,8 @@ $root.nnabla.TopKDataParameter.prototype.k = protobuf.Int64.create(0);
 $root.nnabla.TopKDataParameter.prototype.abs = false;
 $root.nnabla.TopKDataParameter.prototype.reduce = false;
 $root.nnabla.TopKDataParameter.prototype.base_axis = protobuf.Int64.create(0);
+$root.nnabla.TopKDataParameter.prototype.largest = false;
+$root.nnabla.TopKDataParameter.prototype.with_index = false;
 
 $root.nnabla.TopKGradParameter = class TopKGradParameter {