Procházet zdrojové kódy

CoreML tensor formatting (#89)

Lutz Roeder před 8 roky
rodič
revize
9c20dd9ea8
3 změnil soubory, kde provedl 250 přidání a 42 odebrání
  1. 157 38
      src/coreml-model.js
  2. 90 1
      src/coreml-operator.json
  3. 3 3
      src/onnx-model.js

+ 157 - 38
src/coreml-model.js

@@ -304,8 +304,11 @@ class CoreMLNode {
         this._attributes = [];
         this._initializers = [];
         if (data) {
+            var initializerMap = this.initializer(data);
             Object.keys(data).forEach((key) => {
-                this.initialize(key, data[key]);
+                if (!initializerMap[key]) {
+                    this._attributes.push(new CoreMLAttribute(this, key, data[key]));
+                }
             });
         }
     }
@@ -336,6 +339,7 @@ class CoreMLNode {
                 name: initializer.name,
                 connections: [ { 
                     id: initializer.id, 
+                    type: initializer.type,
                     initializer: initializer, } ]
             };
             if (CoreMLOperatorMetadata.operatorMetadata.getInputHidden(this._operator, initializer.name)) {
@@ -361,50 +365,104 @@ class CoreMLNode {
         return this._attributes;
     }
 
-    initialize(name, value) {
+    initializer(data) {
         switch (this._operator) {
-            case 'glmClassifier':
-                if (name == 'weights') {
-                    this._initializers.push(new CoreMLTensor(name, value));
-                    return;
-                }
-                break;
             case 'convolution':
+                var weightsShape = [ data.outputChannels, data.kernelChannels, data.kernelSize[0], data.kernelSize[1] ];
+                if (data.isDeconvolution) {
+                    weightsShape[0] = data.kernelChannels;
+                    weightsShape[1] = Math.floor(data.outputChannels / (data.nGroups != 0 ? data.nGroups : 1));
+                }    
+                this._initializers.push(new CoreMLTensor('weights', weightsShape, data.weights));
+                if (data.hasBias) {
+                    this._initializers.push(new CoreMLTensor('bias', [ data.bias.floatValue.length ], data.bias));
+                }
+                return { 'weights': true, 'bias': data.hasBias };
             case 'innerProduct':
-            case 'embedding':
+                this._initializers.push(new CoreMLTensor('weights', [ data.outputChannels, data.inputChannels ], data.weights));
+                if (data.hasBias) {
+                    this._initializers.push(new CoreMLTensor('bias', [ data.outputChannels ], data.bias));
+                }
+                return { 'weights': true, 'bias': data.hasBias };
             case 'batchnorm':
-            case 'bias':
-            case 'scale':
+                this._initializers.push(new CoreMLTensor('gamma', [ data.channels ], data.gamma));
+                this._initializers.push(new CoreMLTensor('beta', [ data.channels ], data.beta));
+                if (data.mean) {
+                    this._initializers.push(new CoreMLTensor('mean', [ data.channels ], data.mean));
+                }
+                if (data.variance) {
+                    this._initializers.push(new CoreMLTensor('variance', [ data.channels ], data.variance));
+                }
+                return { 'gamma': true, 'beta': true, 'mean': true, 'variance': true };
+            case 'embedding':
+                this._initializers.push(new CoreMLTensor('weights', [ data.inputDim, data.outputChannels ], data.weights));
+                return { 'weights': true };
             case 'loadConstant':    
+                this._initializers.push(new CoreMLTensor('data', data.shape, data.data));            
+                return { 'data': true };
+            case 'scale':
+                this._initializers.push(new CoreMLTensor('scale', data.shapeScale, data.scale));
+                if (data.hasBias) {
+                    this._initializers.push(new CoreMLTensor('bias', data.shapeBias, data.bias));
+                }
+                return { 'scale': true, 'bias': data.hasBias };
+            case 'bias':
+                this._initializers.push(new CoreMLTensor('bias', data.shapeBias, data.bias));
+                return { 'bias': true };
             case 'simpleRecurrentLayer':
-            case 'gru':
-                if (value instanceof coreml.WeightParams) {
-                    this._initializers.push(new CoreMLTensor(name, value));
-                    return;
+                this._initializers.push(new CoreMLTensor('weights', null, data.weightMatrix));
+                this._initializers.push(new CoreMLTensor('recurrent', null, data.recursionMatrix));
+                if (data.hasBiasVectors) {
+                    this._initializers.push(new CoreMLTensor('bias', null, data.biasVector));
                 }
-                break;
+                return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
+            case 'gru':
+                this._initializers.push(new CoreMLTensor('updateGateWeightMatrix', null, data.updateGateWeightMatrix));
+                this._initializers.push(new CoreMLTensor('resetGateWeightMatrix', null, data.resetGateWeightMatrix));
+                this._initializers.push(new CoreMLTensor('outputGateWeightMatrix', null, data.outputGateWeightMatrix));
+                this._initializers.push(new CoreMLTensor('updateGateRecursionMatrix', null, data.updateGateRecursionMatrix));
+                this._initializers.push(new CoreMLTensor('resetGateRecursionMatrix', null, data.resetGateRecursionMatrix));
+                this._initializers.push(new CoreMLTensor('outputGateRecursionMatrix', null, data.outputGateRecursionMatrix));
+                if (data.hasBiasVectors) {
+                    this._initializers.push(new CoreMLTensor('updateGateBiasVector', null, data.updateGateBiasVector));
+                    this._initializers.push(new CoreMLTensor('resetGateBiasVector', null, data.resetGateBiasVector));
+                    this._initializers.push(new CoreMLTensor('outputGateBiasVector', null, data.outputGateBiasVector));
+                }  
+                return {
+                    'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true, 
+                    'updateGateRecursionMatrix': true, 'resetGateRecursionMatrix': true, 'outputGateRecursionMatrix': true,
+                    'updateGateBiasVector': data.hasBiasVectors, 'resetGateBiasVector': data.hasBiasVectors, 'outputGateBiasVector': data.hasBiasVectors };
             case 'uniDirectionalLSTM':
-                if (value instanceof coreml.LSTMWeightParams) {
-                    Object.keys(value).forEach((key) => {
-                        this._initializers.push(new CoreMLTensor(key, value[key]));
-                    });
-                    return;
-                }
-                break;
             case 'biDirectionalLSTM':
-                if (name == 'weightParams' && value.length == 2) {
-                    Object.keys(value[0]).forEach((key) => {
-                        this._initializers.push(new CoreMLTensor(key, value[0][key]));
-                    });
-                    Object.keys(value[1]).forEach((key) => {
-                        this._initializers.push(new CoreMLTensor(key + '_rev', value[1][key]));
-                    });
-                    return;
-                }    
-                break;
+                var count = (this._operator == 'uniDirectionalLSTM') ? 1 : 2;
+                var matrixShape = [ data.outputVectorSize, data.inputVectorSize ];
+                var vectorShape = [ data.outputVectorSize ];
+                for (var i = 0; i < count; i++) {
+                    var weights = count == 1 ? data.weightParams : data.weightParams[i];
+                    var suffix = (i == 0) ? '' : '_rev';
+                    this._initializers.push(new CoreMLTensor('inputGateWeightMatrix' + suffix, matrixShape, weights.inputGateWeightMatrix));
+                    this._initializers.push(new CoreMLTensor('forgetGateWeightMatrix' + suffix, matrixShape, weights.forgetGateWeightMatrix));
+                    this._initializers.push(new CoreMLTensor('blockInputWeightMatrix' + suffix, matrixShape, weights.blockInputWeightMatrix));
+                    this._initializers.push(new CoreMLTensor('outputGateWeightMatrix' + suffix, matrixShape, weights.outputGateWeightMatrix));
+                    this._initializers.push(new CoreMLTensor('inputGateRecursionMatrix' + suffix, matrixShape, weights.inputGateRecursionMatrix));
+                    this._initializers.push(new CoreMLTensor('forgetGateRecursionMatrix' + suffix, matrixShape,weights.forgetGateRecursionMatrix));
+                    this._initializers.push(new CoreMLTensor('blockInputRecursionMatrix' + suffix, matrixShape, weights.blockInputRecursionMatrix));
+                    this._initializers.push(new CoreMLTensor('outputGateRecursionMatrix' + suffix, matrixShape, weights.outputGateRecursionMatrix));
+                    if (data.params.hasBiasVectors) {
+                        this._initializers.push(new CoreMLTensor('inputGateBiasVector' + suffix, vectorShape, weights.inputGateBiasVector));
+                        this._initializers.push(new CoreMLTensor('forgetGateBiasVector' + suffix, vectorShape, weights.forgetGateBiasVector));
+                        this._initializers.push(new CoreMLTensor('blockInputBiasVector' + suffix, vectorShape, weights.blockInputBiasVector));
+                        this._initializers.push(new CoreMLTensor('outputGateBiasVector' + suffix, vectorShape, weights.outputGateBiasVector));
+                    }
+                    if (data.params.hasPeepholeVectors) {
+                        this._initializers.push(new CoreMLTensor('inputGatePeepholeVector' + suffix, vectorShape, weights.inputGatePeepholeVector));
+                        this._initializers.push(new CoreMLTensor('forgetGatePeepholeVector' + suffix, vectorShape, weights.forgetGatePeepholeVector));
+                        this._initializers.push(new CoreMLTensor('outputGatePeepholeVector' + suffix, vectorShape, weights.outputGatePeepholeVector));
+                    }
+                }
+                return { 'weightParams': true };
         }
-
-        this._attributes.push(new CoreMLAttribute(this, name, value));
+        return {};
     }
 }
 
@@ -437,9 +495,26 @@ class CoreMLAttribute {
 
 class CoreMLTensor {
 
-    constructor(name, value) {
+    constructor(name, shape, data) {
         this._name = name;
-        this._value = value;
+        this._shape = shape;
+        this._type = null;
+        this._data = null;
+        if (data) {
+            if (data.floatValue && data.floatValue.length > 0) {
+                this._data = data.floatValue;
+                this._type = 'float';
+            }
+            else if (data.float16Value && data.float16Value.length > 0) {
+                this._data = data.float16Value;
+                this._type = 'float16';
+            }
+            else if (data.rawValue && data.rawValue.length > 0) {
+                this._data = null;
+                this._type = 'byte';
+                this._shape = [];
+            }
+        }
     }
 
     get id() {
@@ -450,10 +525,54 @@ class CoreMLTensor {
         return this._name;
     }
 
+    get title() {
+        return 'Initializer';
+    }
+
+    get type() {
+        if (this._type && this._shape) {
+            return this._type + '[' + this._shape.join(',') + ']';
+        }
+        return '?';
+    }
+
     get value() {
-        return JSON.stringify(this._value);
+        if (this._data) {
+            this._index = 0;
+            this._count = 0;
+            var result = this.read(0);
+            delete this._index;
+            delete this._count;
+            return JSON.stringify(result, null, 4);
+        }
+        return '?';
     }
 
+    read(dimension) {
+        var results = [];
+        var size = this._shape[dimension];
+        if (dimension == this._shape.length - 1) {
+            for (var i = 0; i < size; i++) {
+                if (this._count > 10000) {
+                    results.push('...');
+                    return results;
+                }
+                results.push(this._data[this._index]);
+                this._index++;
+                this._count++;
+            }
+        }
+        else {
+            for (var j = 0; j < size; j++) {
+                if (this._count > 10000) {
+                    results.push('...');
+                    return results;
+                }
+                results.push(this.read(dimension + 1));
+            }
+        }
+        return results;
+    }
 }
 
 class CoreMLOperatorMetadata 

+ 90 - 1
src/coreml-operator.json

@@ -153,6 +153,13 @@
       "description": "A layer that performs batch normalization, which is performed along the channel axis, and repeated along the other axes, if present."
     }
   },
+  {
+    "name": "l2normalize",
+    "schema": {
+      "category": "Normalization",
+      "description": "A layer that performs L2 normalization, i.e. divides by the the square root of the sum of squares of all elements of input."
+    }
+  },
   {
     "name": "lrn",
     "schema": {
@@ -177,6 +184,12 @@
       "description": "A layer that rearranges the dimensions and data of an input."
     }
   },
+  {
+    "name": "reduce",
+    "schema": {
+      "description": "A layer that reduces the input using a specified operation."
+    }
+  },
   {
     "name": "flatten",
     "schema": {
@@ -191,6 +204,20 @@
       "description": "A layer that recasts the input into a new shape."
     }
   },
+  {
+    "name": "reorganizeData",
+    "schema": {
+      "category": "Shape",
+      "description": "A layer that reorganizes data in the input in: 1. SPACE_TO_DEPTH, 2. DEPTH_TO_SPACE."
+    }
+  },
+  {
+    "name": "padding",
+    "schema": {
+      "category": "Layer",
+      "description": "Fill a constant value in the padded region."
+    }
+  },
   {
     "name": "crop",
     "schema": {
@@ -208,7 +235,8 @@
   {
     "name": "sequenceRepeat",
     "schema": {
-      "category": "Shape"
+      "category": "Shape",
+      "description": "A layer that repeats a sequence."
     }    
   },
   {
@@ -234,18 +262,73 @@
       ]
     }
   },
+  {
+    "name": "multiply",
+    "schema": {
+      "description": "A layer that performs elementwise multiplication."
+    }
+  },
+  {
+    "name": "max",
+    "schema": {
+      "description": "A layer that computes the elementwise maximum over the inputs."
+    }
+  },
+  {
+    "name": "min",
+    "schema": {
+      "description": "A layer that computes the elementwise minimum over the inputs."
+    }
+  },
+  {
+    "name": "average",
+    "schema": {
+      "description": "A layer that computes the elementwise average of the inputs."
+    }
+  },
+  {
+    "name": "unary",
+    "schema": {
+      "description": "A layer that applies a unary function."
+    }
+  },
+  {
+    "name": "mvn",
+    "schema": {
+      "description": "Fill a constant value in the padded region."
+    }
+  },
+  {
+    "name": "dot",
+    "schema": {
+      "description": "If true, inputs are normalized first, thereby computing the cosine similarity."
+    }
+  },
   {
     "name": "scale",
     "schema": {
+      "category": "Layer",
       "description": "A layer that performs elmentwise multiplication by a scale factor and optionally adds a bias."
     }
   },
+  {
+    "name": "upsample",
+    "schema": {
+      "description": "A layer that scales up spatial dimensions. It supports two modes: nearest neighbour (default) and bilinear."
+    }
+  },
   {
     "name": "slice",
     "schema": {
       "description": "A layer that slices the input data along a given axis."
     }
   },
+  {
+    "name": "slice",
+    "schema": {
+      "description": "A layer that uniformly splits across the channel dimension to produce a specified number of outputs."
+    }
+  },
   {
     "name": "embedding",
     "schema": {
@@ -253,6 +336,12 @@
       "description": "A layer that performs a matrix lookup and optionally adds a bias."
     }
   },
+  {
+    "name": "loadConstant",
+    "schema": {
+      "category": "Data"
+    }
+  },
   {
     "name": "stringClassLabels",
     "schema": {

+ 3 - 3
src/onnx-model.js

@@ -330,19 +330,19 @@ class OnnxAttribute {
     get value() {
         if (this._attribute.ints && this._attribute.ints.length > 0) {
             if (this._attribute.ints.length > 65536) {
-                return "Too large to render.";
+                return "...";
             }
             return this._attribute.ints.map((v) => { return v.toString(); }).join(', '); 
         }
         else if (this._attribute.floats && this._attribute.floats.length > 0) {
             if (this._attribute.floats.length > 65536) {
-                return "Too large to render.";
+                return "...";
             }
             return this._attribute.floats.map(v => v.toString()).join(', ');
         }
         else if (this._attribute.strings && this._attribute.strings.length > 0) {
             if (this._attribute.strings.length > 65536) {
-                return "Too large to render.";
+                return "...";
             }
             return this._attribute.strings.map((s) => {
                 if (s.filter(c => c <= 32 && c >= 128).length == 0) {