Prechádzať zdrojové kódy

Update ML.NET test files (#1541)

Lutz Roeder 3 týždňov pred
rodič
commit
9f92b7ff85
2 zmenil súbory, kde vykonal 192 pridanie a 54 odobranie
  1. 181 50
      source/mlnet.js
  2. 11 4
      test/models.json

+ 181 - 50
source/mlnet.js

@@ -21,6 +21,7 @@ mlnet.ModelFactory = class {
     async open(context) {
     async open(context) {
         const metadata = await context.metadata('mlnet-metadata.json');
         const metadata = await context.metadata('mlnet-metadata.json');
         const reader = new mlnet.ModelReader(context.value);
         const reader = new mlnet.ModelReader(context.value);
+        await reader.resolve(context);
         return new mlnet.Model(metadata, reader);
         return new mlnet.Model(metadata, reader);
     }
     }
 };
 };
@@ -77,7 +78,7 @@ mlnet.Module = class {
                     }
                     }
                 }
                 }
             }
             }
-            const node = new mlnet.Node(metadata, group, transformer, values);
+            const node = new mlnet.Node(metadata, transformer, group, values);
             this.nodes.push(node);
             this.nodes.push(node);
         };
         };
         /* eslint-disable no-use-before-define */
         /* eslint-disable no-use-before-define */
@@ -99,8 +100,12 @@ mlnet.Module = class {
                     break;
                     break;
             }
             }
         };
         };
-        /* eslint-enable no-use-before-define */
         const scope = new Map();
         const scope = new Map();
+        if (reader.schema && reader.schema.inputs) {
+            for (const input of reader.schema.inputs) {
+                scope[input.name] = { argument: input.name, counter: 0 };
+            }
+        }
         if (reader.dataLoaderModel) {
         if (reader.dataLoaderModel) {
             loadTransformer(scope, '', reader.dataLoaderModel);
             loadTransformer(scope, '', reader.dataLoaderModel);
         }
         }
