浏览代码

MXNet symbol.json support

Lutz Roeder 8 年之前
父节点
当前提交
04e5563839
共有 12 个文件被更改,包括 658 次插入20 次删除
  1. 1 1
      README.md
  2. 3 0
      electron-builder.yml
  3. 1 0
      setup.py
  4. 1 0
      src/app.js
  5. 2 2
      src/caffe-model.js
  6. 1 5
      src/keras-model.js
  7. 455 0
      src/mxnet-model.js
  8. 156 0
      src/mxnet-operator.json
  9. 1 0
      src/view-browser.html
  10. 2 1
      src/view-electron.html
  11. 16 11
      src/view.js
  12. 19 0
      tools/mxnet-generate

+ 1 - 1
README.md

@@ -5,7 +5,7 @@ Netron is a viewer for neural network and machine learning models.
 
 Netron supports **[ONNX](http://onnx.ai)** (`.onnx`, `.pb`), **Keras** (`.h5`, `.keras`), **CoreML** (`.mlmodel`) and **TensorFlow Lite** (`.tflite`). 
 
-Netron has experimental support for **Caffe** (`.caffemodel`) and **TensorFlow** (`.pb`, `.meta`).
+Netron has experimental support for **Caffe** (`.caffemodel`), **MXNet** (`-symbol.json`) and **TensorFlow** (`.pb`, `.meta`).
 
 <p align='center'><a href='https://www.lutzroeder.com/ai'><img src='media/screenshot.png' width='800'></a></p>
 

+ 3 - 0
electron-builder.yml

@@ -31,6 +31,9 @@ fileAssociations:
   - name: "Caffe Model"
     ext:
     - caffemodel
+  - name: "MXNet Model"
+    ext:
+    - json
 publish:
   - provider: github
     releaseType: release

+ 1 - 0
setup.py

@@ -77,6 +77,7 @@ setuptools.setup(
             'keras-model.js', 'keras-operator.json', 'hdf5.js',
             'coreml-model.js', 'coreml-operator.json', 'coreml.js',
             'caffe-model.js', 'caffe-operator.json', 'caffe.js',
+            'mxnet-model.js', 'mxnet-operator.json',
             'view-browser.html', 'view-browser.js',
             'view.js', 'view.css', 'view-render.css', 'view-render.js', 'view-template.js'
         ]

+ 1 - 0
src/app.js

@@ -111,6 +111,7 @@ class Application {
                 { name: 'Keras Model', extension: [ 'json', 'keras', 'h5' ] },
                 { name: 'CoreML Model', extension: [ 'mlmodel' ] },
                 { name: 'Caffe Model', extension: [ 'caffemodel' ] },
+                { name: 'MXNet Model', extension: [ 'json' ] },
                 { name: 'TensorFlow Graph', extensions: [ 'pb', 'meta' ] },
                 { name: 'TensorFlow Saved Model', extensions: [ 'saved_model.pb' ] },
                 { name: 'TensorFlow Lite Model', extensions: [ 'tflite' ] }

+ 2 - 2
src/caffe-model.js

@@ -556,7 +556,7 @@ class CaffeOperatorMetadata
                     return false;
                 }
                 while (length--) {
-                    if (!KerasOperatorMetadata.isEquivalent(a[length], b[length])) {
+                    if (!CaffeOperatorMetadata.isEquivalent(a[length], b[length])) {
                         return false;
                     }
                 }
@@ -570,7 +570,7 @@ class CaffeOperatorMetadata
         } 
         while (size--) {
             var key = keys[size];
-            if (!(b.hasOwnProperty(key) && KerasOperatorMetadata.isEquivalent(a[key], b[key]))) {
+            if (!(b.hasOwnProperty(key) && CaffeOperatorMetadata.isEquivalent(a[key], b[key]))) {
                 return false;
             }
         }

+ 1 - 5
src/keras-model.js

@@ -32,11 +32,7 @@ class KerasModel {
                 }
             }
             else if (extension == 'json') {
-                if (!window.TextDecoder) {
-                    throw new KerasError('TextDecoder not avaialble.');
-                }
-
-                var decoder = new TextDecoder('utf-8');
+                var decoder = new window.TextDecoder('utf-8');
                 json = decoder.decode(buffer);
             }
 

+ 455 - 0
src/mxnet-model.js

@@ -0,0 +1,455 @@
+/*jshint esversion: 6 */
+
+// Experimental
+
+class MXNetModel {
+
+    static open(buffer, identifier, host, callback) { 
+        MXNetModel.create(buffer, identifier, host, (err, model) => {
+            callback(err, model);
+        });
+    }
+
+    static create(buffer, identifier, host, callback) {
+        try {
+            var decoder = new TextDecoder('utf-8');
+            var json = decoder.decode(buffer);
+
+            var model = new MXNetModel(json);
+            MXNetOperatorMetadata.open(host, (err, metadata) => {
+                callback(null, model);
+            });
+        }
+        catch (err) {
+            callback(err, null);
+        }
+    }
+
+    constructor(json) {
+        var model = JSON.parse(json);
+        if (!model) {
+            throw new MXNetError('JSON file does not contain MXNet data.');
+        }
+        if (!model.hasOwnProperty('nodes')) {
+            throw new MXNetError('JSON file does not contain an MXNet \'nodes\' property.');
+        }
+        if (!model.hasOwnProperty('arg_nodes')) {
+            throw new MXNetError('JSON file does not contain an MXNet \'arg_nodes\' property.');
+        }
+        if (!model.hasOwnProperty('heads')) {
+            throw new MXNetError('JSON file does not contain an MXNet \'heads\' property.');
+        }
+
+        if (model.attrs && model.attrs.mxnet_version && model.attrs.mxnet_version.length == 2 && model.attrs.mxnet_version[0] == 'int') {
+            var version = model.attrs.mxnet_version[1];
+            var revision = version % 100;
+            var minor = Math.floor(version / 100) % 100;
+            var major = Math.floor(version / 10000) % 100;
+            this._version = major.toString() + '.' + minor.toString() + '.' + revision.toString(); 
+        }
+
+        this._graphs = [ new MXNetGraph(model) ];
+    }
+
+    get properties() {
+        var results = [];
+        results.push({ name: 'Format', value: 'MXNet' + (this._version ? (' v' + this._version) : '') });
+        return results;
+    }
+
+    get graphs() {
+        return this._graphs;
+    }
+
+}
+
+class MXNetGraph {
+
+    constructor(json)
+    {
+        var nodes = json.nodes;
+
+        this._nodes = [];
+        json.nodes.forEach((node) => {
+            node.outputs = [];
+        });
+
+        nodes.forEach((node) => {
+            node.inputs = node.inputs.map((input) => {
+                return MXNetGraph.updateOutput(nodes, input);
+            });
+        });
+
+        var argumentMap = {};
+        json.arg_nodes.forEach((index) => {
+            argumentMap[index] = (index < nodes.length) ? nodes[index] : null;
+        });
+
+        this._outputs = [];
+        var headMap = {};
+        json.heads.forEach((head, index) => {
+            var id = MXNetGraph.updateOutput(nodes, head);
+            var name = 'output' + ((index == 0) ? '' : (index + 1).toString());
+            this._outputs.push({ id: id, name: name });
+        });
+
+        nodes.forEach((node, index) => {
+            if (!argumentMap[index]) {
+                this._nodes.push(new MXNetNode(node, argumentMap));
+            }
+        });
+
+        this._inputs = [];
+        Object.keys(argumentMap).forEach((key) => {
+            var argument = argumentMap[key];
+            if ((!argument.inputs || argument.inputs.length == 0) &&
+                (argument.outputs && argument.outputs.length == 1)) {
+                this._inputs.push( { id: argument.outputs[0], name: argument.name });
+            }
+        });
+    }
+
+    get name() {
+        return '';
+    }
+
+    get inputs() {
+        return this._inputs.map((input) => {
+            return { 
+                name: input.name,
+                id: '[' + input.id.join(',') + ']' 
+            };
+        });
+    }
+
+    get outputs() {
+        return this._outputs.map((output) => {
+            return { 
+                name: output.name,
+                id: '[' + output.id.join(',') + ']' 
+            };
+        });
+    }
+
+    get nodes() {
+        return this._nodes;
+    }
+
+    static updateOutput(nodes, input) {
+        var sourceNodeIndex = input[0];
+        var sourceNode = nodes[sourceNodeIndex];
+        var sourceOutputIndex = input[1];
+        while (sourceOutputIndex >= sourceNode.outputs.length) {
+            sourceNode.outputs.push([ sourceNodeIndex, sourceNode.outputs.length ]);
+        }
+        return [ sourceNodeIndex, sourceOutputIndex ];
+    }
+}
+
+class MXNetNode {
+
+    constructor(json, argumentMap) {
+        this._operator = json.op;
+        this._name = json.name;
+        this._inputs = json.inputs;
+        this._outputs = json.outputs;
+        this._attributes = [];
+        var attrs = json.attrs;
+        if (!attrs) {
+            attrs = json.attr;
+        }
+        if (!attrs) {
+            attrs = json.param;
+        }
+        if (attrs) {
+            Object.keys(attrs).forEach((key) => {
+                var value = attrs[key];
+                this._attributes.push(new MXNetAttribute(this, key, value));
+            });
+        }
+        this._initializers = {};
+        this._inputs.forEach((input) => {
+            var argumentNodeIndex = input[0];
+            var argument = argumentMap[argumentNodeIndex];
+            if (argument) {
+                if ((!argument.inputs || argument.inputs.length == 0) &&
+                    (argument.outputs && argument.outputs.length == 1)) {
+                    var prefix = this._name + '_';
+                    if (prefix.endsWith('_fwd_')) {
+                        prefix = prefix.slice(0, -4);
+                    }
+                    if (argument.name && argument.name.startsWith(prefix)) {
+                        var id = '[' + input.join(',') + ']';
+                        this._initializers[id] = new MXNetTensor(argument);
+                        delete argumentMap[argumentNodeIndex];
+                    }
+                }
+            }
+        });
+    }
+
+    get operator() {
+        return this._operator;
+    }
+
+    get category() {
+        return MXNetOperatorMetadata.operatorMetadata.getOperatorCategory(this._operator);
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get inputs() {
+        var inputs = this._inputs.map((inputs) => {
+            return '[' + inputs.join(',') + ']'; 
+        });        
+        var results = MXNetOperatorMetadata.operatorMetadata.getInputs(this._operator, inputs);
+        results.forEach((input) => {
+            input.connections.forEach((connection) => {
+                var initializer = this._initializers[connection.id];
+                if (initializer) {
+                    connection.initializer = initializer;
+                }
+            });
+        });
+        return results;
+    }
+
+    get outputs() {
+        var outputs = this._outputs.map((output) => {
+            return '[' + output.join(',') + ']'; 
+        });
+        return MXNetOperatorMetadata.operatorMetadata.getOutputs(this._type, outputs);
+    }
+
+    get attributes() {
+        return this._attributes;
+    }
+}
+
+class MXNetAttribute {
+
+    constructor(owner, name, value) {
+        this._owner = owner;
+        this._name = name;
+        this._value = value;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get value() {
+        return this._value;
+    }
+
+    get hidden() {
+        return MXNetOperatorMetadata.operatorMetadata.getAttributeHidden(this._owner.operator, this._name, this._value);
+    }
+}
+
+class MXNetTensor {
+    
+    constructor(json) {
+        this._json = json;
+    }
+
+    get name() {
+        return this._json.name;
+    }
+}
+
+class MXNetOperatorMetadata {
+
+    static open(host, callback) {
+        if (MXNetOperatorMetadata.operatorMetadata) {
+            callback(null, MXNetOperatorMetadata.operatorMetadata);
+        }
+        else {
+            host.request('/mxnet-operator.json', (err, data) => {
+                MXNetOperatorMetadata.operatorMetadata = new MXNetOperatorMetadata(data);
+                callback(null, MXNetOperatorMetadata.operatorMetadata);
+            });
+        }    
+    }
+
+    constructor(data) {
+        this._map = {};
+        if (data) {
+            var items = JSON.parse(data);
+            if (items) {
+                items.forEach((item) => {
+                    if (item.name && item.schema)
+                    {
+                        var name = item.name;
+                        var schema = item.schema;
+                        this._map[name] = schema;
+                    }
+                });
+            }
+        }
+    }
+
+    getOperatorCategory(operator) {
+        var schema = this._map[operator];
+        if (schema && schema.category) {
+            return schema.category;
+        }
+        return null;
+    }
+
+    getInputs(type, inputs) {
+        var results = [];
+        var index = 0;
+        var schema = this._map[type];
+        if (schema && schema.inputs) {
+            schema.inputs.forEach((inputDef) => {
+                if (index < inputs.length || inputDef.option != 'optional') {
+                    var input = {};
+                    input.name = inputDef.name;
+                    input.type = inputDef.type;
+                    var count = (inputDef.option == 'variadic') ? (inputs.length - index) : 1;
+                    input.connections = [];
+                    inputs.slice(index, index + count).forEach((id) => {
+                        if (id != '' || inputDef.option != 'optional') {
+                            input.connections.push({ id: id});
+                        }
+                    });
+                    index += count;
+                    results.push(input);
+                }
+            });
+        }
+        else {
+            inputs.slice(index).forEach((input) => {
+                var name = (index == 0) ? 'input' : ('(' + index.toString() + ')');
+                results.push({
+                    name: name,
+                    connections: [ { id: input } ]
+                });
+                index++;
+            });
+
+        }
+        return results;
+    }
+
+    getOutputs(type, outputs) {
+        var results = [];
+        var index = 0;
+        var schema = this._map[type];
+        if (schema && schema.outputs) {
+            schema.outputs.forEach((outputDef) => {
+                if (index < outputs.length || outputDef.option != 'optional') {
+                    var output = {};
+                    output.name = outputDef.name;
+                    var count = (outputDef.option == 'variadic') ? (outputs.length - index) : 1;
+                    output.connections = outputs.slice(index, index + count).map((id) => {
+                        return { id: id };
+                    });
+                    index += count;
+                    results.push(output);
+                }
+            });
+        }
+        else {
+            outputs.slice(index).forEach((output) => {
+                var name = (index == 0) ? 'output' : ('(' + index.toString() + ')');
+                results.push({
+                    name: name,
+                    connections: [ { id: output } ]
+                });
+                index++;
+            });
+
+        }
+        return results;
+    }
+
+    getAttributeHidden(operator, name, value) {
+        var schema = this._map[operator];
+        if (schema && schema.attributes && schema.attributes.length > 0) {
+            if (!schema.attributesMap) {
+                schema.attributesMap = {};
+                schema.attributes.forEach((attribute) => {
+                    schema.attributesMap[attribute.name] = attribute;
+                });
+            }
+            var attribute = schema.attributesMap[name];
+            if (attribute) {
+                if (attribute.hasOwnProperty('hidden')) {
+                    return attribute.hidden;
+                }
+                if (attribute.hasOwnProperty('default')) {
+                    return MXNetOperatorMetadata.isEquivalent(attribute.default, value);
+                }
+            }
+        }
+        return false;
+    }
+
+    static isEquivalent(a, b) {
+        if (a === b) {
+            return a !== 0 || 1 / a === 1 / b;
+        }
+        if (a == null || b == null) {
+            return false;
+        }
+        if (a !== a) {
+            return b !== b;
+        }
+        var type = typeof a;
+        if (type !== 'function' && type !== 'object' && typeof b != 'object') {
+            return false;
+        }
+        var className = toString.call(a);
+        if (className !== toString.call(b)) {
+            return false;
+        }
+        switch (className) {
+            case '[object RegExp]':
+            case '[object String]':
+                return '' + a === '' + b;
+            case '[object Number]':
+                if (+a !== +a) {
+                    return +b !== +b;
+                }
+                return +a === 0 ? 1 / +a === 1 / b : +a === +b;
+            case '[object Date]':
+            case '[object Boolean]':
+                return +a === +b;
+            case '[object Array]':
+                var length = a.length;
+                if (length !== b.length) {
+                    return false;
+                }
+                while (length--) {
+                    if (!KerasOperatorMetadata.isEquivalent(a[length], b[length])) {
+                        return false;
+                    }
+                }
+                return true;
+        }
+
+        var keys = Object.keys(a);
+        var size = keys.length;
+        if (Object.keys(b).length != size) {
+            return false;
+        } 
+        while (size--) {
+            var key = keys[size];
+            if (!(b.hasOwnProperty(key) && KerasOperatorMetadata.isEquivalent(a[key], b[key]))) {
+                return false;
+            }
+        }
+        return true;
+    }
+}
+
+class MXNetError extends Error {
+    constructor(message) {
+        super(message);
+        this.name = 'MXNet Error';
+    }
+}

+ 156 - 0
src/mxnet-operator.json

@@ -0,0 +1,156 @@
+[
+  {
+    "name": "Convolution",
+    "schema": {
+      "category": "Layer",
+      "inputs": [
+        { "name": "input" },
+        { "name": "weight" },
+        { "name": "bias" }
+      ],
+      "attributes": [
+        { "name": "no_bias", "hidden": true },
+        { "name": "cudnn_off", "default": "False" },
+        { "name": "cudnn_tune", "default": "off" },
+        { "name": "num_group", "default": "1" },
+        { "name": "workspace", "default": "1024" }
+      ]
+    }
+  },
+  {
+    "name": "Deconvolution",
+    "schema": {
+      "category": "Layer",
+      "inputs": [
+        { "name": "input" },
+        { "name": "weight" },
+        { "name": "bias" }
+      ],
+      "attributes": [
+        { "name": "no_bias", "hidden": true },
+        { "name": "num_group", "default": "1" },
+        { "name": "workspace", "default": "1024" }
+      ]
+    }
+  },
+  {
+    "name": "FullyConnected",
+    "schema": {
+      "category": "Layer",
+      "inputs": [
+        { "name": "input" },
+        { "name": "weight" },
+        { "name": "bias" }
+      ],
+      "attributes": [
+        { "name": "no_bias", "hidden": true }
+      ]
+    }
+  },
+  {
+    "name": "Dropout",
+    "schema": {
+      "category": "Dropout"
+    }
+  },
+  {
+    "name": "LRN",
+    "schema": {
+      "category": "Normalization"
+    }
+  },
+  {
+    "name": "SoftmaxOutput",
+    "schema": {
+      "category": "Activation",
+      "inputs": [
+        { "name": "input" },
+        { "name": "label" }
+      ]
+    }
+  },
+  {
+    "name": "SoftmaxActivation",
+    "schema": {
+      "category": "Activation",
+      "inputs": [
+        { "name": "input" }
+      ]
+    }
+  },
+  {
+    "name": "Activation",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "Pooling",
+    "schema": {
+      "category": "Pool"
+    }
+  },
+  {
+    "name": "Flatten",
+    "schema": {
+      "category": "Shape"
+    }
+  },
+  {
+    "name": "Concat",
+    "schema": {
+      "category": "Tensor",
+      "inputs": [
+        { "name": "inputs", "option": "variadic" }
+      ]
+    }
+  },
+  {
+    "name": "_Plus",
+    "schema": {
+      "inputs": [
+        { "name": "inputs", "option": "variadic" }
+      ]
+    }
+  },
+  {
+    "name": "elemwise_add",
+    "schema": {
+      "inputs": [
+        { "name": "inputs", "option": "variadic" }
+      ]
+    }
+  },  
+  {
+    "name": "BatchNorm",
+    "schema": {
+      "category": "Normalization",
+      "inputs": [
+        { "name": "input" },
+        { "name": "gamma" },
+        { "name": "beta" }
+      ]
+    }
+  },
+  {
+    "name": "CuDNNBatchNorm",
+    "schema": {
+      "category": "Normalization",
+      "inputs": [
+        { "name": "input" },
+        { "name": "gamma" },
+        { "name": "beta" }
+      ]
+    }
+  },
+  {
+    "name": "ElementWiseSum",
+    "schema": {
+      "category": "Normalization",
+      "inputs": [
+        { "name": "inputs", "option": "variadic" }
+      ]
+    }
+  }
+
+]

+ 1 - 0
src/view-browser.html

@@ -45,6 +45,7 @@
 <script type='text/javascript' src='keras-model.js'></script>
 <script type='text/javascript' src='coreml-model.js'></script>
 <script type='text/javascript' src='caffe-model.js'></script>
+<script type='text/javascript' src='mxnet-model.js'></script>
 <script type='text/javascript' src='view-template.js'></script>
 <script type='text/javascript' src='view-browser.js'></script>
 <script type='text/javascript' src='view-render.js'></script>

+ 2 - 1
src/view-electron.html

@@ -30,11 +30,12 @@
 <script type='text/javascript' src='../node_modules/handlebars/dist/handlebars.min.js'></script>
 <script type='text/javascript' src='../node_modules/marked/marked.min.js'></script>
 <script type='text/javascript' src='onnx-model.js'></script>
-<script type='text/javascript' src='caffe-model.js'></script>
 <script type='text/javascript' src='tf-model.js'></script>
 <script type='text/javascript' src='tflite-model.js'></script>
 <script type='text/javascript' src='keras-model.js'></script>
 <script type='text/javascript' src='coreml-model.js'></script>
+<script type='text/javascript' src='caffe-model.js'></script>
+<script type='text/javascript' src='mxnet-model.js'></script>
 <script type='text/javascript' src='view-template.js'></script>
 <script type='text/javascript' src='view-electron.js'></script>
 <script type='text/javascript' src='view-render.js'></script>

+ 16 - 11
src/view.js

@@ -83,21 +83,11 @@ class View {
                 callback(err, model);
            });
         }
-        else if (identifier == 'saved_model.pb' || extension == 'meta') {
-            TensorFlowModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
         else if (extension == 'onnx') {
             OnnxModel.open(buffer, identifier, this._host, (err, model) => {
                 callback(err, model);
             });
         }
-        else if (extension == 'json' || extension == 'keras' || extension == 'h5') {
-            KerasModel.open(buffer, identifier, this._host, (err, model) => {
-                callback(err, model);
-            });
-        }
         else if (extension == 'mlmodel') {
             CoreMLModel.open(buffer, identifier, this._host, (err, model) => {
                 callback(err, model);
@@ -108,10 +98,25 @@ class View {
                 callback(err, model);
             });
         }
+        else if (identifier.endsWith('-symbol.json')) {
+            MXNetModel.open(buffer, identifier, this._host, (err, model) => {
+                callback(err, model);
+            });
+        }
+        else if (extension == 'keras' || extension == 'h5' || extension == 'json') {
+            KerasModel.open(buffer, identifier, this._host, (err, model) => {
+                callback(err, model);
+            });
+        }
+        else if (identifier == 'saved_model.pb' || extension == 'meta') {
+            TensorFlowModel.open(buffer, identifier, this._host, (err, model) => {
+                callback(err, model);
+            });
+        }
         else if (extension == 'pb') {
             OnnxModel.open(buffer, identifier, this._host, (err, model) => {
                 if (!err) {
-                    callback(err, model);    
+                    callback(err, model);
                 }
                 else {
                     TensorFlowModel.open(buffer, identifier, this._host, (err, model) => {

+ 19 - 0
tools/mxnet-generate

@@ -0,0 +1,19 @@
+#!/bin/bash
+
+mkdir -p ../third_party
+
+repository=https://github.com/apache/incubator-mxnet.git
+
+if [ -d "../third_party/incubator-mxnet" ]; then
+    pushd "../third_party/${identifier}" > /dev/null
+    echo "Fetch ${repository}..."
+    git fetch -p
+    echo "Reset ${repository}..."
+    git reset --hard origin/master
+    popd > /dev/null
+else
+    echo "Clone ${repository}..."
+    pushd "../third_party" > /dev/null
+    git clone --recursive ${repository}
+    popd > /dev/null
+fi