|
|
@@ -7,21 +7,29 @@ sklearn.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
const obj = context.open('pkl');
|
|
|
- const validate = (obj) => {
|
|
|
+ const validate = (obj, name) => {
|
|
|
if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
|
|
|
const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
|
|
|
- return key.startsWith('sklearn.') || key.startsWith('xgboost.sklearn.') || key.startsWith('lightgbm.sklearn.');
|
|
|
+ return key.startsWith(name);
|
|
|
}
|
|
|
return false;
|
|
|
};
|
|
|
- if (validate(obj)) {
|
|
|
- return 'sklearn';
|
|
|
- }
|
|
|
- if (Array.isArray(obj) && obj.every((item) => validate(item))) {
|
|
|
- return 'sklearn.list';
|
|
|
- }
|
|
|
- if ((Object(obj) === obj) && Object.entries(obj).every((entry) => validate(entry[1]))) {
|
|
|
- return 'sklearn.map';
|
|
|
+ const formats = [
|
|
|
+ { name: 'sklearn.', format: 'sklearn' },
|
|
|
+ { name: 'xgboost.sklearn.', format: 'sklearn' },
|
|
|
+ { name: 'lightgbm.sklearn.', format: 'sklearn' },
|
|
|
+ { name: 'scipy.', format: 'scipy' }
|
|
|
+ ];
|
|
|
+ for (const format of formats) {
|
|
|
+ if (validate(obj, format.name)) {
|
|
|
+ return format.format;
|
|
|
+ }
|
|
|
+ if (Array.isArray(obj) && obj.every((item) => validate(item, format.name))) {
|
|
|
+ return format.format + '.list';
|
|
|
+ }
|
|
|
+ if ((Object(obj) === obj) && Object.entries(obj).every((entry) => validate(entry[1], format.name))) {
|
|
|
+ return format.format + '.map';
|
|
|
+ }
|
|
|
}
|
|
|
return undefined;
|
|
|
}
|
|
|
@@ -37,16 +45,19 @@ sklearn.ModelFactory = class {
|
|
|
sklearn.Model = class {
|
|
|
|
|
|
constructor(metadata, match, obj) {
|
|
|
- this._format = 'scikit-learn';
|
|
|
+ const formats = new Map([ [ 'sklearn', 'scikit-learn' ], [ 'scipy', 'SciPy' ] ]);
|
|
|
+ this._format = formats.get(match.split('.').shift());
|
|
|
this._graphs = [];
|
|
|
const version = [];
|
|
|
switch (match) {
|
|
|
- case 'sklearn': {
|
|
|
+ case 'sklearn':
|
|
|
+ case 'scipy': {
|
|
|
version.push(obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
|
|
|
this._graphs.push(new sklearn.Graph(metadata, '', obj));
|
|
|
break;
|
|
|
}
|
|
|
- case 'sklearn.list': {
|
|
|
+ case 'sklearn.list':
|
|
|
+ case 'scipy.list': {
|
|
|
const list = obj;
|
|
|
for (let i = 0; i < list.length; i++) {
|
|
|
const obj = list[i];
|
|
|
@@ -55,7 +66,8 @@ sklearn.Model = class {
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
- case 'sklearn.map': {
|
|
|
+ case 'sklearn.map':
|
|
|
+ case 'scipy.map': {
|
|
|
for (const entry of Object.entries(obj)) {
|
|
|
const obj = entry[1];
|
|
|
this._graphs.push(new sklearn.Graph(metadata, entry[0], obj));
|
|
|
@@ -63,6 +75,9 @@ sklearn.Model = class {
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
+ default: {
|
|
|
+ throw new sklearn.Error("Unsupported scikit-learn format '" + match + "'.");
|
|
|
+ }
|
|
|
}
|
|
|
if (version.every((value) => value === version[0])) {
|
|
|
this._format += version[0];
|