Pārlūkot izejas kodu

CoreML imputer and oneHotEncoder support

Lutz Roeder 8 gadi atpakaļ
vecāks
revīzija
2646ed0638
1 mainītis faili ar 46 papildinājumiem un 0 dzēšanām
  1. 46 0
      src/coreml-model.js

+ 46 - 0
src/coreml-model.js

@@ -224,6 +224,13 @@ class CoreMLGraph {
             this.updateClassifierOutput(group, model.glmClassifier);
             return 'Generalized Linear Classifier';
         }
+        else if (model.glmRegressor) {
+            this._nodes.push(new CoreMLNode(group, 'glmRegressor', null, 
+                model.glmRegressor,
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
+            return 'Generalized Linear Regressor';
+        }
         else if (model.dictVectorizer) {
             this._nodes.push(new CoreMLNode(group, 'dictVectorizer', null, model.dictVectorizer,
                 [ model.description.input[0].name ],
@@ -274,6 +281,42 @@ class CoreMLGraph {
                 [ model.description.output[0].name ]));
             return 'Support Vector Regressor';
         }
+        else if (model.arrayFeatureExtractor) {
+            this._nodes.push(new CoreMLNode(group, 'arrayFeatureExtractor', null, 
+                { extractIndex: model.arrayFeatureExtractor.extractIndex },
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
+            return 'Array Feature Extractor';
+        }
+        else if (model.oneHotEncoder) {
+            var categoryType = model.oneHotEncoder.CategoryType;
+            var oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
+            oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
+            this._nodes.push(new CoreMLNode(group, 'oneHotEncoder', null, 
+                oneHotEncoderParams,
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
+            return 'One Hot Encoder';
+        }
+        else if (model.imputer) {
+            var imputedValue = model.imputer.ImputedValue;
+            var replaceValue = model.imputer.ReplaceValue;
+            var imputerParams = {};
+            imputerParams[imputedValue] = model.imputer[imputedValue];
+            imputerParams[replaceValue] = model.imputer[replaceValue];
+            this._nodes.push(new CoreMLNode(group, 'oneHotEncoder', null, 
+                imputerParams,
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
+            return 'Imputer';
+            
+        }
+        else if (model.normalizer) {
+            this._nodes.push(new CoreMLNode(group, 'normalizer', null, 
+                model.normalizer,
+                [ model.description.input[0].name ],
+                [ model.description.output[0].name ]));
+        }
         return 'Unknown';
     }
 
@@ -532,6 +575,9 @@ class CoreMLAttribute {
         if (Array.isArray(this._value)) {
             return this._value.map((item) => JSON.stringify(item)).join(', ');
         }
+        if (Number.isNaN(this._value)) {
+            return 'NaN';
+        }
         return JSON.stringify(this._value);
     }