Bläddra i källkod

CoreML pipeline support (#86)

Lutz Roeder 8 år sedan
förälder
incheckning
a012bc9400
1 ändrade filer med 88 tillägg och 92 borttagningar
  1. 88 92
      src/coreml-model.js

+ 88 - 92
src/coreml-model.js

@@ -96,8 +96,59 @@ class CoreMLGraph {
         });
 
         this._nodes = [];
+        this._type = this.loadModel(model);
+    }
+
+    get name() {
+        return '';
+    }
+
+    get type() {
+        return this._type;
+    }
+
+    get inputs() {
+        return this._inputs;
+    }
+
+    get outputs() {
+        return this._outputs;
+    }
+
+    get nodes() {
+        return this._nodes;
+    }
+
+    updateInput(name, newName) {
+        this._nodes.forEach((node) => {
+            node._inputs = node._inputs.map((input) => (input != name) ? input : newName);
+        });
+        return newName;
+    }
+
+    updateOutput(name, newName) {
+        this._nodes.forEach((node) => {
+            node._outputs = node._outputs.map((output) => (output != name) ? output : newName);
+        });
+        return newName;
+    }
+
+    updateClassifierOutput(classifier) {
+        var labelProbabilityLayerName = classifier.labelProbabilityLayerName;
+        if (!labelProbabilityLayerName && this._nodes.length > 0) {
+            labelProbabilityLayerName = this._nodes.slice(-1).pop()._outputs[0];
+        }
+        var predictedFeatureName = this._description.predictedFeatureName;
+        var predictedProbabilitiesName = this._description.predictedProbabilitiesName;
+        if (predictedFeatureName && predictedProbabilitiesName && labelProbabilityLayerName && classifier.ClassLabels) {
+            var labelProbabilityInput = this.updateOutput(labelProbabilityLayerName, labelProbabilityLayerName + ':labelProbabilityLayerName');
+            var operator = classifier.ClassLabels;
+            this._nodes.push(new CoreMLNode(operator, null, classifier[operator], [ labelProbabilityInput ], [ predictedProbabilitiesName, predictedFeatureName ]));
+        }
+    }
+
+    loadModel(model) {
         if (model.neuralNetworkClassifier) {
-            this._type = "Neural Network Classifier";
             var neuralNetworkClassifier = model.neuralNetworkClassifier;
             neuralNetworkClassifier.layers.forEach((layer) => {
                 var operator = layer.layer;
@@ -121,125 +172,70 @@ class CoreMLGraph {
                     this._nodes.push(node);
                 });
             }
+            return 'Neural Network Classifier';
         }
         else if (model.neuralNetwork) {
-            this._type = "Neural Network";
             model.neuralNetwork.layers.forEach((layer) => {
                 var operator = layer.layer;
                 this._nodes.push(new CoreMLNode(operator, layer.name, layer[operator], layer.input, layer.output));
             });
+            return 'Neural Network';
         }
-        else if (model.pipelineClassifier) {
-            this._type = "Pipeline Classifier";
-            this._nodes.push(new CoreMLNode('pipelineClassifier', null, model.pipelineClassifier, 
-                this._description.input.map((input) => input.name), 
-                this._description.output.map((output) => output.name)));
-            this.updateClassifierOutput(model.pipelineClassifier);
-            /*
-            model.pipelineClassifier.pipeline.models.forEach((subModel, index) => {
-                var buffer = coreml.Model.encode(subModel).finish();
-                require('fs').writeFileSync(require('os').homedir + '/' + identifier + '_' + index.toString(), buffer);
-                console.log();
+        else if (model.neuralNetworkRegressor) {
+            model.neuralNetworkRegressor.layers.forEach((layer) => {
+                var operator = layer.layer;
+                this._nodes.push(new CoreMLNode(operator, layer.name, layer[operator], layer.input, layer.output));
             });
-            */
-            debugger;
+            return 'Neural Network Regressor';
         }
         else if (model.pipeline) {
-            this._type = "Pipeline";
-            this._nodes.push(new CoreMLNode('pipeline', null, model.pipeline, 
-                this._description.input.map((input) => input.name), 
-                this._description.output.map((output) => output.name)));
-            /*
-            model.pipeline.models.forEach((subModel, index) => {
-                var buffer = coreml.Model.encode(subModel).finish();
-                require('fs').writeFileSync(require('os').homedir + '/' + identifier + '_' + index.toString(), buffer);
+            model.pipeline.models.forEach((subModel) => {
+                this.loadModel(subModel);
             });
-            */
-            debugger;
+            return 'Pipeline';
+        }
+        else if (model.pipelineClassifier) {
+            model.pipelineClassifier.pipeline.models.forEach((subModel) => {
+                this.loadModel(subModel);
+            });
+            return 'Pipeline Classifier';
+        }
+        else if (model.pipelineRegressor) {
+            model.pipelineRegressor.pipeline.models.forEach((subModel) => {
+                this.loadModel(subModel);
+            });
+            return 'Pipeline Regressor';
         }
         else if (model.glmClassifier) {
-            this._type = "Generalized Linear Classifier";
             this._nodes.push(new CoreMLNode('glmClassifier', null, 
                 { classEncoding: model.glmClassifier.classEncoding, 
                   offset: model.glmClassifier.offset, 
                   weights: model.glmClassifier.weights }, 
-                [ this._description.input[0].name ],
-                [ this._description.predictedProbabilitiesName ]));
+                [ model.description.input[0].name ],
+                [ model.description.predictedProbabilitiesName ]));
             this.updateClassifierOutput(model.glmClassifier);