@@ -136,41 +141,43 @@ mlnet.Value = class {
 
 
 mlnet.Node = class {
 mlnet.Node = class {
 
 
-    constructor(metadata, group, transformer, values) {
-        this.group = group;
-        this.name = transformer.__name__;
+    constructor(metadata, obj, group, values) {
+        const op = obj.__type__;
+        this.type = metadata.type(op) || { name: op || '?' };
+        this.name = obj.__name__ || '';
+        this.group = group || '';
         this.inputs = [];
         this.inputs = [];
         this.outputs = [];
         this.outputs = [];
         this.attributes = [];
         this.attributes = [];
-        const type = transformer.__type__;
-        this.type = metadata.type(type) || { name: type };
-        if (transformer.inputs) {
-            let i = 0;
-            for (const input of transformer.inputs) {
-                const value = values.map(input.name);
-                const argument = new mlnet.Argument(i.toString(), [value]);
-                this.inputs.push(argument);
-                i++;
+        if (values && obj.inputs) {
+            for (let i = 0; i < obj.inputs.length; i++) {
+                const value = values.map(obj.inputs[i].name);
+                this.inputs.push(new mlnet.Argument(i.toString(), [value]));
             }
             }
         }
         }
-        if (transformer.outputs) {
-            let i = 0;
-            for (const output of transformer.outputs) {
-                const argument = new mlnet.Argument(i.toString(), [values.map(output.name)]);
-                this.outputs.push(argument);
-                i++;
+        if (values && obj.outputs) {
+            for (let i = 0; i < obj.outputs.length; i++) {
+                this.outputs.push(new mlnet.Argument(i.toString(), [values.map(obj.outputs[i].name)]));
             }
             }
         }
         }
-        for (const [name, obj] of Object.entries(transformer).filter(([key]) => !key.startsWith('_') && key !== 'inputs' && key !== 'outputs')) {
-            const schema = metadata.attribute(transformer.__type__, name);
-            let value = obj;
+        for (const [name, raw] of Object.entries(obj).filter(([key]) => !key.startsWith('_') && key !== 'inputs' && key !== 'outputs')) {
+            const schema = metadata.attribute(op, name);
+            let value = raw;
             let type = null;
             let type = null;
             if (schema) {
             if (schema) {
                 type = schema.type ? schema.type : null;
                 type = schema.type ? schema.type : null;
                 value = mlnet.Utility.enum(type, value);
                 value = mlnet.Utility.enum(type, value);
             }
             }
-            const attribute = new mlnet.Argument(name, value, type);
-            this.attributes.push(attribute);
+            if (value && typeof value === 'object' && !Array.isArray(value) && Array.isArray(value.nodes)) {
+                type = 'graph';
+            } else if (value && typeof value === 'object' && !Array.isArray(value) && value.__type__) {
+                value = new mlnet.Node(metadata, value);
+                type = 'object';
+            } else if (Array.isArray(value) && value.length > 0 && value.every((item) => item && item.__type__)) {
+                value = value.map((item) => new mlnet.Node(metadata, item));
+                type = 'object[]';
+            }
+            this.attributes.push(new mlnet.Argument(name, value, type));
         }
         }
     }
     }
 };
 };
@@ -228,7 +235,6 @@ mlnet.TensorShape = class {
 mlnet.ModelReader = class {
 mlnet.ModelReader = class {
 
 
     constructor(entries) {
     constructor(entries) {
-
         const catalog = new mlnet.ComponentCatalog();
         const catalog = new mlnet.ComponentCatalog();
         catalog.register('AffineNormExec', mlnet.AffineNormSerializationUtils);
         catalog.register('AffineNormExec', mlnet.AffineNormSerializationUtils);
         catalog.register('AnomalyPredXfer', mlnet.AnomalyPredictionTransformer);
         catalog.register('AnomalyPredXfer', mlnet.AnomalyPredictionTransformer);
@@ -309,34 +315,127 @@ mlnet.ModelReader = class {
         catalog.register('TransformerChain', mlnet.TransformerChain);
         catalog.register('TransformerChain', mlnet.TransformerChain);
         catalog.register('ValueMappingTransformer', mlnet.ValueMappingTransformer);
         catalog.register('ValueMappingTransformer', mlnet.ValueMappingTransformer);
         catalog.register('XGBoostMulticlass', mlnet.XGBoostMulticlass);
         catalog.register('XGBoostMulticlass', mlnet.XGBoostMulticlass);
-
-        const root = new mlnet.ModelHeader(catalog, entries, '', null);
-
+        this._resolve = [];
+        const root = new mlnet.ModelHeader(catalog, entries, '', null, this._resolve);
         const version = root.openText('TrainingInfo/Version.txt');
         const version = root.openText('TrainingInfo/Version.txt');
         if (version) {
         if (version) {
             [this.version] = version.split(/[\s+\r]+/);
             [this.version] = version.split(/[\s+\r]+/);
         }
         }
-
         const schemaReader = root.openBinary('Schema');
         const schemaReader = root.openBinary('Schema');
         if (schemaReader) {
         if (schemaReader) {
             this.schema = new mlnet.BinaryLoader(null, schemaReader).schema;
             this.schema = new mlnet.BinaryLoader(null, schemaReader).schema;
         }
         }
-
         const transformerChain = root.open('TransformerChain');
         const transformerChain = root.open('TransformerChain');
         if (transformerChain) {
         if (transformerChain) {
             this.transformerChain = transformerChain;
             this.transformerChain = transformerChain;
         }
         }
-
         const dataLoaderModel = root.open('DataLoaderModel');
         const dataLoaderModel = root.open('DataLoaderModel');
         if (dataLoaderModel) {
         if (dataLoaderModel) {
             this.dataLoaderModel = dataLoaderModel;
             this.dataLoaderModel = dataLoaderModel;
         }
         }
-
         const predictor = root.open('Predictor');
         const predictor = root.open('Predictor');
         if (predictor) {
         if (predictor) {
             this.predictor = predictor;
             this.predictor = predictor;
         }
         }
     }
     }
+
+    async resolve(context) {
+        const resolve = async (entry) => {
+            let module = null;
+            let content = '';
+            if (entry.format === 'tf') {
+                const protobuf = await import('./protobuf.js');
+                module = await context.require('./tf');
+                content = new mlnet.Context(context, 'model.pb', entry.bytes, protobuf);
+            } else if (entry.format === 'onnx') {
+                const protobuf = await import('./protobuf.js');
+                module = await context.require('./onnx');
+                content = new mlnet.Context(context, 'model.onnx', entry.bytes, protobuf);
+            } else {
+                throw new mlnet.Error(`Unsupported ML.NET model format '${entry.format}'.`);
+            }
+            const factory = new module.ModelFactory();
+            await factory.match(content);
+            const model = await factory.open(content);
+            if (model && Array.isArray(model.modules) && model.modules.length > 0) {
+                return model.modules[0];
+            }
+            return null;
+        };
+        const results = await Promise.all(this._resolve.map((entry) => resolve(entry).catch(() => null)));
+        for (let i = 0; i < this._resolve.length; i++) {
+            if (results[i]) {
+                this._resolve[i].target.Model = results[i];
+            }
+        }
+    }
+};
+
+mlnet.Context = class {
+
+    constructor(context, identifier, bytes, protobuf) {
+        this._context = context;
+        this._protobuf = protobuf;
+        this._identifier = identifier;
+        this._stream = new base.BinaryStream(bytes);
+        this._tags = new Map();
+    }
+
+    get identifier() {
+        return this._identifier;
+    }
+
+    get stream() {
+        return this._stream;
+    }
+
+    set(type, value) {
+        this.type = type;
+        this.value = value;
+        return type;
+    }
+
+    async require(id) {
+        return this._context.require(id);
+    }
+
+    async metadata(id) {
+        return this._context.metadata(id);
+    }
+
+    async request(file, encoding) {
+        return this._context.request(file, encoding);
+    }
+
+    async read(type) {
+        if (type === 'protobuf.binary') {
+            return this._protobuf.BinaryReader.open(this.stream);
+        }
+        throw new mlnet.Error(`Unsupported read type '${type}'.`);
+    }
+
+    async tags(type) {
+        if (!this._tags.has(type)) {
+            let tags = new Map();
+            try {
+                const reader = this._protobuf.BinaryReader.open(this.stream);
+                tags = type === 'pb+' ? reader.decode() : reader.signature();
+            } catch {
+                // continue regardless of error
+            }
+            this.stream.seek(0);
+            this._tags.set(type, tags);
+        }
+        return this._tags.get(type);
+    }
+
+    async peek() {
+        return undefined;
+    }
+
+    error(err) {
+        this._context.error(err, false);
+    }
 };
 };
 
 
 mlnet.ComponentCatalog = class {
 mlnet.ComponentCatalog = class {
@@ -360,20 +459,17 @@ mlnet.ComponentCatalog = class {
 
 
 mlnet.ModelHeader = class {
 mlnet.ModelHeader = class {
 
 
-    constructor(catalog, entries, directory, data) {
-
+    constructor(catalog, entries, directory, data, resolve) {
         this._entries = entries;
         this._entries = entries;
         this._catalog = catalog;
         this._catalog = catalog;
         this._directory = directory;
         this._directory = directory;
-
+        this._resolve = resolve;
         if (data) {
         if (data) {
             const reader = new mlnet.BinaryReader(data);
             const reader = new mlnet.BinaryReader(data);
-
             const decoder = new TextDecoder('ascii');
             const decoder = new TextDecoder('ascii');
             reader.assert('ML\0MODEL');
             reader.assert('ML\0MODEL');
             this.versionWritten = reader.uint32();
             this.versionWritten = reader.uint32();
             this.versionReadable = reader.uint32();
             this.versionReadable = reader.uint32();
-
             const modelBlockOffset = reader.uint64().toNumber();
             const modelBlockOffset = reader.uint64().toNumber();
             /* let modelBlockSize = */ reader.uint64();
             /* let modelBlockSize = */ reader.uint64();
             const stringTableOffset = reader.uint64().toNumber();
             const stringTableOffset = reader.uint64().toNumber();
@@ -416,7 +512,6 @@ mlnet.ModelHeader = class {
             }
             }
             reader.seek(tailOffset);
             reader.seek(tailOffset);
             reader.assert('LEDOM\0LM');
             reader.assert('LEDOM\0LM');
-
             this._reader = reader;
             this._reader = reader;
             this._reader.seek(modelBlockOffset);
             this._reader.seek(modelBlockOffset);
         }
         }
@@ -441,7 +536,7 @@ mlnet.ModelHeader = class {
         const stream = this._entries.get(key) || this._entries.get(key.replace(/\//g, '\\'));
         const stream = this._entries.get(key) || this._entries.get(key.replace(/\//g, '\\'));
         if (stream) {
         if (stream) {
             const buffer = stream.peek();
             const buffer = stream.peek();
-            const context = new mlnet.ModelHeader(this._catalog, this._entries, name, buffer);
+            const context = new mlnet.ModelHeader(this._catalog, this._entries, name, buffer, this._resolve);
             const value = this._catalog.create(context.loaderSignature, context);
             const value = this._catalog.create(context.loaderSignature, context);
             value.__type__ = value.__type__ || context.loaderSignature;
             value.__type__ = value.__type__ || context.loaderSignature;
             value.__name__ = name;
             value.__name__ = name;
@@ -475,6 +570,10 @@ mlnet.ModelHeader = class {
     check(signature, verWrittenCur, verWeCanReadBack) {
     check(signature, verWrittenCur, verWeCanReadBack) {
         return signature === this.modelSignature && verWrittenCur >= this.modelVersionReadable && verWeCanReadBack <= this.modelVersionWritten;
         return signature === this.modelSignature && verWrittenCur >= this.modelVersionReadable && verWeCanReadBack <= this.modelVersionWritten;
     }
     }
+
+    resolve(target, format, bytes) {
+        this._resolve.push({ target, format, bytes });
+    }
 };
 };
 
 
 mlnet.BinaryReader = class {
 mlnet.BinaryReader = class {
@@ -635,6 +734,7 @@ mlnet.BinaryLoader = class { // 'BINLOADR'
         reader.seek(tableOfContentsOffset);
         reader.seek(tableOfContentsOffset);
         this.schema = {};
         this.schema = {};
         this.schema.inputs = [];
         this.schema.inputs = [];
+        const columns = new Map();
         for (let c = 0; c < columnCount; c  ++) {
         for (let c = 0; c < columnCount; c  ++) {
             const input = {};
             const input = {};
             input.name = reader.string();
             input.name = reader.string();
@@ -643,6 +743,9 @@ mlnet.BinaryLoader = class { // 'BINLOADR'
             input.rowsPerBlock = reader.leb128();
             input.rowsPerBlock = reader.leb128();
             input.lookupOffset = reader.int64();
             input.lookupOffset = reader.int64();
             input.metadataTocOffset = reader.int64();
             input.metadataTocOffset = reader.int64();
+            columns.set(input.name, input);
+        }
+        for (const input of columns.values()) {
             this.schema.inputs.push(input);
             this.schema.inputs.push(input);
         }
         }
     }
     }
@@ -825,7 +928,10 @@ mlnet.FieldAwareFactorizationMachinePredictionTransformer = class extends mlnet.
         }
         }
         this.Threshold = reader.float32();
         this.Threshold = reader.float32();
         this.ThresholdColumn = context.string();
         this.ThresholdColumn = context.string();
-        this.inputs.push({ name: this.ThresholdColumn });
+        this.outputs = [];
+        this.outputs.push({ name: 'Score' });
+        this.outputs.push({ name: 'Probability' });
+        this.outputs.push({ name: 'PredictedLabel' });
     }
     }
 };
 };
 
 
@@ -837,11 +943,16 @@ mlnet.SingleFeaturePredictionTransformerBase = class extends mlnet.PredictionTra
         this.inputs = [];
         this.inputs = [];
         this.inputs.push({ name: featureColumn });
         this.inputs.push({ name: featureColumn });
         this.outputs = [];
         this.outputs = [];
-        this.outputs.push({ name: featureColumn });
     }
     }
 };
 };
 
 
 mlnet.ClusteringPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
 mlnet.ClusteringPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
+
+    constructor(context) {
+        super(context);
+        this.outputs.push({ name: 'Score' });
+        this.outputs.push({ name: 'PredictedLabel' });
+    }
 };
 };
 
 
 mlnet.AnomalyPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
 mlnet.AnomalyPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
@@ -851,6 +962,8 @@ mlnet.AnomalyPredictionTransformer = class extends mlnet.SingleFeaturePrediction
         const reader = context.reader;
         const reader = context.reader;
         this.Threshold = reader.float32();
         this.Threshold = reader.float32();
         this.ThresholdColumn = context.string();
         this.ThresholdColumn = context.string();
+        this.outputs.push({ name: 'Score' });
+        this.outputs.push({ name: 'PredictedLabel' });
     }
     }
 };
 };
 
 
@@ -871,6 +984,11 @@ mlnet.AffineNormSerializationUtils = class {
 };
 };
 
 
 mlnet.RegressionPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
 mlnet.RegressionPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
