Forráskód Böngészése

TensorFlow Lite variadic inputs

Lutz Roeder 8 éve
szülő
commit
e74e599434
2 módosított fájl, 38 hozzáadás és 32 törlés
  1. 35 32
      src/tflite-model.js
  2. 3 0
      src/tflite-operator.json

+ 35 - 32
src/tflite-model.js

@@ -206,29 +206,20 @@ class TensorFlowLiteNode {
     }
 
     get inputs() {
-        var results = [];
-        var graph = this._graph._graph;
-        var node = this._node;
-        for (var i = 0; i < node.inputsLength(); i++) {
-            var input = {
-                name: TensorFlowLiteOperatorMetadata.operatorMetadata.getInputName(this.operator, i),
-                connections: []
-            };
-            var tensorIndex = node.inputs(i);
-            if (tensorIndex != -1) {
-                var tensor = graph.tensors(tensorIndex);
-                var connection = {};
-                connection.id = tensorIndex.toString();
+        var inputs = TensorFlowLiteOperatorMetadata.operatorMetadata.getInputs(this._node, this.operator);
+        inputs.forEach((input) => {
+            input.connections.forEach((connection) => {
+                var tensorIndex = connection.id;
+                var tensor = this._graph._graph.tensors(tensorIndex);
                 connection.type = TensorFlowLiteTensor.formatTensorType(tensor);
                 var initializer = this._graph.getInitializer(tensorIndex);
                 if (initializer) {
                     connection.initializer = initializer;
                 }
-                input.connections.push(connection);
-            }
-            results.push(input);
-        }
-        return results;
+                connection.id = connection.id.toString();
+            });
+        });
+        return inputs;
     }
 
     get outputs() {
@@ -585,23 +576,35 @@ class TensorFlowLiteOperatorMetadata {
         }
     }
 
-    getInputName(operator, index) {
+    getInputs(node, operator) {
+        var results = [];
+        var connections = [];
+        for (var i = 0; i < node.inputsLength(); i++) {
+            connections.push(node.inputs(i));
+        }
         var schema = this._map[operator];
-        if (schema) {
-            var inputs = schema.inputs;
-            if (inputs && index < inputs.length) {
-                var input = inputs[index];
-                if (input) {
-                    if (!input.option || input.option != 'variadic') {
-                        var name = input.name;
-                        if (name) {
-                            return name;
-                        }
-                    }
-                } 
+        var index = 0;
+        while (index < connections.length) {
+            var result = { connections: [] };
+            var count = 1;
+            var name = null;
+            if (schema && schema.inputs && index < schema.inputs.length) {
+                name = schema.inputs[index].name;
+                if (schema.inputs[index].option == 'variadic') {
+                    count = connections.length - index;
+                }
+            }
+            result.name = name ? name : '(' + index.toString() + ')';
+            var array = connections.slice(index, index + count);
+            for (var j = 0; j < array.length; j++) {
+                if (array[j] != -1) {
+                    result.connections.push({ id: array[j] });
+                }
             }
+            index += count;
+            results.push(result);
         }
-        return "(" + index.toString() + ")";
+        return results;
     }
 
     getOutputName(operator, index) {

+ 3 - 0
src/tflite-operator.json

@@ -232,6 +232,9 @@
     "name": "Concatenation",
     "schema": {
       "category": "Tensor",
+      "inputs": [
+        { "name": "inputs", "option": "variadic" }
+      ],
       "outputs": [
         { "name": "output" }
       ],