Selaa lähdekoodia

Add ML.NET test files (#170)

Lutz Roeder 6 vuotta sitten
vanhempi
sitoutus
63f6f78fee
2 muutettua tiedostoa jossa 315 lisäystä ja 70 poistoa
  1. 287 70
      src/mlnet.js
  2. 28 0
      test/models.json

+ 287 - 70
src/mlnet.js

@@ -364,7 +364,7 @@ mlnet.TensorType = class {
                 this._dataType = mlnet.TensorType._map.get(codec.itemType.name);
             }
             else {
-                throw new mlnet.Error("Unknown data type '" + codec.name + "'.");
+                throw new mlnet.Error("Unknown data type '" + codec.itemType.name + "'.");
             }
             this._shape = new mlnet.TensorShape(codec.dims);
         }
@@ -467,7 +467,9 @@ mlnet.ModelReader = class {
         catalog.register('BinaryPredXfer', mlnet.BinaryPredictionTransformer);
         catalog.register('BinaryLoader', mlnet.BinaryLoader);
         catalog.register('CaliPredExec', mlnet.CalibratedPredictor);
+        catalog.register('CdfNormalizeFunction', mlnet.CdfColumnFunction);
         catalog.register('CharToken', mlnet.TokenizingByCharactersTransformer);
+        catalog.register('ChooseColumnsTransform', mlnet.ColumnSelectingTransformer);       
         catalog.register('ClusteringPredXfer', mlnet.ClusteringPredictionTransformer);
         catalog.register('ConcatTransform', mlnet.ColumnConcatenatingTransformer);
         catalog.register('CopyTransform', mlnet.ColumnCopyingTransformer);
@@ -478,6 +480,7 @@ mlnet.ModelReader = class {
         catalog.register('FastForestBinaryExec', mlnet.FastForestClassificationPredictor);
         catalog.register('FastTreeBinaryExec', mlnet.FastTreeBinaryModelParameters);
         catalog.register('FastTreeTweedieExec', mlnet.FastTreeTweedieModelParameters);
+        catalog.register('FastTreeRankerExec', mlnet.FastTreeRankingModelParameters);
         catalog.register('FastTreeRegressionExec', mlnet.FastTreeRegressionModelParameters);
         catalog.register('FeatWCaliPredExec', mlnet.FeatureWeightsCalibratedModelParameters);
         catalog.register('FieldAwareFactMacPredict', mlnet.FieldAwareFactorizationMachineModelParameters);
@@ -496,15 +499,21 @@ mlnet.ModelReader = class {
         catalog.register('LightGBMRegressionExec', mlnet.LightGbmRegressionModelParameters);
         catalog.register('LightGBMBinaryExec', mlnet.LightGbmBinaryModelParameters);
         catalog.register('Linear2CExec', mlnet.LinearBinaryModelParameters);
+        catalog.register('LinearModelStats', mlnet.LinearModelParameterStatistics);
+        catalog.register('MaFactPredXf', mlnet.MatrixFactorizationPredictionTransformer);
+        catalog.register('MFPredictor', mlnet.MatrixFactorizationModelParameters);
         catalog.register('MulticlassLinear', mlnet.LinearMulticlassModelParameters);
         catalog.register('MultiClassLRExec', mlnet.MaximumEntropyModelParameters);
+        catalog.register('MultiClassNaiveBayesPred', mlnet.NaiveBayesMulticlassModelParameters);
         catalog.register('MultiClassNetPredictor', mlnet.MultiClassNetPredictor);
         catalog.register('MulticlassPredXfer', mlnet.MulticlassPredictionTransformer);
         catalog.register('NgramTransform', mlnet.NgramExtractingTransformer);
+        catalog.register('NgramHashTransform', mlnet.NgramHashingTransformer);
         catalog.register('NltTokenizeTransform', mlnet.NltTokenizeTransform);
         catalog.register('Normalizer', mlnet.NormalizingTransformer);
         catalog.register('NormalizeTransform', mlnet.NormalizeTransform);
         catalog.register('OnnxTransform', mlnet.OnnxTransformer);
+        catalog.register('OptColTransform', mlnet.OptionalColumnTransform);
         catalog.register('OVAExec', mlnet.OneVersusAllModelParameters);
         catalog.register('pcaAnomExec', mlnet.PcaModelParameters);
         catalog.register('PcaTransform', mlnet.PrincipalComponentAnalysisTransformer);
@@ -520,6 +529,7 @@ mlnet.ModelReader = class {
         catalog.register('SelectColumnsTransform', mlnet.ColumnSelectingTransformer);
         catalog.register('StopWordsTransform', mlnet.StopWordsTransform);
         catalog.register('TensorFlowTransform', mlnet.TensorFlowTransformer);
+        catalog.register('TermLookupTransform', mlnet.ValueMappingTransformer);
         catalog.register('TermTransform', mlnet.ValueToKeyMappingTransformer);
         catalog.register('TermManager', mlnet.TermManager);
         catalog.register('Text', mlnet.TextFeaturizingEstimator);
@@ -685,6 +695,10 @@ mlnet.ModelHeader = class {
         }
         return null;
     }
+    
+    check(signature, verWrittenCur, verWeCanReadBack) {
+        return signature === this.modelSignature && verWrittenCur >= this.modelVersionReadable && verWeCanReadBack <= this.modelVersionWritten; 
+    }
 };
 
 mlnet.Reader = class {
@@ -915,6 +929,49 @@ mlnet.TransformerChain = class {
     }
 };
 
+mlnet.TransformBase = class {
+
+    constructor(/* context */) {
+
+    }
+}
+
+mlnet.RowToRowTransformBase = class extends mlnet.TransformBase {
+
+    constructor(context) {
+        super(context);
+    }
+}
+
+mlnet.RowToRowTransformerBase = class {
+
+    constructor(/* context */) {
+    }
+}
+
+mlnet.RowToRowMapperTransformBase = class extends mlnet.RowToRowTransformBase {
+
+    constructor(context) {
+        super(context);
+    }
+}
+
+mlnet.OneToOneTransformerBase = class {
+
+    constructor(context) {
+        const reader = context.reader;
+        const n = reader.int32();
+        this.inputs = [];
+        this.outputs = [];
+        for (let i = 0; i < n; i++) {
+            const output = context.string();
+            const input = context.string();
+            this.outputs.push({ name: output });
+            this.inputs.push({ name: input });
+        }
+    }
+};
+
 mlnet.ColumnCopyingTransformer = class {
 
     constructor(context) {
@@ -932,7 +989,7 @@ mlnet.ColumnCopyingTransformer = class {
 mlnet.ColumnConcatenatingTransformer = class {
 
     constructor(context) {
-        let reader = context.reader;
+        const reader = context.reader;
         if (context.modelVersionReadable >= 0x00010003) {
             const count = reader.int32();
             for (let i = 0; i < count; i++) {
@@ -1010,6 +1067,35 @@ mlnet.PredictionTransformerBase = class {
     }
 };
 
+mlnet.MatrixFactorizationModelParameters = class {
+
+    constructor(context) {
+        const reader = context.reader;
+        this.NumberOfRows = reader.int32();
+        if (context.modelVersionWritten < 0x00010002) {
+            reader.uint64(); // mMin
+        }
+        this.NumberOfColumns = reader.int32();
+        if (context.modelVersionWritten < 0x00010002) {
+            reader.uint64(); // nMin
+        }
+        this.ApproximationRank = reader.int32();
+
+        this._leftFactorMatrix = reader.float32s(this.NumberOfRows * this.ApproximationRank);
+        this._rightFactorMatrix = reader.float32s(this.NumberOfColumns * this.ApproximationRank);
+    }
+}
+
+mlnet.MatrixFactorizationPredictionTransformer = class extends mlnet.PredictionTransformerBase {
+
+    constructor(context) {
+        super(context);
+        this.MatrixColumnIndexColumnName = context.string();
+        this.MatrixRowIndexColumnName = context.string();
+        // TODO
+    }
+}
+
 mlnet.FieldAwareFactorizationMachinePredictionTransformer = class extends mlnet.PredictionTransformerBase {
 
     constructor(context) {
@@ -1118,6 +1204,23 @@ mlnet.ModelParametersBase = class {
     }
 };
 
+mlnet.NaiveBayesMulticlassModelParameters = class extends mlnet.ModelParametersBase {
+
+    constructor(context) {
+        super(context);
+        const reader = context.reader;
+        this._labelHistogram = reader.int32s(reader.int32());
+        this._featureCount = reader.int32();
+        this._featureHistogram = [];
+        for (let i = 0; i < this._labelHistogram.length; i++) {
+            if (this._labelHistogram[i] > 0) {
+                this._featureHistogram.push(reader.int32s(this._featureCount));
+            }
+        }
+        this._absentFeaturesLogProb = reader.float64s(this._labelHistogram.length);
+    }
+}
+
 mlnet.LinearModelParameters = class extends mlnet.ModelParametersBase {
 
     constructor(context) {
@@ -1140,6 +1243,51 @@ mlnet.LinearBinaryModelParameters = class extends mlnet.LinearModelParameters {
     }
 }
 
+mlnet.ModelStatisticsBase = class {
+
+    constructor(context) {
+        const reader = context.reader;
+        this.ParametersCount = reader.int32();
+        this.TrainingExampleCount = reader.int64();
+        this.Deviance = reader.float32();
+        this.NullDeviance = reader.float32();
+
+    }    
+}
+
+mlnet.LinearModelParameterStatistics = class extends mlnet.ModelStatisticsBase {
+
+    constructor(context) {
+        super(context);
+        const reader = context.reader;
+        if (context.modelVersionWritten < 0x00010002) {
+            if (!reader.boolean()) {
+                return;
+            }
+        }
+        const stdErrorValues = reader.float32s(this.ParametersCount);
+        const length = reader.int32();
+        if (length == this.ParametersCount) {
+            this._coeffStdError = stdErrorValues;
+        }
+        else {
+            this.stdErrorIndices = reader.int32s(this.ParametersCount);
+            this._coeffStdError = stdErrorValues;
+        }
+        this._bias = reader.float32();
+        const isWeightsDense = reader.byte();
+        const weightsLength = reader.int32();
+        const weightsValues = reader.float32s(weightsLength);
+
+        if (isWeightsDense) {
+            this._weights = weightsValues;
+        }
+        else {
+            this.weightsIndices = reader.int32s(weightsLength);
+        }
+    }
+}
+
 mlnet.LinearMulticlassModelParametersBase = class extends mlnet.ModelParametersBase {
 
     constructor(context) {
@@ -1225,22 +1373,6 @@ mlnet.MaximumEntropyModelParameters = class extends mlnet.LinearMulticlassModelP
     }
 };
 
-mlnet.OneToOneTransformerBase = class {
-
-    constructor(context) {
-        const reader = context.reader;
-        const n = reader.int32();
-        this.inputs = [];
-        this.outputs = [];
-        for (let i = 0; i < n; i++) {
-            const output = context.string();
-            const input = context.string();
-            this.outputs.push({ name: output });
-            this.inputs.push({ name: input });
-        }
-    }
-};
-
 mlnet.TokenizingByCharactersTransformer = class extends mlnet.OneToOneTransformerBase {
 
     constructor(context) {
@@ -1290,6 +1422,38 @@ mlnet.NgramExtractingTransformer = class extends mlnet.OneToOneTransformerBase {
 
 // mlnet.NgramExtractingTransformer.WeightingCriteria
 
+mlnet.NgramHashingTransformer = class extends mlnet.RowToRowTransformerBase {
+
+    constructor(context) {
+        super(context);
+        const loadLegacy = context.modelVersionWritten < 0x00010003
+        const reader = context.reader;
+        if (loadLegacy) {
+            reader.int32(); // cbFloat
+        }
+        this.inputs = [];
+        this.outputs = [];
+        const columnsLength = reader.int32();
+        if (loadLegacy) {
+            /* TODO
+            for (let i = 0; i < columnsLength; i++) {
+                this.Columns.push(new NgramHashingEstimator.ColumnOptions(context));
+            } */
+        }
+        else {
+            for (let i = 0; i < columnsLength; i++) {
+                this.outputs.push(context.string());
+                let csrc = reader.int32();
+                for (let j = 0; j < csrc; j++) {
+                    let src = context.string();
+                    this.inputs.push(src);
+                    // TODO inputs[i][j] = src;
+                }
+            }
+        }
+    }
+}
+
 mlnet.WordTokenizingTransformer = class extends mlnet.OneToOneTransformerBase {
 
     constructor(context) {
@@ -1498,17 +1662,29 @@ mlnet.NormalizingTransformer = class extends mlnet.OneToOneTransformerBase {
         const reader = context.reader;
         this.Options = [];
         for (let i = 0; i < this.inputs.length; i++) {
-            const name = 'Normalizer_' + ('00' + i).slice(-3);
-            /* let isVector = */ reader.boolean();
-            /* let vectorSize = */ reader.int32();
-            const itemKind = reader.byte();
+            let isVector = false;
+            let shape = 0;
+            let itemKind = '';
+            if (context.modelVersionWritten < 0x00010002) {
+                isVector = reader.boolean();
+                shape = [ reader.int32() ];
+                itemKind = reader.byte();
+            }
+            else {
+                isVector = reader.boolean();
+                itemKind = reader.byte();
+                shape = reader.int32s(reader.int32());
+            }
+            let itemType = '';
             switch (itemKind) {
-                case 9: this.itemType = 'float32'; break;
-                case 10: this.itemType = 'float64'; break;
+                case 9: itemType = 'float32'; break;
+                case 10: itemType = 'float64'; break;
                 default: throw new mlnet.Error("Unknown NormalizingTransformer item kind '" + itemKind + "'.");
             }
+            const type = itemType + (!isVector ? '' : '[' + shape.map((dim) => dim.toString()).join(',') + ']');
+            const name = 'Normalizer_' + ('00' + i).slice(-3);
             const func = context.open(name);
-            this.Options.push({ func: func });
+            this.Options.push({ type: type, func: func });
         }
     }
 }
@@ -1595,7 +1771,11 @@ mlnet.ValueMappingTransformer = class extends mlnet.OneToOneTransformerBase {
 
     constructor(context) {
         super(context);
-        // debugger;
+        this.keyColumnName = 'Key';
+        if (context.check('TXTLOOKT', 0x00010002, 0x00010002)) {
+            this.keyColumnName = 'Term';
+        }
+        // TODO
     }
 }
 
@@ -1617,7 +1797,7 @@ mlnet.CompositeDataLoader = class {
         /* let loader = */ context.open('Loader');
         const reader = context.reader;
         // LoadTransforms
-        this.floatSize = reader.int32();
+        reader.int32(); // floatSize
         const cxf = reader.int32();
         const tagData = [];
         for (let i = 0; i < cxf; i++) {
@@ -1638,20 +1818,6 @@ mlnet.CompositeDataLoader = class {
     }
 };
 
-mlnet.TransformBase = class {
-
-    constructor(/* context */) {
-
-    }
-}
-
-mlnet.RowToRowTransformBase = class extends mlnet.TransformBase {
-
-    constructor(context) {
-        super(context);
-    }
-}
-
 mlnet.RowToRowMapperTransform = class extends mlnet.RowToRowTransformBase {
 
     constructor(context) {
@@ -1664,12 +1830,6 @@ mlnet.RowToRowMapperTransform = class extends mlnet.RowToRowTransformBase {
     }
 }
 
-mlnet.RowToRowTransformerBase = class {
-
-    constructor(/* context */) {
-    }
-}
-
 mlnet.ImageClassificationTransformer = class extends mlnet.RowToRowTransformerBase {
 
     constructor(context) {
@@ -1732,6 +1892,13 @@ mlnet.OnnxTransformer = class extends mlnet.RowToRowTransformerBase {
     }
 }
 
+mlnet.OptionalColumnTransform = class extends mlnet.RowToRowMapperTransformBase {
+
+    constructor(context) {
+        super(context);
+    }
+}
+
 mlnet.TensorFlowTransformer = class extends mlnet.RowToRowTransformerBase {
 
     constructor(context) {
@@ -1797,7 +1964,7 @@ mlnet.TextLoader = class {
 
     constructor(context) {
         const reader = context.reader;
-        this.FloatSize = reader.int32();
+        reader.int32(); // floatSize
         this.MaxRows = reader.int64();
         this.Flags = reader.uint32();
         this.InputSize = reader.int32();
@@ -2023,6 +2190,17 @@ mlnet.FastTreeTweedieModelParameters = class extends mlnet.TreeEnsembleModelPara
     get VerCategoricalSplitSerialized() { return 0x00010003; }
 }
 
+mlnet.FastTreeRankingModelParameters = class extends mlnet.TreeEnsembleModelParametersBasedOnRegressionTree {
+
+    constructor(context) {
+        super(context);
+    }
+
+    get VerNumFeaturesSerialized() { return 0x00010002; }
+    get VerDefaultValueSerialized() { return 0x00010004; }
+    get VerCategoricalSplitSerialized() { return 0x00010005; }
+}
+
 mlnet.FastTreeBinaryModelParameters = class extends mlnet.TreeEnsembleModelParametersBasedOnRegressionTree {
 
     constructor(context) {
@@ -2105,22 +2283,19 @@ mlnet.Codec = class {
         reader = new mlnet.Reader(data);
 
         switch (this.name) {
-            case 'Boolean':
-                break;
-            case 'Single':
-                break;
-            case 'Double':
-                break;
-            case 'Byte':
-                break;
-            case 'UInt32':
-                break;
-            case 'TextSpan':
-                break;
+            case 'Boolean': break;
+            case 'Single': break;
+            case 'Double': break;
+            case 'Byte': break;
+            case 'Int32': break;
+            case 'UInt32': break;
+            case 'Int64': break;
+            case 'TextSpan': break;
             case 'VBuffer':
                 this.itemType = new mlnet.Codec(reader);
                 this.dims = reader.int32s(reader.int32());
                 break;
+            case 'Key':
             case 'Key2':
                 this.itemType = new mlnet.Codec(reader);
                 this.count = reader.uint64();
@@ -2138,6 +2313,16 @@ mlnet.Codec = class {
                     values.push(reader.float32());
                 }
                 break;
+            case 'Int32':
+                for (let i = 0; i < count; i++) {
+                    values.push(reader.int32());
+                }
+                break;
+            case 'Int64':
+                for (let i = 0; i < count; i++) {
+                    values.push(reader.int64());
+                }
+                break;
             default:
                 throw new mlnet.Error("Unknown codec read operation '" + this.name + "'.");
         }
@@ -2326,19 +2511,44 @@ mlnet.ColumnSelectingTransformer = class {
 
     constructor(context) {
         const reader = context.reader;
-        const keepColumns = reader.boolean();
-        this.KeepHidden = reader.boolean();
-        this.IgnoreMissing = reader.boolean();
-        const length = reader.int32();
-        this.inputs = [];
-        for (let i = 0; i < length; i++) {
-            this.inputs.push({ name: context.string() });
+        if (context.check('DRPCOLST', 0x00010002, 0x00010002)) {
+            throw new mlnet.Error("'LoadDropColumnsTransform' not supported.");
         }
-        if (keepColumns) {
-            this.ColumnsToKeep = this.inputs;
+        else if (context.check('CHSCOLSF', 0x00010001, 0x00010001)) {
+            reader.int32(); // cbFloat
+            this.KeepHidden = this._getHiddenOption(reader.byte());
+            const count = reader.int32();
+            this.inputs = [];
+            for (let colIdx = 0; colIdx < count; colIdx++) {
+                const dst = context.string();
+                this.inputs.push(dst);
+                context.string(); // src 
+                this._getHiddenOption(reader.byte()); // colKeepHidden
+            }
         }
         else {
-            this.ColumnsToDrop = this.inputs;
+            const keepColumns = reader.boolean();
+            this.KeepHidden = reader.boolean();
+            this.IgnoreMissing = reader.boolean();
+            const length = reader.int32();
+            this.inputs = [];
+            for (let i = 0; i < length; i++) {
+                this.inputs.push({ name: context.string() });
+            }
+            if (keepColumns) {
+                this.ColumnsToKeep = this.inputs;
+            }
+            else {
+                this.ColumnsToDrop = this.inputs;
+            }
+        }
+    }
+
+    _getHiddenOption(value) {
+        switch (value) {
+            case 1: return true;
+            case 2: return false;
+            default: throw new mlnet.Error('Unsupported hide option specified');
         }
     }
 }
@@ -2357,6 +2567,13 @@ mlnet.GenericScoreTransform = class {}
 
 mlnet.NormalizeTransform = class {}
 
+mlnet.CdfColumnFunction = class {
+
+    constructor(/* context, typeSrc */) {
+        // TODO
+    }
+}
+
 mlnet.MultiClassNetPredictor = class {}
 
 mlnet.ProtonNNMCPred = class {}

+ 28 - 0
test/models.json

@@ -2306,6 +2306,20 @@
     "format": "ML.NET v1.4.28305.1",
     "link":   "https://github.com/dotnet/machinelearning-samples"
   },
+  {
+    "type":   "mlnet",
+    "target": "ep_model1.zip",
+    "source": "https://github.com/lutzroeder/netron/files/4216033/ep_model1.zip",
+    "format": "ML.NET v1.0.0.0",
+    "link":   "https://github.com/lutzroeder/netron/issues/170"
+  },
+  {
+    "type":   "mlnet",
+    "target": "ep_model3.zip",
+    "source": "https://github.com/lutzroeder/netron/files/4216304/ep_model3.zip",
+    "format": "ML.NET v1.0.0.0",
+    "link":   "https://github.com/lutzroeder/netron/issues/170"
+  },
   {
     "type":   "mlnet",
     "target": "FastTreeModel.zip",
@@ -2355,6 +2369,13 @@
     "format": "ML.NET v1.0.0.0",
     "link":   "https://github.com/dotnet/machinelearning-samples"
   },
+  {
+    "type":   "mlnet",
+    "target": "ngram.zip",
+    "source": "https://github.com/lutzroeder/netron/files/4216079/ngram.zip?raw=true",
+    "format": "ML.NET v3.10.29.504",
+    "link":   "https://github.com/lutzroeder/netron/issues/170"
+  },
   {
     "type":   "mlnet",
     "target": "PoissonModel.zip",
@@ -2411,6 +2432,13 @@
     "format": "ML.NET v1.0.27701.1",
     "link":   "https://github.com/dotnet/machinelearning-samples"
   },
+  {
+    "type":   "mlnet",
+    "target": "termlookup_with_key.zip",
+    "source": "https://github.com/lutzroeder/netron/files/4216237/termlookup_with_key.zip?raw=true",
+    "format": "ML.NET v1.0.0.0",
+    "link":   "https://github.com/lutzroeder/netron/issues/170"
+  },
   {
     "type":   "mlnet",
     "target": "TinyYoloModel.zip",