Przeglądaj źródła

Fix ML.NET ImageClassificationPred (#439)

Lutz Roeder 6 lat temu
rodzic
commit
ab7353211d
1 zmienionych plików z 22 dodań i 0 usunięć
  1. 22 0
      src/mlnet.js

+ 22 - 0
src/mlnet.js

@@ -454,6 +454,7 @@ mlnet.ModelReader = class {
         catalog.register('IidChangePointDetector', mlnet.IidChangePointDetector);
         catalog.register('IidSpikeDetector', mlnet.IidSpikeDetector);
         catalog.register('ImageClassificationTrans', mlnet.ImageClassificationTransformer);
+        catalog.register('ImageClassificationPred', mlnet.ImageClassificationModelParameters);
         catalog.register('ImageLoaderTransform', mlnet.ImageLoadingTransformer);
         catalog.register('ImageScalerTransform', mlnet.ImageResizingTransformer);
         catalog.register('ImagePixelExtractor', mlnet.ImagePixelExtractingTransformer);
@@ -1169,6 +1170,24 @@ mlnet.ModelParametersBase = class {
     }
 };
 
+mlnet.ImageClassificationModelParameters = class extends mlnet.ModelParametersBase {
+
+    constructor(context) {
+        super(context);
+        const reader = context.reader;
+        this.classCount = reader.int32();
+        this.imagePreprocessorTensorInput = reader.string();
+        this.imagePreprocessorTensorOutput = reader.string();
+        this.graphInputTensor = reader.string();
+        this.graphOutputTensor = reader.string();
+        this.modelFile = 'TFModel';
+        // const modelBytes = context.openBinary('TFModel');
+        // first uint32 is size of TensorFlow model
+        // inputType = new VectorDataViewType(uint8);
+        // outputType = new VectorDataViewType(float32, classCount);
+    }
+};
+
 mlnet.NaiveBayesMulticlassModelParameters = class extends mlnet.ModelParametersBase {
 
     constructor(context) {
@@ -1834,6 +1853,9 @@ mlnet.OnnxTransformer = class extends mlnet.RowToRowTransformerBase {
     constructor(context) {
         super(context);
         const reader = context.reader;
+        this.modelFile = 'OnnxModel';
+        // const modelBytes = context.openBinary('OnnxModel');
+        // first uint32 is size of .onnx model
         const numInputs = context.modelVersionWritten > 0x00010001 ? reader.int32() : 1;
         this.inputs = [];
         for (let i = 0; i < numInputs; i++) {