Просмотр исходного кода

Fix TensorFlow Lite chain operators (#416)

Lutz Roeder 6 лет назад
Родитель
Сommit
b652161080
5 измененных файлов с 83 добавлено и 23 удалено
  1. 48 0
      src/tflite-metadata.json
  2. 18 17
      src/tflite.js
  3. 9 5
      src/view.js
  4. 7 0
      test/models.json
  5. 1 1
      test/test.js

+ 48 - 0
src/tflite-metadata.json

@@ -480,6 +480,51 @@
       ]
     }
   },
+  {
+    "name": "Sum",
+    "schema": {
+      "inputs": [
+        { "name": "input", "type": "T" },
+        { "name": "axis", "type": "T" }
+      ],
+      "outputs": [
+        { "name": "output", "type": "T" }
+      ],
+      "attributes": [
+        { "name": "keep_dims", "type": "boolean" }
+      ]
+    }
+  },
+  {
+    "name": "ReduceMax",
+    "schema": {
+      "inputs": [
+        { "name": "input", "type": "T" },
+        { "name": "axis", "type": "T" }
+      ],
+      "outputs": [
+        { "name": "output", "type": "T" }
+      ],
+      "attributes": [
+        { "name": "keep_dims", "type": "boolean" }
+      ]
+    }
+  },
+  {
+    "name": "ReduceMin",
+    "schema": {
+      "inputs": [
+        { "name": "input", "type": "T" },
+        { "name": "axis", "type": "T" }
+      ],
+      "outputs": [
+        { "name": "output", "type": "T" }
+      ],
+      "attributes": [
+        { "name": "keep_dims", "type": "boolean" }
+      ]
+    }
+  },
   {
     "name": "Mean",
     "schema": {
@@ -489,6 +534,9 @@
       ],
       "outputs": [
         { "name": "output", "type": "T" }
+      ],
+      "attributes": [
+        { "name": "keep_dims", "type": "boolean" }
       ]
     }
   },

+ 18 - 17
src/tflite.js

@@ -64,7 +64,7 @@ tflite.Model = class {
         this._graphs = [];
         this._format = 'TensorFlow Lite v' + model.version().toString();
         this._description = model.description() || '';
-        let operatorCodeList = [];
+        let operators = [];
         let builtinOperatorMap = {};
         for (const key of Object.keys(tflite.schema.BuiltinOperator)) {
             const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]);