+
+    constructor(context) {
+        super(context);
+        this.outputs.push({ name: 'Score' });
+    }
 };
 };
 
 
 mlnet.BinaryPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
 mlnet.BinaryPredictionTransformer = class extends mlnet.SingleFeaturePredictionTransformerBase {
@@ -880,6 +998,8 @@ mlnet.BinaryPredictionTransformer = class extends mlnet.SingleFeaturePredictionT
         const reader = context.reader;
         const reader = context.reader;
         this.Threshold = reader.float32();
         this.Threshold = reader.float32();
         this.ThresholdColumn = context.string();
         this.ThresholdColumn = context.string();
+        this.outputs.push({ name: 'Score' });
+        this.outputs.push({ name: 'PredictedLabel' });
     }
     }
 };
 };
 
 
@@ -889,6 +1009,15 @@ mlnet.MulticlassPredictionTransformer = class extends mlnet.SingleFeaturePredict
         super(context);
         super(context);
         this.TrainLabelColumn = context.string(null);
         this.TrainLabelColumn = context.string(null);
         this.inputs.push({ name: this.TrainLabelColumn });
         this.inputs.push({ name: this.TrainLabelColumn });
+        if (context.modelVersionWritten >= 0x00010002) {
+            const scoreColumn = context.string(null);
+            const predictedLabelColumn = context.string(null);
+            this.outputs.push({ name: scoreColumn || 'Score' });
+            this.outputs.push({ name: predictedLabelColumn || 'PredictedLabel' });
+        } else {
+            this.outputs.push({ name: 'Score' });
+            this.outputs.push({ name: 'PredictedLabel' });
+        }
     }
     }
 };
 };
 
 
