|
@@ -21,6 +21,7 @@ mlnet.ModelFactory = class {
|
|
|
async open(context) {
|
|
async open(context) {
|
|
|
const metadata = await context.metadata('mlnet-metadata.json');
|
|
const metadata = await context.metadata('mlnet-metadata.json');
|
|
|
const reader = new mlnet.ModelReader(context.value);
|
|
const reader = new mlnet.ModelReader(context.value);
|
|
|
|
|
+ await reader.resolve(context);
|
|
|
return new mlnet.Model(metadata, reader);
|
|
return new mlnet.Model(metadata, reader);
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
@@ -77,7 +78,7 @@ mlnet.Module = class {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- const node = new mlnet.Node(metadata, group, transformer, values);
|
|
|
|
|
|
|
+ const node = new mlnet.Node(metadata, transformer, group, values);
|
|
|
this.nodes.push(node);
|
|
this.nodes.push(node);
|
|
|
};
|
|
};
|
|
|
/* eslint-disable no-use-before-define */
|
|
/* eslint-disable no-use-before-define */
|
|
@@ -99,8 +100,12 @@ mlnet.Module = class {
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
- /* eslint-enable no-use-before-define */
|
|
|
|
|
const scope = new Map();
|
|
const scope = new Map();
|
|
|
|
|
+ if (reader.schema && reader.schema.inputs) {
|
|
|
|
|
+ for (const input of reader.schema.inputs) {
|
|
|
|
|
+ scope[input.name] = { argument: input.name, counter: 0 };
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
if (reader.dataLoaderModel) {
|
|
if (reader.dataLoaderModel) {
|
|
|
loadTransformer(scope, '', reader.dataLoaderModel);
|
|
loadTransformer(scope, '', reader.dataLoaderModel);
|
|
|
}
|
|
}
|
|
@@ -136,41 +141,43 @@ mlnet.Value = class {
|
|
|
|
|
|
|
|
mlnet.Node = class {
|
|
mlnet.Node = class {
|
|
|
|
|
|
|
|
- constructor(metadata, group, transformer, values) {
|
|
|
|
|
- this.group = group;
|
|
|
|
|
- this.name = transformer.__name__;
|
|
|
|
|
|
|
+ constructor(metadata, obj, group, values) {
|
|
|
|
|
+ const op = obj.__type__;
|
|
|
|
|
+ this.type = metadata.type(op) || { name: op || '?' };
|
|
|
|
|
+ this.name = obj.__name__ || '';
|
|
|
|
|
+ this.group = group || '';
|
|
|
this.inputs = [];
|
|
this.inputs = [];
|
|
|
this.outputs = [];
|
|
this.outputs = [];
|
|
|
this.attributes = [];
|
|
this.attributes = [];
|
|
|
- const type = transformer.__type__;
|
|
|
|
|
- this.type = metadata.type(type) || { name: type };
|
|
|
|
|
- if (transformer.inputs) {
|
|
|
|
|
- let i = 0;
|
|
|
|
|
- for (const input of transformer.inputs) {
|
|
|
|
|
- const value = values.map(input.name);
|
|
|
|
|
- const argument = new mlnet.Argument(i.toString(), [value]);
|
|
|
|
|
- this.inputs.push(argument);
|
|
|
|
|
- i++;
|
|
|
|
|
|
|
+ if (values && obj.inputs) {
|
|
|
|
|
+ for (let i = 0; i < obj.inputs.length; i++) {
|
|
|
|
|
+ const value = values.map(obj.inputs[i].name);
|
|
|
|
|
+ this.inputs.push(new mlnet.Argument(i.toString(), [value]));
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- if (transformer.outputs) {
|
|
|
|
|
- let i = 0;
|
|
|
|
|
- for (const output of transformer.outputs) {
|
|
|
|
|
- const argument = new mlnet.Argument(i.toString(), [values.map(output.name)]);
|
|
|
|
|
- this.outputs.push(argument);
|
|
|
|
|
- i++;
|
|
|
|
|
|
|
+ if (values && obj.outputs) {
|
|
|
|
|
+ for (let i = 0; i < obj.outputs.length; i++) {
|
|
|
|
|
+ this.outputs.push(new mlnet.Argument(i.toString(), [values.map(obj.outputs[i].name)]));
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- for (const [name, obj] of Object.entries(transformer).filter(([key]) => !key.startsWith('_') && key !== 'inputs' && key !== 'outputs')) {
|
|
|
|
|
- const schema = metadata.attribute(transformer.__type__, name);
|
|
|
|
|
- let value = obj;
|
|
|
|
|
|
|
+ for (const [name, raw] of Object.entries(obj).filter(([key]) => !key.startsWith('_') && key !== 'inputs' && key !== 'outputs')) {
|
|
|
|
|
+ const schema = metadata.attribute(op, name);
|
|
|
|
|
+ let value = raw;
|
|
|
let type = null;
|
|
let type = null;
|
|
|
if (schema) {
|
|
if (schema) {
|
|
|
type = schema.type ? schema.type : null;
|
|
type = schema.type ? schema.type : null;
|
|
|
value = mlnet.Utility.enum(type, value);
|
|
value = mlnet.Utility.enum(type, value);
|
|
|
}
|
|
}
|
|
|
- const attribute = new mlnet.Argument(name, value, type);
|
|
|
|
|
- this.attributes.push(attribute);
|
|
|
|
|
|
|
+ if (value && typeof value === 'object' && !Array.isArray(value) && Array.isArray(value.nodes)) {
|
|
|
|
|
+ type = 'graph';
|
|
|
|
|
+ } else if (value && typeof value === 'object' && !Array.isArray(value) && value.__type__) {
|
|
|
|
|
+ value = new mlnet.Node(metadata, value);
|
|
|
|
|
+ type = 'object';
|
|
|
|
|
+ } else if (Array.isArray(value) && value.length > 0 && value.every((item) => item && item.__type__)) {
|
|
|
|
|
+ value = value.map((item) => new mlnet.Node(metadata, item));
|
|
|
|
|
+ type = 'object[]';
|
|
|
|
|
+ }
|
|
|
|
|
+ this.attributes.push(new mlnet.Argument(name, value, type));
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
@@ -228,7 +235,6 @@ mlnet.TensorShape = class {
|
|
|
mlnet.ModelReader = class {
|
|
mlnet.ModelReader = class {
|
|
|
|
|
|
|
|
constructor(entries) {
|
|
constructor(entries) {
|
|
|
-
|
|
|
|
|
const catalog = new mlnet.ComponentCatalog();
|
|
const catalog = new mlnet.ComponentCatalog();
|
|
|
catalog.register('AffineNormExec', mlnet.AffineNormSerializationUtils);
|
|
catalog.register('AffineNormExec', mlnet.AffineNormSerializationUtils);
|
|
|
catalog.register('AnomalyPredXfer', mlnet.AnomalyPredictionTransformer);
|
|
catalog.register('AnomalyPredXfer', mlnet.AnomalyPredictionTransformer);
|
|
@@ -309,34 +315,127 @@ mlnet.ModelReader = class {
|
|
|
catalog.register('TransformerChain', mlnet.TransformerChain);
|
|
catalog.register('TransformerChain', mlnet.TransformerChain);
|
|
|
catalog.register('ValueMappingTransformer', mlnet.ValueMappingTransformer);
|
|
catalog.register('ValueMappingTransformer', mlnet.ValueMappingTransformer);
|
|
|
catalog.register('XGBoostMulticlass', mlnet.XGBoostMulticlass);
|
|
catalog.register('XGBoostMulticlass', mlnet.XGBoostMulticlass);
|
|
|
-
|
|
|
|
|
- const root = new mlnet.ModelHeader(catalog, entries, '', null);
|
|
|
|
|
-
|
|
|
|
|
|
|
+ this._resolve = [];
|
|
|
|
|
+ const root = new mlnet.ModelHeader(catalog, entries, '', null, this._resolve);
|
|
|
const version = root.openText('TrainingInfo/Version.txt');
|
|
const version = root.openText('TrainingInfo/Version.txt');
|
|
|
if (version) {
|
|
if (version) {
|
|
|
[this.version] = version.split(/[\s+\r]+/);
|
|
[this.version] = version.split(/[\s+\r]+/);
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
const schemaReader = root.openBinary('Schema');
|
|
const schemaReader = root.openBinary('Schema');
|
|
|
if (schemaReader) {
|
|
if (schemaReader) {
|
|
|
this.schema = new mlnet.BinaryLoader(null, schemaReader).schema;
|
|
this.schema = new mlnet.BinaryLoader(null, schemaReader).schema;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
const transformerChain = root.open('TransformerChain');
|
|
const transformerChain = root.open('TransformerChain');
|
|
|
if (transformerChain) {
|
|
if (transformerChain) {
|
|
|
this.transformerChain = transformerChain;
|
|
this.transformerChain = transformerChain;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
const dataLoaderModel = root.open('DataLoaderModel');
|
|
const dataLoaderModel = root.open('DataLoaderModel');
|
|
|
if (dataLoaderModel) {
|
|
if (dataLoaderModel) {
|
|
|
this.dataLoaderModel = dataLoaderModel;
|
|
this.dataLoaderModel = dataLoaderModel;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
const predictor = root.open('Predictor');
|
|
const predictor = root.open('Predictor');
|
|
|
if (predictor) {
|
|
if (predictor) {
|
|
|
this.predictor = predictor;
|
|
this.predictor = predictor;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ async resolve(context) {
|
|
|
|
|
+ const resolve = async (entry) => {
|
|
|
|
|
+ let module = null;
|
|
|
|
|
+ let content = '';
|
|
|
|
|
+ if (entry.format === 'tf') {
|
|
|
|
|
+ const protobuf = await import('./protobuf.js');
|
|
|
|
|
+ module = await context.require('./tf');
|
|
|
|
|
+ content = new mlnet.Context(context, 'model.pb', entry.bytes, protobuf);
|
|
|
|
|
+ } else if (entry.format === 'onnx') {
|
|
|
|
|
+ const protobuf = await import('./protobuf.js');
|
|
|
|
|
+ module = await context.require('./onnx');
|
|
|
|
|
+ content = new mlnet.Context(context, 'model.onnx', entry.bytes, protobuf);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ throw new mlnet.Error(`Unsupported ML.NET model format '${entry.format}'.`);
|
|
|
|
|
+ }
|
|
|
|
|
+ const factory = new module.ModelFactory();
|
|
|
|
|
+ await factory.match(content);
|
|
|
|
|
+ const model = await factory.open(content);
|
|
|
|
|
+ if (model && Array.isArray(model.modules) && model.modules.length > 0) {
|
|
|
|
|
+ return model.modules[0];
|
|
|
|
|
+ }
|
|
|
|
|
+ return null;
|
|
|
|
|
+ };
|
|
|
|
|
+ const results = await Promise.all(this._resolve.map((entry) => resolve(entry).catch(() => null)));
|
|
|
|
|
+ for (let i = 0; i < this._resolve.length; i++) {
|
|
|
|
|
+ if (results[i]) {
|
|
|
|
|
+ this._resolve[i].target.Model = results[i];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
|
|
+mlnet.Context = class {
|
|
|
|
|
+
|
|
|
|
|
+ constructor(context, identifier, bytes, protobuf) {
|
|
|
|
|
+ this._context = context;
|
|
|
|
|
+ this._protobuf = protobuf;
|
|
|
|
|
+ this._identifier = identifier;
|
|
|
|
|
+ this._stream = new base.BinaryStream(bytes);
|
|
|
|
|
+ this._tags = new Map();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ get identifier() {
|
|
|
|
|
+ return this._identifier;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ get stream() {
|
|
|
|
|
+ return this._stream;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ set(type, value) {
|
|
|
|
|
+ this.type = type;
|
|
|
|
|
+ this.value = value;
|
|
|
|
|
+ return type;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async require(id) {
|
|
|
|
|
+ return this._context.require(id);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async metadata(id) {
|
|
|
|
|
+ return this._context.metadata(id);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async request(file, encoding) {
|
|
|
|
|
+ return this._context.request(file, encoding);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async read(type) {
|
|
|
|
|
+ if (type === 'protobuf.binary') {
|
|
|
|
|
+ return this._protobuf.BinaryReader.open(this.stream);
|
|
|
|
|
+ }
|
|
|
|
|
+ throw new mlnet.Error(`Unsupported read type '${type}'.`);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async tags(type) {
|
|
|
|
|
+ if (!this._tags.has(type)) {
|
|
|
|
|
+ let tags = new Map();
|
|
|
|
|
+ try {
|
|
|
|
|
+ const reader = this._protobuf.BinaryReader.open(this.stream);
|
|
|
|
|
+ tags = type === 'pb+' ? reader.decode() : reader.signature();
|
|
|
|
|
+ } catch {
|
|
|
|
|
+ // continue regardless of error
|
|
|
|
|
+ }
|
|
|
|
|
+ this.stream.seek(0);
|
|
|
|
|
+ this._tags.set(type, tags);
|
|
|
|
|
+ }
|
|
|
|
|
+ return this._tags.get(type);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async peek() {
|
|
|
|
|
+ return undefined;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ error(err) {
|
|
|
|
|
+ this._context.error(err, false);
|
|
|
|
|
+ }
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
mlnet.ComponentCatalog = class {
|
|
mlnet.ComponentCatalog = class {
|
|
@@ -360,20 +459,17 @@ mlnet.ComponentCatalog = class {
|
|
|
|
|
|
|
|
mlnet.ModelHeader = class {
|
|
mlnet.ModelHeader = class {
|
|
|
|
|
|
|
|
- constructor(catalog, entries, directory, data) {
|
|
|
|
|
-
|
|
|
|
|
|
|
+ constructor(catalog, entries, directory, data, resolve) {
|
|
|
this._entries = entries;
|
|
this._entries = entries;
|
|
|
this._catalog = catalog;
|
|
this._catalog = catalog;
|
|
|
this._directory = directory;
|
|
this._directory = directory;
|
|
|
-
|
|
|
|
|
|
|
+ this._resolve = resolve;
|
|
|
if (data) {
|
|
if (data) {
|
|
|
const reader = new mlnet.BinaryReader(data);
|
|
const reader = new mlnet.BinaryReader(data);
|
|
|
-
|
|
|
|
|
const decoder = new TextDecoder('ascii');
|
|
const decoder = new TextDecoder('ascii');
|
|
|
reader.assert('ML\0MODEL');
|
|
reader.assert('ML\0MODEL');
|
|
|
this.versionWritten = reader.uint32();
|
|
this.versionWritten = reader.uint32();
|
|
|
this.versionReadable = reader.uint32();
|
|
this.versionReadable = reader.uint32();
|
|
|
-
|
|
|
|
|
const modelBlockOffset = reader.uint64().toNumber();
|
|
const modelBlockOffset = reader.uint64().toNumber();
|
|
|
/* let modelBlockSize = */ reader.uint64();
|
|
/* let modelBlockSize = */ reader.uint64();
|
|
|
const stringTableOffset = reader.uint64().toNumber();
|
|
const stringTableOffset = reader.uint64().toNumber();
|
|
@@ -416,7 +512,6 @@ mlnet.ModelHeader = class {
|
|
|
}
|
|
}
|
|
|
reader.seek(tailOffset);
|
|
reader.seek(tailOffset);
|
|
|
reader.assert('LEDOM\0LM');
|
|
reader.assert('LEDOM\0LM');
|
|
|
-
|
|
|
|
|
this._reader = reader;
|
|
this._reader = reader;
|
|
|
this._reader.seek(modelBlockOffset);
|
|
this._reader.seek(modelBlockOffset);
|
|
|
}
|
|
}
|
|
@@ -441,7 +536,7 @@ mlnet.ModelHeader = class {
|
|
|
const stream = this._entries.get(key) || this._entries.get(key.replace(/\//g, '\\'));
|
|
const stream = this._entries.get(key) || this._entries.get(key.replace(/\//g, '\\'));
|
|
|
if (stream) {
|
|
if (stream) {
|
|
|
const buffer = stream.peek();
|
|
const buffer = stream.peek();
|
|
|
- const context = new mlnet.ModelHeader(this._catalog, this._entries, name, buffer);
|
|
|
|
|
|
|
+ const context = new mlnet.ModelHeader(this._catalog, this._entries, name, buffer, this._resolve);
|
|
|
const value = this._catalog.create(context.loaderSignature, context);
|
|
const value = this._catalog.create(context.loaderSignature, context);
|
|
|
value.__type__ = value.__type__ || context.loaderSignature;
|
|
value.__type__ = value.__type__ || context.loaderSignature;
|
|
|
value.__name__ = name;
|
|
value.__name__ = name;
|
|
@@ -475,6 +570,10 @@ mlnet.ModelHeader = class {
|
|
|
check(signature, verWrittenCur, verWeCanReadBack) {
|
|
check(signature, verWrittenCur, verWeCanReadBack) {
|
|
|
return signature === this.modelSignature && verWrittenCur >= this.modelVersionReadable && verWeCanReadBack <= this.modelVersionWritten;
|
|
return signature === this.modelSignature && verWrittenCur >= this.modelVersionReadable && verWeCanReadBack <= this.modelVersionWritten;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ resolve(target, format, bytes) {
|
|
|
|
|
+ this._resolve.push({ target, format, bytes });
|
|
|
|
|
+ }
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
mlnet.BinaryReader = class {
|
|
mlnet.BinaryReader = class {
|
|
@@ -635,6 +734,7 @@ mlnet.BinaryLoader = class { // 'BINLOADR'
|
|
|
reader.seek(tableOfContentsOffset);
|
|
reader.seek(tableOfContentsOffset);
|
|
|
this.schema = {};
|
|
this.schema = {};
|
|
|
this.schema.inputs = [];
|
|
this.schema.inputs = [];
|
|
|
|
|
+ const columns = new Map();
|
|
|
for (let c = 0; c < columnCount; c ++) {
|
|
for (let c = 0; c < columnCount; c ++) {
|
|
|
const input = {};
|
|
const input = {};
|
|
|
input.name = reader.string();
|
|
input.name = reader.string();
|
|
@@ -643,6 +743,9 @@ mlnet.BinaryLoader = class { // 'BINLOADR'
|
|
|
input.rowsPerBlock = reader.leb128();
|
|
input.rowsPerBlock = reader.leb128();
|
|
|
input.lookupOffset = reader.int64();
|
|
input.lookupOffset = reader.int64();
|
|
|
input.metadataTocOffset = reader.int64();
|
|
input.metadataTocOffset = reader.int64();
|
|
|
|
|
+ columns.set(input.name, input);
|
|
|
|
|
+ }
|
|
|
|
|
+ for (const input of columns.values()) {
|
|
|
this.schema.inputs.push(input);
|
|
this.schema.inputs.push(input);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -825,7 +928,10 @@ mlnet.FieldAwareFactorizationMachinePredictionTransformer = class extends mlnet.
|
|
|
}
|
|
}
|
|
|
this.Threshold = reader.float32();
|
|
this.Threshold = reader.float32();
|
|
|
this.ThresholdColumn = context.string();
|
|
this.ThresholdColumn = context.string();
|
|
|
- this.inputs.push({ name: this.ThresholdColumn });
|
|
|
|
|
|
|
+ this.outputs = [];
|
|
|
|
|
+ this.outputs.push({ name: 'Score' });
|
|
|
|
|
+ this.outputs.push({ name: 'Probability' });
|
|
|
|
|
+ this.outputs.push({ name: 'PredictedLabel' });
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
|
|
|
|
@@ -837,11 +943,16 @@ mlnet.SingleFeaturePredictionTransformerBase = class extends mlnet.PredictionTra
|
|
|
this.inputs = [];
|
|
this.inputs = [];
|
|
|
this.inputs.push({ name: featureColumn });
|
|
this.inputs.push({ name: featureColumn });
|
|
|
this.outputs = [];
|
|
this.outputs = [];
|
|
|
- this.outputs.push({ name: featureColumn });
|
|
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
mlnet.ClusteringPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
mlnet.ClusteringPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
|
|
|
+
|
|
|
|
|
+ constructor(context) {
|
|
|
|
|
+ super(context);
|
|
|
|
|
+ this.outputs.push({ name: 'Score' });
|
|
|
|
|
+ this.outputs.push({ name: 'PredictedLabel' });
|
|
|
|
|
+ }
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
mlnet.AnomalyPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
mlnet.AnomalyPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
@@ -851,6 +962,8 @@ mlnet.AnomalyPredictionTransformer = class extends mlnet.SingleFeaturePrediction
|
|
|
const reader = context.reader;
|
|
const reader = context.reader;
|
|
|
this.Threshold = reader.float32();
|
|
this.Threshold = reader.float32();
|
|
|
this.ThresholdColumn = context.string();
|
|
this.ThresholdColumn = context.string();
|
|
|
|
|
+ this.outputs.push({ name: 'Score' });
|
|
|
|
|
+ this.outputs.push({ name: 'PredictedLabel' });
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
|
|
|
|
@@ -871,6 +984,11 @@ mlnet.AffineNormSerializationUtils = class {
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
mlnet.RegressionPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
mlnet.RegressionPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
|
|
|
+
|
|
|
|
|
+ constructor(context) {
|
|
|
|
|
+ super(context);
|
|
|
|
|
+ this.outputs.push({ name: 'Score' });
|
|
|
|
|
+ }
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
mlnet.BinaryPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
mlnet.BinaryPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
|
|
@@ -880,6 +998,8 @@ mlnet.BinaryPredictionTransformer = class extends mlnet.SingleFeaturePredictionT
|
|
|
const reader = context.reader;
|
|
const reader = context.reader;
|
|
|
this.Threshold = reader.float32();
|
|
this.Threshold = reader.float32();
|
|
|
this.ThresholdColumn = context.string();
|
|
this.ThresholdColumn = context.string();
|
|
|
|
|
+ this.outputs.push({ name: 'Score' });
|
|
|
|
|
+ this.outputs.push({ name: 'PredictedLabel' });
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
|
|
|
|
@@ -889,6 +1009,15 @@ mlnet.MulticlassPredictionTransformer = class extends mlnet.SingleFeaturePredict
|
|
|
super(context);
|
|
super(context);
|
|
|
this.TrainLabelColumn = context.string(null);
|
|
this.TrainLabelColumn = context.string(null);
|
|
|
this.inputs.push({ name: this.TrainLabelColumn });
|
|
this.inputs.push({ name: this.TrainLabelColumn });
|
|
|
|
|
+ if (context.modelVersionWritten >= 0x00010002) {
|
|
|
|
|
+ const scoreColumn = context.string(null);
|
|
|
|
|
+ const predictedLabelColumn = context.string(null);
|
|
|
|
|
+ this.outputs.push({ name: scoreColumn || 'Score' });
|
|
|
|
|
+ this.outputs.push({ name: predictedLabelColumn || 'PredictedLabel' });
|
|
|
|
|
+ } else {
|
|
|
|
|
+ this.outputs.push({ name: 'Score' });
|
|
|
|
|
+ this.outputs.push({ name: 'PredictedLabel' });
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
|
|
|
|
@@ -936,11 +1065,11 @@ mlnet.ImageClassificationModelParameters = class extends mlnet.ModelParametersBa
|
|
|
this.imagePreprocessorTensorOutput = reader.string();
|
|
this.imagePreprocessorTensorOutput = reader.string();
|
|
|
this.graphInputTensor = reader.string();
|
|
this.graphInputTensor = reader.string();
|
|
|
this.graphOutputTensor = reader.string();
|
|
this.graphOutputTensor = reader.string();
|
|
|
- this.modelFile = 'TFModel';
|
|
|
|
|
- // const modelBytes = context.openBinary('TFModel');
|
|
|
|
|
- // first uint32 is size of TensorFlow model
|
|
|
|
|
- // inputType = new VectorDataViewType(uint8);
|
|
|
|
|
- // outputType = new VectorDataViewType(float32, classCount);
|
|
|
|
|
|
|
+ const modelReader = context.openBinary('TFModel');
|
|
|
|
|
+ if (modelReader) {
|
|
|
|
|
+ const size = modelReader.uint32();
|
|
|
|
|
+ context.resolve(this, 'tf', modelReader.read(size));
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
};
|
|
};
|
|
|
|
|
|
|
@@ -1559,9 +1688,11 @@ mlnet.OnnxTransformer = class extends mlnet.RowToRowTransformerBase {
|
|
|
constructor(context) {
|
|
constructor(context) {
|
|
|
super(context);
|
|
super(context);
|
|
|
const reader = context.reader;
|
|
const reader = context.reader;
|
|
|
- this.modelFile = 'OnnxModel';
|
|
|
|
|
- // const modelBytes = context.openBinary('OnnxModel');
|
|
|
|
|
- // first uint32 is size of .onnx model
|
|
|
|
|
|
|
+ const modelReader = context.openBinary('OnnxModel');
|
|
|
|
|
+ if (modelReader) {
|
|
|
|
|
+ const size = modelReader.uint32();
|
|
|
|
|
+ context.resolve(this, 'onnx', modelReader.read(size));
|
|
|
|
|
+ }
|
|
|
const numInputs = context.modelVersionWritten > 0x00010001 ? reader.int32() : 1;
|
|
const numInputs = context.modelVersionWritten > 0x00010001 ? reader.int32() : 1;
|
|
|
this.inputs = [];
|
|
this.inputs = [];
|
|
|
for (let i = 0; i < numInputs; i++) {
|
|
for (let i = 0; i < numInputs; i++) {
|