@@ -76,14 +76,14 @@ tflite.Model = class {
         for (let operatorIndex = 0; operatorIndex < model.operatorCodesLength(); operatorIndex++) {
             const operatorCode = model.operatorCodes(operatorIndex);
             const builtinCode = operatorCode.builtinCode();
-            operatorCodeList.push(builtinCode === tflite.schema.BuiltinOperator.CUSTOM ?
+            operators.push(builtinCode === tflite.schema.BuiltinOperator.CUSTOM ?
                 { name: operatorCode.customCode(), custom: true } :
                 { name: builtinOperatorMap[builtinCode] });
         }
         const subgraphsLength = model.subgraphsLength();
         for (let subgraph = 0; subgraph < subgraphsLength; subgraph++) {
             const name = (subgraphsLength > 1) ? subgraph.toString() : '';
-            this._graphs.push(new tflite.Graph(metadata, model.subgraphs(subgraph), name, operatorCodeList, model));
+            this._graphs.push(new tflite.Graph(metadata, model.subgraphs(subgraph), name, operators, model));
         }
     }
 
@@ -102,7 +102,7 @@ tflite.Model = class {
 
 tflite.Graph = class {
 
-    constructor(metadata, graph, name, operatorCodeList, model) {
+    constructor(metadata, graph, name, operators, model) {
         this._name = graph.name() || name;
         this._nodes = [];
         this._inputs = [];
@@ -123,7 +123,7 @@ tflite.Graph = class {
         for (let j = 0; j < graph.operatorsLength(); j++) {
             const node = graph.operators(j);
             const opcodeIndex = node.opcodeIndex();
-            const operator = (opcodeIndex < operatorCodeList.length) ? operatorCodeList[opcodeIndex] : { name: '(' + opcodeIndex.toString() + ')' };
+            const operator = (opcodeIndex < operators.length) ? operators[opcodeIndex] : { name: '(' + opcodeIndex.toString() + ')' };
             this._nodes.push(new tflite.Node(metadata, node, operator, args));
         }
         for (let k = 0; k < graph.inputsLength(); k++) {
@@ -219,10 +219,16 @@ tflite.Node = class {
             }
             let optionsTypeName = this.operator + 'Options';
             switch (this.operator) {
-                case 'MaxPool2D':
                 case 'AveragePool2D':
+                case 'MaxPool2D':
                     optionsTypeName = 'Pool2DOptions';
                     break;
+                case 'Mean':
+                case 'ReduceMax':
+                case 'ReduceMin':
+                case 'Sum':
+                    optionsTypeName = 'ReducerOptions';
+                    break;
             }
             const optionsType = tflite.Node._getType(optionsTypeName);
             if (typeof optionsType === 'function') {
@@ -260,19 +266,14 @@ tflite.Node = class {
                             else {
                                 value = options[attributeName]();
                             }
-                            const attribute = new tflite.Attribute(this._metadata, this.operator, attributeName, value);
-                            if (attribute.name == 'fused_activation_function') {
-                                value = attribute.value;
-                                if (attribute.value != 'NONE') {
-                                    const activationFunctionMap = { 'RELU': 'Relu', 'RELU_N1_TO_1': "ReluN1To1", "RELU6": "Relu6", "TANH": "Tanh", "SIGN_BIT": "SignBit" };
-                                    if (activationFunctionMap[value]) {
-                                        value = activationFunctionMap[value];
-                                    }
-                                    this._chain = [];
-                                    this._chain.push(new tflite.Node(metadata, null, { name: value }, []));
+                            if (attributeName === 'fusedActivationFunction' && value !== 0) {
+                                const activationFunctionMap = { 1: 'Relu', 2: "ReluN1To1", 3: "Relu6", 4: "Tanh", 5: "SignBit" };
+                                if (activationFunctionMap[value]) {
+                                    value = activationFunctionMap[value];
                                 }
+                                this._chain = [ new tflite.Node(metadata, null, { name: value }, []) ];
                             }
-                            this._attributes.push(attribute);
+                            this._attributes.push(new tflite.Attribute(this._metadata, this.operator, attributeName, value));
                         }
                     }
                 }

+ 9 - 5
src/view.js

@@ -380,13 +380,13 @@ view.View = class {
                     }  
                 }
             }
-            return this.renderGraph(graph).then(() => {
+            return this.renderGraph(model, graph).then(() => {
                 this._model = model;
                 this._activeGraph = graph;
                 this.show('Graph');
                 return this._model;
             }).catch((error) => {
-                return this.renderGraph(this._activeGraph).then(() => {
+                return this.renderGraph(this._model, this._activeGraph).then(() => {
                     this.show('Graph');
                     throw error;
                 }).catch(() => {
@@ -396,7 +396,7 @@ view.View = class {
         });
     }
 
-    renderGraph(graph) {
+    renderGraph(model, graph) {
         try {
             if (!graph) {
                 return Promise.resolve();
@@ -471,8 +471,12 @@ view.View = class {
                         if (category) {
                             styles.push('node-item-operator-' + category.toLowerCase());
                         }
-                        const content = self.showNames && node.name ? node.name : node.operator.split('.').pop();
-                        const tooltip = self.showNames && node.name ? node.operator : node.name;
+                        const operator = node.operator;
+                        if (typeof operator !== 'string' || !operator.split) { // #416
+                            throw new ModelError("Unknown node operator '" + JSON.stringify(operator) + "' in '" + model.format + "'.");
+                        }
+                        const content = self.showNames && node.name ? node.name : operator.split('.').pop();
+                        const tooltip = self.showNames && node.name ? operator : node.name;
                         header.add(null, styles, content, tooltip, () => { 
                             self.showNodeProperties(node, null);
                         });

+ 7 - 0
test/models.json

@@ -5241,6 +5241,13 @@
     "format": "TensorFlow Lite v3",
     "link":   "https://github.com/lutzroeder/netron/issues/386"
   },
+  {
+    "type":   "tflite",
+    "target": "netron_issue_416.tflite",
+    "source": "https://github.com/lutzroeder/netron/files/4089319/netron_issue_416.zip[netron_issue_416.tflite]",
+    "format": "TensorFlow Lite v3",
+    "link":   "https://github.com/lutzroeder/netron/issues/416"
+  },
   {
     "type":   "tflite",
     "target": "pose_estimation_for_mobile.tflite",

+ 1 - 1
test/test.js

@@ -563,7 +563,7 @@ function render(model) {
         if (!currentView.showInitializers) {
             currentView.toggleInitializers();
         }
-        return currentView.renderGraph(model.graphs[0]);
+        return currentView.renderGraph(model, model.graphs[0]);
     }
     catch (error) {
         return Promise.reject(error);