Lutz Roeder 4 лет назад
Родитель
Сommit
fd525332cc
1 измененных файлов с 29 добавлено и 14 удалено
  1. 29 14
      source/sklearn.js

+ 29 - 14
source/sklearn.js

@@ -7,21 +7,29 @@ sklearn.ModelFactory = class {
 
     match(context) {
         const obj = context.open('pkl');
-        const validate = (obj) => {
+        const validate = (obj, name) => {
             if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
                 const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
-                return key.startsWith('sklearn.') || key.startsWith('xgboost.sklearn.') || key.startsWith('lightgbm.sklearn.');
+                return key.startsWith(name);
             }
             return false;
         };
-        if (validate(obj)) {
-            return 'sklearn';
-        }
-        if (Array.isArray(obj) && obj.every((item) => validate(item))) {
-            return 'sklearn.list';
-        }
-        if ((Object(obj) === obj) && Object.entries(obj).every((entry) => validate(entry[1]))) {
-            return 'sklearn.map';
+        const formats = [
+            { name: 'sklearn.', format: 'sklearn' },
+            { name: 'xgboost.sklearn.', format: 'sklearn' },
+            { name: 'lightgbm.sklearn.', format: 'sklearn' },
+            { name: 'scipy.', format: 'scipy' }
+        ];
+        for (const format of formats) {
+            if (validate(obj, format.name)) {
+                return format.format;
+            }
+            if (Array.isArray(obj) && obj.every((item) => validate(item, format.name))) {
+                return format.format + '.list';
+            }
+            if ((Object(obj) === obj) && Object.entries(obj).every((entry) => validate(entry[1], format.name))) {
+                return format.format + '.map';
+            }
         }
         return undefined;
     }
@@ -37,16 +45,19 @@ sklearn.ModelFactory = class {
 sklearn.Model = class {
 
     constructor(metadata, match, obj) {
-        this._format = 'scikit-learn';
+        const formats = new Map([ [ 'sklearn', 'scikit-learn' ], [ 'scipy', 'SciPy' ] ]);
+        this._format = formats.get(match.split('.').shift());
         this._graphs = [];
         const version = [];
         switch (match) {
-            case 'sklearn': {
+            case 'sklearn':
+            case 'scipy': {
                 version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
                 this._graphs.push(new sklearn.Graph(metadata, '', obj));
                 break;
             }
-            case 'sklearn.list': {
+            case 'sklearn.list':
+            case 'scipy.list': {
                 const list = obj;
                 for (let i = 0; i < list.length; i++) {
                     const obj = list[i];
@@ -55,7 +66,8 @@ sklearn.Model = class {
                 }
                 break;
             }
-            case 'sklearn.map': {
+            case 'sklearn.map':
+            case 'scipy.map': {
                 for (const entry of Object.entries(obj)) {
                     const obj = entry[1];
                     this._graphs.push(new sklearn.Graph(metadata, entry[0], obj));
@@ -63,6 +75,9 @@ sklearn.Model = class {
                 }
                 break;
             }
+            default: {
+                throw new sklearn.Error("Unsupported scikit-learn format '" + match + "'.");
+            }
         }
         if (version.every((value) => value === version[0])) {
             this._format += version[0];