Browse Source

CoreML categorization and updateActiveGraph changes

Lutz Roeder 8 years ago
parent
commit
f78928eb19
7 changed files with 188 additions and 78 deletions
  1. 1 0
      setup.py
  2. 83 34
      src/coreml-model.js
  3. 80 0
      src/coreml-operator.json
  4. 0 8
      src/onnx-model.js
  5. 0 13
      src/tf-model.js
  6. 0 13
      src/tflite-model.js
  7. 24 10
      src/view.js

+ 1 - 0
setup.py

@@ -75,6 +75,7 @@ setuptools.setup(
             'tf-model.js', 'tf.js', 'tf-operator.pb',
             'tflite-model.js', 'tflite.js', 'tflite-operator.json',
             'keras-model.js', 'keras-operator.json', 'hdf5.js',
+            'coreml-model.js', 'coreml-operator.json',
             'view-browser.html', 'view-browser.js',
             'view.js', 'view.css', 'view-render.css', 'view-render.js', 'view-template.js'
         ]

+ 83 - 34
src/coreml-model.js

@@ -25,7 +25,9 @@ class CoreMLModel {
         try {
             var decodedBuffer = coreml.Model.decode(buffer);
             var model = new CoreMLModel(decodedBuffer, identifier);
-            callback(null, model);
+            CoreMLOperatorMetadata.open(host, (err, metadata) => {
+                callback(null, model);
+            });
         }
         catch (err) {
             callback(err, null);
@@ -35,7 +37,6 @@ class CoreMLModel {
     constructor(model, identifier) {
         this._model = model;
         this._graphs = [ new CoreMLGraph(this._model, identifier) ];
-        this._activeGraph = this._graphs[0];
     }
 
     get properties() {
@@ -68,14 +69,6 @@ class CoreMLModel {
     get graphs() {
         return this._graphs;
     }
-
-    get activeGraph() {
-        return this._activeGraph;
-    }
-
-    updateActiveGraph(name) {
-        this._activeGraph = (name == this._graphs[0]._graph.name) ? this._graph : null;
-    }
 }
 
 class CoreMLGraph {
@@ -83,15 +76,8 @@ class CoreMLGraph {
     constructor(model, identifier)
     {
         this._model = model;
-        this._identifier = identifier;
-    }
 
-    get name() {
-        return this._identifier;
-    }
-
-    get inputs() {
-        return this._model.description.input.map((input) => {
+        this._inputs = this._model.description.input.map((input) => {
             return {
                 id: input.name,
                 name: input.name,
@@ -99,10 +85,8 @@ class CoreMLGraph {
                 type: CoreMLGraph.formatFeatureType(input.type) 
             };
         });
-    }
 
-    get outputs() {
-        return this._model.description.output.map((output) => {
+        this._outputs = this._model.description.output.map((output) => {
             return {
                 id: output.name,
                 name: output.name,
@@ -110,30 +94,50 @@ class CoreMLGraph {
                 type: CoreMLGraph.formatFeatureType(output.type) 
             };
         });
-    }
-
-    get nodes() {
 
+        this._nodes = [];
         if (this._model.neuralNetworkClassifier) {
-            var results = [];
             this._model.neuralNetworkClassifier.layers.forEach((layer) => {
                 var node = new CoreMLNode(layer);
-                results.push(node);
+                this._nodes.push(node);
             });
-            return results;
+            this._name = "Neural Network Classifier";
         }
-
-        if (this._model.neuralNetwork) {
-            var results = [];
+        else if (this._model.neuralNetwork) {
             this._model.neuralNetwork.layers.forEach((layer) => {
                 var node = new CoreMLNode(layer);
-                results.push(node);
+                this._nodes.push(node);
             });
-            return results;
+            this._name = "Neural Network";
+        }
+        else if (this._model.pipelineClassifier) {
+            debugger;
+            this._name = "Pipeline Classifier";
+        }
+        else if (this._model.glmClassifier) {
+            debugger;
+            this._name = "Generalized Linear Classifier";
         }
+        else {
+            debugger;
+            this._name = identifier;
+        }
+    }
 
-        debugger;
-        return [];
+    get name() {
+        return this._name;
+    }
+
+    get inputs() {
+        return this._inputs;
+    }
+
+    get outputs() {
+        return this._outputs;
+    }
+
+    get nodes() {
+        return this._nodes;
     }
 
     static formatFeatureType(type) {
@@ -213,6 +217,10 @@ class CoreMLNode {
         return this._layer.layer;
     }
 
+    get category() {
+        return CoreMLOperatorMetadata.operatorMetadata.getOperatorCategory(this.operator);
+    }
+    
     get name() {
         return this._layer.name;
     }
@@ -268,3 +276,44 @@ class CoreMLAttribute {
         return JSON.stringify(this._value);
     }
 }
+
+class CoreMLOperatorMetadata 
+{
+
+    static open(host, callback) {
+        if (CoreMLOperatorMetadata.operatorMetadata) {
+            callback(null, CoreMLOperatorMetadata.operatorMetadata);
+        }
+        else {
+            host.request('/coreml-operator.json', (err, data) => {
+                CoreMLOperatorMetadata.operatorMetadata = new CoreMLOperatorMetadata(data);
+                callback(null, CoreMLOperatorMetadata.operatorMetadata);
+            });
+        }    
+    }
+
+    constructor(data) {
+        this._map = {};
+        if (data) {
+            var items = JSON.parse(data);
+            if (items) {
+                items.forEach((item) => {
+                    if (item.name && item.schema)
+                    {
+                        var name = item.name;
+                        var schema = item.schema;
+                        this._map[name] = schema;
+                    }
+                });
+            }
+        }
+    }
+
+    getOperatorCategory(operator) {
+        var schema = this._map[operator];
+        if (schema && schema.category) {
+            return schema.category;
+        }
+        return null;
+    }
+}

+ 80 - 0
src/coreml-operator.json

@@ -0,0 +1,80 @@
+[
+  {
+    "name": "convolution",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "innerProduct",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "uniDirectionalLSTM",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "gru",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "activation",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "softmax",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "batchnorm",
+    "schema": {
+      "category": "Normalization"
+    }
+  },
+  {
+    "name": "lrn",
+    "schema": {
+      "category": "Normalization"
+    }
+  },
+  {
+    "name": "pooling",
+    "schema": {
+      "category": "Pool"
+    }
+  },
+  {
+    "name": "permute",
+    "schema": {
+      "category": "Shape"
+    }
+  },
+  {
+    "name": "flatten",
+    "schema": {
+      "category": "Shape"
+    }
+  },
+  {
+    "name": "reshape",
+    "schema": {
+      "category": "Shape"
+    }
+  },
+  {
+    "name": "concat",
+    "schema": {
+      "category": "Tensor"
+    }
+  }
+]

+ 0 - 8
src/onnx-model.js

@@ -114,14 +114,6 @@ class OnnxModel {
     get graphs() {
         return this._graphs;
     }
-
-    get activeGraph() {
-        return this._activeGraph;
-    }
-
-    updateActiveGraph(name) {
-        this._activeGraph = (name == this._graphs[0]._graph.name) ? this._graph : null;
-    }
 }
 
 class OnnxGraph {

+ 0 - 13
src/tf-model.js

@@ -85,19 +85,6 @@ class TensorFlowModel {
     get graphs() {
         return this._graphs;    
     }
-
-    get activeGraph() {
-        return this._activeGraph;
-    }
-
-    updateActiveGraph(name) {
-        this.graphs.forEach((graph) => {
-            if (name == graph.name) {
-                this._activeGraph = graph;
-                return;
-            }            
-        });
-    }
 }
 
 class TensorFlowGraph {

+ 0 - 13
src/tflite-model.js

@@ -73,19 +73,6 @@ class TensorFlowLiteModel {
     get graphs() {
         return this._graphs;
     }
-
-    get activeGraph() {
-        return this._activeGraph;
-    }
-
-    updateActiveGraph(name) {
-        this.graphs.forEach((graph) => {
-            if (name == graph.name) {
-                this._activeGraph = graph;
-                return;
-            }            
-        });
-    }
 } 
 
 class TensorFlowLiteGraph {

+ 24 - 10
src/view.js

@@ -131,16 +131,19 @@ class View {
                     setTimeout(() => {
                         this._graph = false;
                         try {
-                            this.updateGraph(model);
+                            var graph = model.graphs.length > 0 ? model.graphs[0] : null;
+                            this.updateGraph(model, graph);
                             this._model = model;
+                            this._activeGraph = graph;
                             callback(null);
                         }
                         catch (err) {
                             try {
-                                this.updateGraph(this._model);
+                                this.updateGraph(this._model, this._activeGraph);
                             }
                             catch (obj) {
                                 this._model = null;
+                                this._activeGraph = null;
                             }
                             callback(err);
                         }
@@ -159,17 +162,28 @@ class View {
     updateActiveGraph(name) {
         this._sidebar.close();
         if (this._model) {
-            this._model.updateActiveGraph(name);
-            this.show('spinner');
-            setTimeout(() => {
-                this.updateGraph(this._model);
-            }, 250);
+            var model = this._model;
+            var graph = model.graphs.filter(graph => graph.name).shift();
+            if (graph) {
+                this.show('spinner');
+                setTimeout(() => {
+                    try {
+                        this.updateGraph(model, graph);
+                        this._model = model;
+                        this._activeGraph = graph;
+                    }
+                    catch (obj) {
+                        this._model = null;
+                        this._activeGraph = null;
+                    }
+                }, 250);
+    
+            }
         }
     }
     
-    updateGraph(model) {
-    
-        var graph = model.activeGraph;
+    updateGraph(model, graph) {
+
         if (!graph) {
             this.show('graph');
             return;