Quellcode durchsuchen

Add scikit-learn test file (#182)

Lutz Roeder vor 5 Jahren
Ursprung
Commit
fea056cb3f
3 geänderte Dateien mit 82 neuen und 24 gelöschten Zeilen
  1. 15 3
      source/python.js
  2. 60 21
      source/sklearn.js
  3. 7 0
      test/models.json

+ 15 - 3
source/python.js

@@ -1935,6 +1935,8 @@ python.Execution = class {
         this.registerType('sklearn.gaussian_process.kernels.ConstantKernel', class {});
         this.registerType('sklearn.gaussian_process.kernels.Product', class {});
         this.registerType('sklearn.gaussian_process.kernels.RBF', class {});
+        this.registerType('sklearn.grid_search._CVScoreTuple', class {});
+        this.registerType('sklearn.grid_search.GridSearchCV', class {});
         this.registerType('sklearn.impute._base.SimpleImputer', class {});
         this.registerType('sklearn.impute.SimpleImputer', class {});
         this.registerType('sklearn.isotonic.IsotonicRegression', class {});
@@ -2225,9 +2227,19 @@ python.Execution = class {
         this.registerFunction('collections.defaultdict', function(/* default_factory */) {
             return {};
         });
-        this.registerFunction('copy_reg._reconstructor', function(cls, base /* , state */) {
-            if (base == '__builtin__.object') {
-                return self.invoke(cls, []);
+        this.registerFunction('copy_reg._reconstructor', function(cls, base, state) {
+            // copyreg._reconstructor in Python 3
+            switch (base) {
+                case '__builtin__.object': {
+                    return self.invoke(cls, []);
+                }
+                case '__builtin__.tuple': {
+                    const obj = self.invoke(cls, []);
+                    for (let i = 0; i < state.length; i++) {
+                        obj[i] = state[i];
+                    }
+                    return obj;
+                }
             }
             throw new python.Error("Unknown copy_reg._reconstructor base type '" + base + "'.");
         });

+ 60 - 21
source/sklearn.js

@@ -8,11 +8,18 @@ sklearn.ModelFactory = class {
 
     match(context) {
         const obj = context.open('pkl');
-        if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
-            const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
-            if (key.startsWith('sklearn.') || key.startsWith('xgboost.sklearn.') || key.startsWith('lightgbm.sklearn.')) {
-                return true;
+        const validate = (obj) => {
+            if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
+                const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
+                return key.startsWith('sklearn.') || key.startsWith('xgboost.sklearn.') || key.startsWith('lightgbm.sklearn.');
             }
+            return false;
+        };
+        if (validate(obj)) {
+            return true;
+        }
+        if (Array.isArray(obj) && obj.every((item) => validate(item))) {
+            return true;
         }
         return false;
     }
@@ -28,8 +35,17 @@ sklearn.ModelFactory = class {
 sklearn.Model = class {
 
     constructor(metadata, obj) {
-        this._format = 'scikit-learn' + (obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
-        this._graphs = [ new sklearn.Graph(metadata, obj) ];
+        this._format = 'scikit-learn';
+        this._graphs = [];
+        if (!Array.isArray(obj)) {
+            this._format += obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '';
+            this._graphs.push(new sklearn.Graph(metadata, '', obj));
+        }
+        else {
+            for (let i = 0; i < obj.length; i++) {
+                this._graphs.push(new sklearn.Graph(metadata, i.toString(), obj[i]));
+            }
+        }
     }
 
     get format() {
@@ -43,8 +59,8 @@ sklearn.Model = class {
 
 sklearn.Graph = class {
 
-    constructor(metadata, obj) {
-        this._name = '';
+    constructor(metadata, name, obj) {
+        this._name = name || '';
         this._metadata = metadata;
         this._nodes = [];
         this._groups = false;
@@ -350,6 +366,9 @@ sklearn.Tensor = class {
             this._kind = 'NumPy Array';
             this._type = new sklearn.TensorType(value.dtype.name, new sklearn.TensorShape(value.shape));
             this._data = value.data;
+            if (value.dtype.name === 'string') {
+                this._itemsize = value.dtype.itemsize;
+            }
         }
         else {
             const type = value.__class__.__module__ + '.' + value.__class__.__name__;
@@ -422,7 +441,12 @@ sklearn.Tensor = class {
             case 'uint32':
             case 'int64':
             case 'uint64':
-                context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
+                context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
+                break;
+            case 'string':
+                context.data = this._data;
+                context.itemsize = this._itemsize;
+                context.decoder = new TextDecoder('utf-8');
                 break;
             default:
                 context.state = "Tensor data type '" + context.dataType + "' is not implemented.";
@@ -442,36 +466,51 @@ sklearn.Tensor = class {
                     return results;
                 }
                 switch (context.dataType) {
-                    case 'float32':
-                        results.push(context.rawData.getFloat32(context.index, true));
+                    case 'float32': {
+                        results.push(context.view.getFloat32(context.index, true));
                         context.index += 4;
                         context.count++;
                         break;
-                    case 'float64':
-                        results.push(context.rawData.getFloat64(context.index, true));
+                    }
+                    case 'float64': {
+                        results.push(context.view.getFloat64(context.index, true));
                         context.index += 8;
                         context.count++;
                         break;
-                    case 'int32':
-                        results.push(context.rawData.getInt32(context.index, true));
+                    }
+                    case 'int32': {
+                        results.push(context.view.getInt32(context.index, true));
                         context.index += 4;
                         context.count++;
                         break;
-                    case 'uint32':
-                        results.push(context.rawData.getUint32(context.index, true));
+                    }
+                    case 'uint32': {
+                        results.push(context.view.getUint32(context.index, true));
                         context.index += 4;
                         context.count++;
                         break;
-                    case 'int64':
-                        results.push(context.rawData.getInt64(context.index, true));
+                    }
+                    case 'int64': {
+                        results.push(context.view.getInt64(context.index, true));
                         context.index += 8;
                         context.count++;
                         break;
-                    case 'uint64':
-                        results.push(context.rawData.getUint64(context.index, true));
+                    }
+                    case 'uint64': {
+                        results.push(context.view.getUint64(context.index, true));
                         context.index += 8;
                         context.count++;
                         break;
+                    }
+                    case 'string': {
+                        const buffer = context.data.subarray(context.index, context.index + context.itemsize);
+                        const index = buffer.indexOf(0);
+                        const text = context.decoder.decode(index >= 0 ? buffer.subarray(0, index) : buffer);
+                        results.push(text);
+                        context.index += context.itemsize;
+                        context.count++;
+                        break;
+                    }
                 }
             }
         }

+ 7 - 0
test/models.json

@@ -4840,6 +4840,13 @@
     "format": "scikit-learn v0.19.1",
     "link":   "https://github.com/lutzroeder/netron/issues/182"
   },
+  {
+    "type":   "sklearn",
+    "target": "celeb-classifier.nn4.small2.v1.pkl.zip",
+    "source": "https://github.com/lutzroeder/netron/files/6176216/celeb-classifier.nn4.small2.v1.pkl.zip",
+    "format": "scikit-learn",
+    "link":   "https://github.com/lutzroeder/netron/issues/182"
+  },
   {
     "type":   "sklearn",
     "target": "column_pipeline.pkl",