Browse Source

Fold ONNX Constant nodes with single output into graph

Lutz Roeder 8 years ago
parent
commit
7326da8c99
8 changed files with 267 additions and 37 deletions
  1. 1 3
      package.json
  2. 1 0
      setup.py
  3. 1 1
      src/app.js
  4. 119 0
      src/tflite-operator.json
  5. 35 9
      src/view-onnx.js
  6. 30 17
      src/view-tf.js
  7. 67 4
      src/view-tflite.js
  8. 13 3
      src/view.js

+ 1 - 3
package.json

@@ -41,15 +41,13 @@
         "fileAssociations": [
             {
                 "ext": [
-                    "saved_model.pb",
-                    "saved_model.pbtxt"
+                    "saved_model.pb"
                 ],
                 "name": "TensorFlow Saved Model"
             },
             {
                 "ext": [
                     "pb",
-                    "pbtxt",
                     "onnx"
                 ],
                 "name": "ONNX Model"

+ 1 - 0
setup.py

@@ -25,6 +25,7 @@ package_data={
         'onnx-operator.json',
         'tf.js',
         'tflite.js',
+        'tflite-operator.json',
         'favicon.ico',
         'view-browser.html',
         'view-browser.js',

+ 1 - 1
src/app.js

@@ -29,7 +29,7 @@ if (quit) {
 
 function openFileDialog() {
     var showOpenDialogOptions = { 
-        properties: [ 'openFile'], 
+        properties: [ 'openFile' ], 
         filters: [
             { name: 'ONNX Model', extensions: [ 'onnx', 'pb' ] },
             { name: 'TensorFlow Saved Model', extensions: [ 'saved_model.pb' ] },

+ 119 - 0
src/tflite-operator.json

@@ -0,0 +1,119 @@
+[
+  {
+    "name": "Conv2D",
+    "schema": {
+      "inputs": [
+        { "name": "X", "typeStr": "T" },
+        { "name": "weights", "typeStr": "T" },
+        { "name": "bias", "typeStr": "T" }
+      ],
+      "outputs": [
+        { "name": "Y", "typeStr": "T" }
+      ]
+    }
+  },
+  {
+    "name": "DepthwiseConv2D",
+    "schema": {
+      "inputs": [
+        { "name": "X", "typeStr": "T" },
+        { "name": "weights", "typeStr": "T" },
+        { "name": "bias", "typeStr": "T" }
+      ],
+      "outputs": [
+        { "name": "Y", "typeStr": "T" }
+      ]
+    }
+  },
+  {
+    "name": "AveragePool2D",
+    "schema": {
+      "inputs": [
+        { "name": "X", "typeStr": "T" }
+      ],
+      "outputs": [
+        { "name": "Y", "typeStr": "T" }
+      ]
+    }
+  },
+  {
+    "name": "Softmax",
+    "schema": {
+      "inputs": [
+        { "name": "input", "typeStr": "T" }
+      ],
+      "outputs": [
+        { "name": "output", "typeStr": "T" }
+      ]
+    }
+  },
+  {
+    "name": "Reshape",
+    "schema": {
+      "outputs": [
+        { "name": "reshaped", "typeStr": "T" }
+      ]
+    }
+  },
+  {
+    "name": "MaxPool2D",
+    "schema": {
+      "inputs": [
+        { "name": "X", "typeStr": "T" }
+      ]
+    }
+  },
+  {
+    "name": "LSHProjection",
+    "schema": {
+      "inputs": [
+        { "name": "hash" },
+        { "name": "input" },
+        { "name": "weight" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
+  {
+    "name": "Normalize",
+    "schema": {
+      "inputs": [
+        { "name": "input" }
+      ],
+      "outputs": [
+        { "name": "output" }
+      ]
+    }
+  },
+  {
+    "name": "Predict",
+    "schema": {
+      "inputs": [
+        { "name": "hashes" },
+        { "name": "keys" },
+        { "name": "labels" },
+        { "name": "weights" }
+      ],
+      "outputs": [
+        { "name": "label" },
+        { "name": "weight" }
+      ]
+    }
+  },
+  {
+    "name": "HashtableLookup",
+    "schema": {
+      "inputs": [
+        { "name": "key" },
+        { "name": "keys" },
+        { "name": "values" }
+      ],
+      "outputs": [
+        { "name": "value" },
+        { "name": "hits" }
+      ]
+    }
+  }
+]

+ 35 - 9
src/view-onnx.js

@@ -5,13 +5,17 @@ var onnx = protobuf.roots.onnx.onnx;
 class OnnxModel {
 
     constructor(hostService) {
-        this.operatorMetadata = new OnnxOperatorMetadata(hostService);
+        this.hostService = hostService;
     }
 
     openBuffer(buffer, identifier) { 
         try {
             this.model = onnx.ModelProto.decode(buffer);
             this.activeGraph = this.model.graph;
+
+            if (!this.operatorMetadata) {
+                this.operatorMetadata = new OnnxOperatorMetadata(this.hostService);
+            }
         }
         catch (err) {
             return err;
@@ -108,23 +112,45 @@ class OnnxModel {
         graph.initializer.forEach((tensor) => {
             var result = this.formatTensor(tensor);
             result.id = tensor.name;
+            result.title = 'Initializer';
             results.push(result);
         });
-    /*    graph.node.forEach((node) => {
-            if (node.opType == 'Constant') {
+        graph.node.forEach((node) => {
+            if (node.opType == 'Constant' && node.output && node.output.length == 1) {
+                var result = null;
                 node.attribute.forEach((attribute) => {
-                    if (attribute.name == 'value') {
-                        result[node.output[0]] = attribute.value;
-                    }
+                    if (attribute.name == 'value' && attribute.t) {
+                        result = this.formatTensor(attribute.t);
+                    }                    
                 });
+                if (result) {
+                    result.id = node.output[0];
+                    if (!result.name) {
+                        result.name = result.id;
+                    }
+                    result.title = 'Constant';
+                    results.push(result);
+                }
             }
-        }); */
+        });
         return results;
     }
 
     getNodes(graph) {
-        return graph.node;
-        // return graph.node.filter(node => node.opType != 'Constant');
+        var results = [];
+        var initializerMap = {}
+        this.getGraphInitializers(graph).forEach((initializer) => {
+            initializerMap[initializer.id] = true;
+        });
+        graph.node.forEach((node) => {
+            if (node.opType == 'Constant' && node.output.length == 1 && initializerMap[node.output[0]]) {
+
+            }
+            else {
+                results.push(node);
+            }
+        });
+        return results;
     }
 
     getNodeOperator(node) {

+ 30 - 17
src/view-tf.js

@@ -12,8 +12,19 @@ class TensorFlowModel {
 
     openBuffer(buffer, identifier) { 
         try {
-            this.model = tensorflow.SavedModel.decode(buffer);
-            this.activeGraph = (this.model.metaGraphs.length > 0) ? this.model.metaGraphs[0] : null;
+            if (identifier == 'saved_model.pb') {
+                this.model = tensorflow.SavedModel.decode(buffer);
+                this.activeGraph = (this.model.metaGraphs.length > 0) ? this.model.metaGraphs[0] : null;
+            }
+            else {
+                var graphDef = tensorflow.GraphDef.decode(buffer);
+                var metaGraph = new tensorflow.MetaGraphDef();
+                metaGraph.graphDef = graphDef;
+                var savedModel = new tensorflow.SavedModel();
+                savedModel.metaGraphs.push(metaGraph);
+                this.model = savedModel;
+                this.activeGraph = metaGraph;
+            }
         }
         catch (err) {
             return err;
@@ -89,9 +100,9 @@ class TensorFlowModel {
     }
 
     getNodes(graph) {
-    // graph.graphDef.node.forEach(function (node) {
-    //     console.log(node.name + ' [' + (!node.input ? "" : node.input.map(s => s).join(',')) + ']');
-    // });
+        // graph.graphDef.node.forEach(function (node) {
+        //     console.log(node.name + ' [' + (!node.input ? "" : node.input.map(s => s).join(',')) + ']');
+        // });
         var result = [];
         graph.graphDef.node.forEach(function (node) {
             if (node.op != 'Const') {
@@ -193,19 +204,21 @@ class TensorFlowGraphMetadata {
     constructor(metaInfoDef) {
         var self = this;
         self.schemaMap = {};
-        metaInfoDef.strippedOpList.op.forEach(function (opDef) {
-            var schema = { inputs: [], outputs: [], attributes: [] };
-            opDef.inputArg.forEach(function (inputArg) {
-                schema.inputs.push({ name: inputArg.name, typeStr: inputArg.typeAttr });
-            });
-            opDef.outputArg.forEach(function (outputArg) {
-                schema.outputs.push({ name: outputArg.name, typeStr: outputArg.typeAttr });
-            });
-            opDef.attr.forEach(function (attr) {
-                schema.attributes.push({ name: attr.name, type: attr.type });
+        if (metaInfoDef && metaInfoDef.strippedOpList && metaInfoDef.strippedOpList.op) {
+            metaInfoDef.strippedOpList.op.forEach(function (opDef) {
+                var schema = { inputs: [], outputs: [], attributes: [] };
+                opDef.inputArg.forEach(function (inputArg) {
+                    schema.inputs.push({ name: inputArg.name, typeStr: inputArg.typeAttr });
+                });
+                opDef.outputArg.forEach(function (outputArg) {
+                    schema.outputs.push({ name: outputArg.name, typeStr: outputArg.typeAttr });
+                });
+                opDef.attr.forEach(function (attr) {
+                    schema.attributes.push({ name: attr.name, type: attr.type });
+                });
+                self.schemaMap[opDef.name] = schema;
             });
-            self.schemaMap[opDef.name] = schema;
-        });
+        }
     }
 
     getInputName(operator, index) {

+ 67 - 4
src/view-tflite.js

@@ -4,9 +4,8 @@
 
 class TensorFlowLiteModel {
     
-    
     constructor(hostService) {
-        this.operatorService = null;
+        this.operatorMetadata = new TensorFlowLiteOperatorMetadata(hostService);
     }
 
     openBuffer(buffer, identifier) { 
@@ -174,9 +173,10 @@ class TensorFlowLiteModel {
         for (var i = 0; i < node.inputsLength(); i++) {
             var tensorIndex = node.inputs(i);
             var tensor = graph.tensors(tensorIndex);
+            var operator = this.getNodeOperator(node);
             result.push({
                 id: tensorIndex.toString(),
-                name: '(' + i.toString() + ')',
+                name: this.operatorMetadata.getInputName(operator, i),
                 type: this.formatTensorType(tensor)
             });
         }
@@ -188,9 +188,10 @@ class TensorFlowLiteModel {
         for (var i = 0; i < node.outputsLength(); i++) {
             var tensorIndex = node.outputs(i);
             var tensor = graph.tensors(tensorIndex);
+            var operator = this.getNodeOperator(node);
             result.push({
                 id: tensorIndex.toString(),
-                name: '(' + i.toString() + ')',
+                name: this.operatorMetadata.getOutputName(operator, i),
                 type: this.formatTensorType(tensor)
             });
         }
@@ -453,4 +454,66 @@ class TensorFlowLiteTensorFormatter {
         }
         return (s ? -1 : 1) * Math.pow(2, e-15) * (1 + (f / Math.pow(2, 10)));
     }
+}
+
+class TensorFlowLiteOperatorMetadata {
+    constructor() {
+        this.map = {};
+        hostService.request('/tflite-operator.json', (err, data) => {
+            if (err != null) {
+                // TODO error
+            }
+            else {
+                var items = JSON.parse(data);
+                if (items) {
+                    items.forEach((item) => {
+                        if (item.name && item.schema)
+                        {
+                            var name = item.name;
+                            var schema = item.schema;
+                            this.map[name] = schema;
+                        }
+                    });
+                }
+            }
+        });
+    }
+
+    getInputName(operator, index) {
+        var schema = this.map[operator];
+        if (schema) {
+            var inputs = schema.inputs;
+            if (inputs && index < inputs.length) {
+                var input = inputs[index];
+                if (input) {
+                    if (!input.option || input.option != 'variadic') {
+                        var name = input.name;
+                        if (name) {
+                            return name;
+                        }
+                    }
+                } 
+            }
+        }
+        return "(" + index.toString() + ")";
+    }
+
+    getOutputName(operator, index) {
+        var schema = this.map[operator];
+        if (schema) {
+            var outputs = schema.outputs;
+            if (outputs && index < outputs.length) {
+                var output = outputs[index];
+                if (output) {
+                    if (!output.option || output.option != 'variadic') {
+                        var name = output.name;
+                        if (name) {
+                            return name;
+                        }
+                    }
+                } 
+            }
+        }
+        return "(" + index.toString() + ")";
+    }
 }

+ 13 - 3
src/view.js

@@ -311,7 +311,7 @@ function showTensor(model, tensor) {
         var view = { 'items': [ tensor ] };
         var template = Handlebars.compile(itemsTemplate, 'utf-8');
         var data = template(view);
-        sidebar.open(data, 'Tensor');
+        sidebar.open(data, tensor.title ? tensor.title : 'Tensor');
     }
 }
 
@@ -397,7 +397,9 @@ ModelService.prototype.openBuffer = function(buffer, identifier, callback) {
     var model = null;
     var err = null;
 
-    if (identifier != null && identifier.split('.').pop() == 'tflite')
+    var extension = identifier.split('.').pop();
+
+    if (identifier != null && extension == 'tflite')
     {
         model = new TensorFlowLiteModel(hostService); 
         err = model.openBuffer(buffer, identifier);
@@ -406,9 +408,17 @@ ModelService.prototype.openBuffer = function(buffer, identifier, callback) {
         model = new TensorFlowModel(hostService);
         err = model.openBuffer(buffer, identifier);
     }
-    else {
+    else if (extension == 'onnx') {
+        model = new OnnxModel(hostService);
+        err = model.openBuffer(buffer, identifier);
+    }
+    else if (extension == 'pb') {
         model = new OnnxModel(hostService);
         err = model.openBuffer(buffer, identifier);
+        if (err) {
+            model = new TensorFlowModel(hostService);
+            err = model.openBuffer(buffer, identifier);
+        }
     }
 
     if (err) {