|
|
@@ -8,11 +8,18 @@ sklearn.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
const obj = context.open('pkl');
|
|
|
- if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
|
|
|
- const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
|
|
|
- if (key.startsWith('sklearn.') || key.startsWith('xgboost.sklearn.') || key.startsWith('lightgbm.sklearn.')) {
|
|
|
- return true;
|
|
|
+ const validate = (obj) => {
|
|
|
+ if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
|
|
|
+ const key = obj.__class__.__module__ + '.' + obj.__class__.__name__;
|
|
|
+ return key.startsWith('sklearn.') || key.startsWith('xgboost.sklearn.') || key.startsWith('lightgbm.sklearn.');
|
|
|
}
|
|
|
+ return false;
|
|
|
+ };
|
|
|
+ if (validate(obj)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (Array.isArray(obj) && obj.every((item) => validate(item))) {
|
|
|
+ return true;
|
|
|
}
|
|
|
return false;
|
|
|
}
|
|
|
@@ -28,8 +35,17 @@ sklearn.ModelFactory = class {
|
|
|
sklearn.Model = class {
|
|
|
|
|
|
constructor(metadata, obj) {
|
|
|
- this._format = 'scikit-learn' + (obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '');
|
|
|
- this._graphs = [ new sklearn.Graph(metadata, obj) ];
|
|
|
+ this._format = 'scikit-learn';
|
|
|
+ this._graphs = [];
|
|
|
+ if (!Array.isArray(obj)) {
|
|
|
+ this._format += obj._sklearn_version ? ' v' + obj._sklearn_version.toString() : '';
|
|
|
+ this._graphs.push(new sklearn.Graph(metadata, '', obj));
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ for (let i = 0; i < obj.length; i++) {
|
|
|
+ this._graphs.push(new sklearn.Graph(metadata, i.toString(), obj[i]));
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
get format() {
|
|
|
@@ -43,8 +59,8 @@ sklearn.Model = class {
|
|
|
|
|
|
sklearn.Graph = class {
|
|
|
|
|
|
- constructor(metadata, obj) {
|
|
|
- this._name = '';
|
|
|
+ constructor(metadata, name, obj) {
|
|
|
+ this._name = name || '';
|
|
|
this._metadata = metadata;
|
|
|
this._nodes = [];
|
|
|
this._groups = false;
|
|
|
@@ -350,6 +366,9 @@ sklearn.Tensor = class {
|
|
|
this._kind = 'NumPy Array';
|
|
|
this._type = new sklearn.TensorType(value.dtype.name, new sklearn.TensorShape(value.shape));
|
|
|
this._data = value.data;
|
|
|
+ if (value.dtype.name === 'string') {
|
|
|
+ this._itemsize = value.dtype.itemsize;
|
|
|
+ }
|
|
|
}
|
|
|
else {
|
|
|
const type = value.__class__.__module__ + '.' + value.__class__.__name__;
|
|
|
@@ -422,7 +441,12 @@ sklearn.Tensor = class {
|
|
|
case 'uint32':
|
|
|
case 'int64':
|
|
|
case 'uint64':
|
|
|
- context.rawData = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
|
|
|
+ context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
|
|
|
+ break;
|
|
|
+ case 'string':
|
|
|
+ context.data = this._data;
|
|
|
+ context.itemsize = this._itemsize;
|
|
|
+ context.decoder = new TextDecoder('utf-8');
|
|
|
break;
|
|
|
default:
|
|
|
context.state = "Tensor data type '" + context.dataType + "' is not implemented.";
|
|
|
@@ -442,36 +466,51 @@ sklearn.Tensor = class {
|
|
|
return results;
|
|
|
}
|
|
|
switch (context.dataType) {
|
|
|
- case 'float32':
|
|
|
- results.push(context.rawData.getFloat32(context.index, true));
|
|
|
+ case 'float32': {
|
|
|
+ results.push(context.view.getFloat32(context.index, true));
|
|
|
context.index += 4;
|
|
|
context.count++;
|
|
|
break;
|
|
|
- case 'float64':
|
|
|
- results.push(context.rawData.getFloat64(context.index, true));
|
|
|
+ }
|
|
|
+ case 'float64': {
|
|
|
+ results.push(context.view.getFloat64(context.index, true));
|
|
|
context.index += 8;
|
|
|
context.count++;
|
|
|
break;
|
|
|
- case 'int32':
|
|
|
- results.push(context.rawData.getInt32(context.index, true));
|
|
|
+ }
|
|
|
+ case 'int32': {
|
|
|
+ results.push(context.view.getInt32(context.index, true));
|
|
|
context.index += 4;
|
|
|
context.count++;
|
|
|
break;
|
|
|
- case 'uint32':
|
|
|
- results.push(context.rawData.getUint32(context.index, true));
|
|
|
+ }
|
|
|
+ case 'uint32': {
|
|
|
+ results.push(context.view.getUint32(context.index, true));
|
|
|
context.index += 4;
|
|
|
context.count++;
|
|
|
break;
|
|
|
- case 'int64':
|
|
|
- results.push(context.rawData.getInt64(context.index, true));
|
|
|
+ }
|
|
|
+ case 'int64': {
|
|
|
+ results.push(context.view.getInt64(context.index, true));
|
|
|
context.index += 8;
|
|
|
context.count++;
|
|
|
break;
|
|
|
- case 'uint64':
|
|
|
- results.push(context.rawData.getUint64(context.index, true));
|
|
|
+ }
|
|
|
+ case 'uint64': {
|
|
|
+ results.push(context.view.getUint64(context.index, true));
|
|
|
context.index += 8;
|
|
|
context.count++;
|
|
|
break;
|
|
|
+ }
|
|
|
+ case 'string': {
|
|
|
+ const buffer = context.data.subarray(context.index, context.index + context.itemsize);
|
|
|
+ const index = buffer.indexOf(0);
|
|
|
+ const text = context.decoder.decode(index >= 0 ? buffer.subarray(0, index) : buffer);
|
|
|
+ results.push(text);
|
|
|
+ context.index += context.itemsize;
|
|
|
+ context.count++;
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|