فهرست منبع

Add scikit-learn test file

Lutz Roeder 4 سال پیش
والد
کامیت
1c8316d6a9
2فایلهای تغییر یافته به همراه65 افزوده شده و 33 حذف شده
  1. 36 12
      source/sklearn.js
  2. 29 21
      test/models.json

+ 36 - 12
source/sklearn.js

@@ -18,33 +18,55 @@ sklearn.ModelFactory = class {
             return 'sklearn';
         }
         if (Array.isArray(obj) && obj.every((item) => validate(item))) {
-            return 'sklearn';
+            return 'sklearn.list';
+        }
+        if ((Object(obj) === obj) && Object.entries(obj).every((entry) => validate(entry[1]))) {
+            return 'sklearn.map';
         }
         return undefined;
     }
 
-    open(context) {
+    open(context, match) {
         return sklearn.Metadata.open(context).then((metadata) => {
             const obj = context.open('pkl');
-            return new sklearn.Model(metadata, obj);
+            return new sklearn.Model(metadata, match, obj);
         });
     }
 };
 
 sklearn.Model = class {
 
-    constructor(metadata, obj) {
+    constructor(metadata, match, obj) {
         this._format = 'scikit-learn';
         this._graphs = [];
-        if (!Array.isArray(obj)) {
-            this._format += obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '';
-            this._graphs.push(new sklearn.Graph(metadata, '', obj));
-        }
-        else {
-            for (let i = 0; i < obj.length; i++) {
-                this._graphs.push(new sklearn.Graph(metadata, i.toString(), obj[i]));
+        const version = [];
+        switch (match) {
+            case 'sklearn': {
+                version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
+                this._graphs.push(new sklearn.Graph(metadata, '', obj));
+                break;
+            }
+            case 'sklearn.list': {
+                const list = obj;
+                for (let i = 0; i < list.length; i++) {
+                    const obj = list[i];
+                    this._graphs.push(new sklearn.Graph(metadata, i.toString(), obj));
+                    version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
+                }
+                break;
+            }
+            case 'sklearn.map': {
+                for (const entry of Object.entries(obj)) {
+                    const obj = entry[1];
+                    this._graphs.push(new sklearn.Graph(metadata, entry[0], obj));
+                    version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
+                }
+                break;
             }
         }
+        if (version.every((value) => value === version[0])) {
+            this._format += version[0];
+        }
     }
 
     get format() {
@@ -98,7 +120,9 @@ sklearn.Graph = class {
                 const outputs = [];
                 this._add(subgroup, output, obj, inputs, [ output ]);
                 for (const transformer of obj.transformers){
-                    outputs.push(...this._process(subgroup, transformer[0], transformer[1], [ output ]));
+                    if (transformer[1] !== 'passthrough') {
+                        outputs.push(...this._process(subgroup, transformer[0], transformer[1], [ output ]));
+                    }
                 }
                 return outputs;
             }

+ 29 - 21
test/models.json

@@ -5098,27 +5098,6 @@
     "format": "scikit-learn v0.22.1",
     "link":   "https://github.com/lutzroeder/netron/issues/498"
   },
-  {
-    "type":   "sklearn",
-    "target": "iris_svc.joblib.z",
-    "source": "https://github.com/lutzroeder/netron/files/5728549/iris_svc.joblib.z.zip[iris_svc.joblib.z]",
-    "format": "scikit-learn v0.24.0rc1",
-    "link":   "https://github.com/lutzroeder/netron/issues/182"
-  },
-  {
-    "type":   "sklearn",
-    "target": "LDA_model.pkl",
-    "source": "https://raw.githubusercontent.com/rainer85ah/DCS/master/Output/LDA_model.pkl",
-    "format": "scikit-learn",
-    "link":   "https://github.com/rainer85ah/DCS"
-  },
-  {
-    "type":   "sklearn",
-    "target": "phoenix_ml_classifier.pkl",
-    "source": "https://raw.githubusercontent.com/imeraj/Phoenix_Playground/master/1.4/phoenix_ml/lib/phoenix_ml/model/classifier.pkl",
-    "format": "scikit-learn v0.19.1",
-    "link":   "https://github.com/imeraj/Phoenix_Playground/tree/master/1.4/phoenix_ml/lib/phoenix_ml/model"
-  },
   {
     "type":   "sklearn",
     "target": "forest_iris_ExtraTreesClassifier.pkl",
@@ -5147,6 +5126,27 @@
     "format": "scikit-learn v0.19.2",
     "link":   "https://github.com/lutzroeder/netron/issues/182"
   },
+  {
+    "type":   "sklearn",
+    "target": "iris_svc.joblib.z",
+    "source": "https://github.com/lutzroeder/netron/files/5728549/iris_svc.joblib.z.zip[iris_svc.joblib.z]",
+    "format": "scikit-learn v0.24.0rc1",
+    "link":   "https://github.com/lutzroeder/netron/issues/182"
+  },
+  {
+    "type":   "sklearn",
+    "target": "LDA_model.pkl",
+    "source": "https://raw.githubusercontent.com/rainer85ah/DCS/master/Output/LDA_model.pkl",
+    "format": "scikit-learn",
+    "link":   "https://github.com/rainer85ah/DCS"
+  },
+  {
+    "type":   "sklearn",
+    "target": "phoenix_ml_classifier.pkl",
+    "source": "https://raw.githubusercontent.com/imeraj/Phoenix_Playground/master/1.4/phoenix_ml/lib/phoenix_ml/model/classifier.pkl",
+    "format": "scikit-learn v0.19.1",
+    "link":   "https://github.com/imeraj/Phoenix_Playground/tree/master/1.4/phoenix_ml/lib/phoenix_ml/model"
+  },
   {
     "type":   "sklearn",
     "target": "pima.xgboost.joblib.pkl",
@@ -5175,6 +5175,14 @@
     "format": "scikit-learn",
     "link":   "https://github.com/dfridovi/imagelib/blob/master/svm.pkl"
   },
+  {
+    "type":   "sklearn",
+    "target": "tree.pkl",
+    "source": "https://github.com/wjqkkky/TTS-front-end/raw/56e09ef79bb6679e51a7c8ff3d302ccd917c0992/ChineseRhythmPredictor/tree.pkl",
+    "error":  "Invalid string length",
+    "format": "scikit-learn v0.21.3",
+    "link":   "https://github.com/wjqkkky/TTS-front-end"
+  },
   {
     "type":   "sklearn",
     "target": "wiki-aa-mlp.pkl",