@@ -936,11 +1065,11 @@ mlnet.ImageClassificationModelParameters = class extends mlnet.ModelParametersBa
         this.imagePreprocessorTensorOutput = reader.string();
         this.imagePreprocessorTensorOutput = reader.string();
         this.graphInputTensor = reader.string();
         this.graphInputTensor = reader.string();
         this.graphOutputTensor = reader.string();
         this.graphOutputTensor = reader.string();
-        this.modelFile = 'TFModel';
-        // const modelBytes = context.openBinary('TFModel');
-        // first uint32 is size of TensorFlow model
-        // inputType = new VectorDataViewType(uint8);
-        // outputType = new VectorDataViewType(float32, classCount);
+        const modelReader = context.openBinary('TFModel');
+        if (modelReader) {
+            const size = modelReader.uint32();
+            context.resolve(this, 'tf', modelReader.read(size));
+        }
     }
     }
 };
 };
 
 
@@ -1559,9 +1688,11 @@ mlnet.OnnxTransformer = class extends mlnet.RowToRowTransformerBase {
     constructor(context) {
     constructor(context) {
         super(context);
         super(context);
         const reader = context.reader;
         const reader = context.reader;
-        this.modelFile = 'OnnxModel';
-        // const modelBytes = context.openBinary('OnnxModel');
-        // first uint32 is size of .onnx model
+        const modelReader = context.openBinary('OnnxModel');
+        if (modelReader) {
+            const size = modelReader.uint32();
+            context.resolve(this, 'onnx', modelReader.read(size));
+        }
         const numInputs = context.modelVersionWritten > 0x00010001 ? reader.int32() : 1;
         const numInputs = context.modelVersionWritten > 0x00010001 ? reader.int32() : 1;
         this.inputs = [];
         this.inputs = [];
         for (let i = 0; i < numInputs; i++) {
         for (let i = 0; i < numInputs; i++) {

+ 11 - 4
test/models.json

@@ -3167,14 +3167,14 @@
     "target":   "ep_model1.zip",
     "target":   "ep_model1.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216033/ep_model1.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216033/ep_model1.zip",
     "format":   "ML.NET v1.0.0.0",
     "format":   "ML.NET v1.0.0.0",
-    "link":     "https://github.com/lutzroeder/netron/issues/170"
+    "link":     "https://github.com/lutzroeder/netron/issues/1541"
   },
   },
   {
   {
     "type":     "mlnet",
     "type":     "mlnet",
     "target":   "ep_model3.zip",
     "target":   "ep_model3.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216304/ep_model3.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216304/ep_model3.zip",
     "format":   "ML.NET v1.0.0.0",
     "format":   "ML.NET v1.0.0.0",
-    "link":     "https://github.com/lutzroeder/netron/issues/170"
+    "link":     "https://github.com/lutzroeder/netron/issues/1541"
   },
   },
   {
   {
     "type":     "mlnet",
     "type":     "mlnet",
@@ -3211,6 +3211,13 @@
     "format":   "ML.NET v1.0.0.0",
     "format":   "ML.NET v1.0.0.0",
     "link":     "https://github.com/dotnet/machinelearning-samples"
     "link":     "https://github.com/dotnet/machinelearning-samples"
   },
   },
+  {
+    "type":     "mlnet",
+    "target":   "MLModel1.mlnet",
+    "source":   "https://github.com/joesatriani10/Image-Recognition/raw/master/Image%20Recognition/MLModel1.mlnet",
+    "format":   "ML.NET v3.0.0-preview.23323.1",
+    "link":     "https://github.com/joesatriani10/Image-Recognition"
+  },
   {
   {
     "type":     "mlnet",
     "type":     "mlnet",
     "target":   "MovieRecommender_Model.zip",
     "target":   "MovieRecommender_Model.zip",
@@ -3223,7 +3230,7 @@
     "target":   "ngram.zip",
     "target":   "ngram.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216079/ngram.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216079/ngram.zip",
     "format":   "ML.NET v3.10.29.504",
     "format":   "ML.NET v3.10.29.504",
-    "link":     "https://github.com/lutzroeder/netron/issues/170"
+    "link":     "https://github.com/lutzroeder/netron/issues/1541"
   },
   },
   {
   {
     "type":     "mlnet",
     "type":     "mlnet",
@@ -3286,7 +3293,7 @@
     "target":   "termlookup_with_key.zip",
     "target":   "termlookup_with_key.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216237/termlookup_with_key.zip",
     "source":   "https://github.com/lutzroeder/netron/files/4216237/termlookup_with_key.zip",
     "format":   "ML.NET v1.0.0.0",
     "format":   "ML.NET v1.0.0.0",
-    "link":     "https://github.com/lutzroeder/netron/issues/170"
+    "link":     "https://github.com/lutzroeder/netron/issues/1541"
   },
   },
   {
   {
     "type":     "mlnet",
     "type":     "mlnet",