+            return 'Generalized Linear Classifier';
         }
         else if (model.dictVectorizer) {
-            this._type = "Dictionary Vectorizer";
             this._nodes.push(new CoreMLNode('dictVectorizer', null, model.dictVectorizer,
-                [ this._description.input[0].name ],
-                [ this._description.output[0].name ]));
-            debugger;
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
+            return 'Dictionary Vectorizer';
         }
         else if (model.featureVectorizer) {
-            this._type = "Feature Vectorizer";
             this._nodes.push(new CoreMLNode('featureVectorizer', null, model.featureVectorizer, 
-                [ this._description.input[0].name ],
-                [ this._description.output[0].name ]));
-            debugger;
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
+            return 'Feature Vectorizer';
         }
         else if (model.treeEnsembleClassifier) {
-            this._type = "Tree Ensemble Classifier";
             this._nodes.push(new CoreMLNode('treeEnsembleClassifier', null, model.treeEnsembleClassifier.treeEnsemble, 
-                [ this._description.input[0].name ],
-                [ this._description.output[0].name ]));
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
             this.updateClassifierOutput(model.treeEnsembleClassifier);
-            debugger;          
-        }
-        else {
-            debugger;
-        }
-    }
-
-    get name() {
-        return '';
-    }
-
-    get type() {
-        return this._type;
-    }
-
-    get inputs() {
-        return this._inputs;
-    }
-
-    get outputs() {
-        return this._outputs;
-    }
-
-    get nodes() {
-        return this._nodes;
-    }
-
-    updateInput(name, newName) {
-        this._nodes.forEach((node) => {
-            node._inputs = node._inputs.map((input) => (input != name) ? input : newName);
-        });
-        return newName;
-    }
-
-    updateOutput(name, newName) {
-        this._nodes.forEach((node) => {
-            node._outputs = node._outputs.map((output) => (output != name) ? output : newName);
-        });
-        return newName;
-    }
-
-    updateClassifierOutput(classifier) {
-        var labelProbabilityLayerName = classifier.labelProbabilityLayerName;
-        if (!labelProbabilityLayerName && this._nodes.length > 0) {
-            labelProbabilityLayerName = this._nodes.slice(-1).pop()._outputs[0];
-        }
-        var predictedFeatureName = this._description.predictedFeatureName;
-        var predictedProbabilitiesName = this._description.predictedProbabilitiesName;
-        if (predictedFeatureName && predictedProbabilitiesName && labelProbabilityLayerName && classifier.ClassLabels) {
-            var labelProbabilityInput = this.updateOutput(labelProbabilityLayerName, labelProbabilityLayerName + ':labelProbabilityLayerName');
-            var operator = classifier.ClassLabels;
-            this._nodes.push(new CoreMLNode(operator, null, classifier[operator], [ labelProbabilityInput ], [ predictedProbabilitiesName, predictedFeatureName ]));
+            return 'Tree Ensemble Classifier';
         }
+        return 'Unknown';
     }
 
     static formatFeatureType(type) {
@@ -599,7 +595,7 @@ class CoreMLOperatorMetadata
             var template = Handlebars.compile(operatorTemplate, 'utf-8');
             return template(schema);
         }
-        return "";
+        return '';
     }
 
     markdown(text) {