Ver código fonte

Update .dnn prototype (#581)

Lutz Roeder 5 anos atrás
pai
commit
065bb26828
3 arquivos alterados com 155 adições e 128 exclusões
  1. 77 77
      source/dnn-proto.js
  2. 39 12
      source/dnn.js
  3. 39 39
      tools/dnn.proto

+ 77 - 77
source/dnn-proto.js

@@ -22,10 +22,10 @@ $root.dnn.Model = class Model {
                     message.name = reader.string();
                     break;
                 case 2:
-                    message.version = reader.int64();
+                    message.version = reader.int32();
                     break;
                 case 4:
-                    message.input_shape = reader.array(message.input_shape, () => reader.int64(), tag);
+                    message.input_shape = reader.array(message.input_shape, () => reader.int32(), tag);
                     break;
                 case 7:
                     message.input_name.push(reader.string());
@@ -52,7 +52,7 @@ $root.dnn.Model = class Model {
 };
 
 $root.dnn.Model.prototype.name = "";
-$root.dnn.Model.prototype.version = protobuf.Int64.create(0);
+$root.dnn.Model.prototype.version = 0;
 $root.dnn.Model.prototype.a014 = 0;
 
 $root.dnn.Parameter = class Parameter {
@@ -96,16 +96,16 @@ $root.dnn.Shape = class Shape {
             const tag = reader.uint32();
             switch (tag >>> 3) {
                 case 1:
-                    message.dim0 = reader.int64();
+                    message.dim0 = reader.int32();
                     break;
                 case 2:
-                    message.dim1 = reader.int64();
+                    message.dim1 = reader.int32();
                     break;
                 case 3:
-                    message.dim2 = reader.int64();
+                    message.dim2 = reader.int32();
                     break;
                 case 4:
-                    message.dim3 = reader.int64();
+                    message.dim3 = reader.int32();
                     break;
                 default:
                     reader.skipType(tag & 7);
@@ -116,10 +116,10 @@ $root.dnn.Shape = class Shape {
     }
 };
 
-$root.dnn.Shape.prototype.dim0 = protobuf.Int64.create(0);
-$root.dnn.Shape.prototype.dim1 = protobuf.Int64.create(0);
-$root.dnn.Shape.prototype.dim2 = protobuf.Int64.create(0);
-$root.dnn.Shape.prototype.dim3 = protobuf.Int64.create(0);
+$root.dnn.Shape.prototype.dim0 = 0;
+$root.dnn.Shape.prototype.dim1 = 0;
+$root.dnn.Shape.prototype.dim2 = 0;
+$root.dnn.Shape.prototype.dim3 = 0;
 
 $root.dnn.Node = class Node {
 
@@ -173,88 +173,88 @@ $root.dnn.Layer = class Layer {
                     message.type = reader.string();
                     break;
                 case 3:
-                    message.a003 = reader.uint64();
+                    message.filters = reader.int32();
                     break;
                 case 7:
-                    message.a007 = reader.uint64();
+                    message.a007 = reader.int32();
                     break;
                 case 8:
-                    message.a008 = reader.uint64();
+                    message.a008 = reader.int32();
                     break;
                 case 9:
-                    message.a009 = reader.uint64();
+                    message.groups = reader.int32();
                     break;
                 case 10:
-                    message.a010 = reader.uint64();
+                    message.a010 = reader.int32();
                     break;
                 case 11:
-                    message.a011 = reader.uint64();
+                    message.a011 = reader.int32();
                     break;
                 case 14:
-                    message.a014 = reader.float();
+                    message.slope = reader.float();
                     break;
                 case 15:
-                    message.a015 = reader.float();
+                    message.intercept = reader.float();
                     break;
                 case 50:
                     message.weight.push($root.dnn.Tensor.decode(reader, reader.uint32()));
                     break;
                 case 72:
-                    message.operation = reader.uint64();
+                    message.operation = reader.int32();
                     break;
                 case 65:
-                    message.axis = reader.uint64();
+                    message.axis = reader.int32();
                     break;
                 case 77:
-                    message.a077 = reader.uint64();
+                    message.a077 = reader.int32();
                     break;
                 case 79:
-                    message.a079 = reader.float();
+                    message.scale = reader.float();
                     break;
                 case 80:
-                    message.a080 = reader.uint64();
+                    message.pad_1 = reader.int32();
                     break;
                 case 81:
-                    message.a081 = reader.uint64();
+                    message.pad_2 = reader.int32();
                     break;
                 case 82:
-                    message.a082 = reader.uint64();
+                    message.pad_3 = reader.int32();
                     break;
                 case 83:
-                    message.a083 = reader.uint64();
+                    message.pad_4 = reader.int32();
                     break;
                 case 84:
-                    message.a084 = reader.uint64();
+                    message.pad_5 = reader.int32();
                     break;
                 case 85:
-                    message.a085 = reader.uint64();
+                    message.a085 = reader.int32();
                     break;
                 case 90:
-                    message.a090 = reader.uint64();
+                    message.a090 = reader.int32();
                     break;
                 case 101:
-                    message.a101 = reader.uint64();
+                    message.is_quantized = reader.bool();
                     break;
                 case 104:
-                    message.a104 = $root.dnn.Buffer.decode(reader, reader.uint32());
+                    message.quantization = $root.dnn.Buffer.decode(reader, reader.uint32());
                     break;
                 case 109:
-                    message.a109 = reader.uint64();
+                    message.stride_w = reader.int32();
                     break;
                 case 110:
-                    message.a110 = reader.uint64();
+                    message.stride_h = reader.int32();
                     break;
                 case 111:
-                    message.a111 = reader.uint64();
+                    message.kernel_w = reader.int32();
                     break;
                 case 112:
-                    message.a112 = reader.uint64();
+                    message.kernel_h = reader.int32();
                     break;
                 case 115:
-                    message.a115 = reader.uint64();
+                    message.a115 = reader.int32();
                     break;
                 case 116:
-                    message.a116 = reader.uint64();
+                    message.a116 = reader.int32();
                     break;
                 default:
                     reader.skipType(tag & 7);
@@ -267,33 +267,33 @@ $root.dnn.Layer = class Layer {
 
 $root.dnn.Layer.prototype.name = "";
 $root.dnn.Layer.prototype.type = "";
-$root.dnn.Layer.prototype.a003 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a007 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a008 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a009 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a010 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a011 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a014 = 0;
-$root.dnn.Layer.prototype.a015 = 0;
-$root.dnn.Layer.prototype.operation = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.axis = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a077 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a079 = 0;
-$root.dnn.Layer.prototype.a080 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a081 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a082 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a083 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a084 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a085 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a090 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a101 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a104 = null;
-$root.dnn.Layer.prototype.a109 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a110 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a111 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a112 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a115 = protobuf.Uint64.create(0);
-$root.dnn.Layer.prototype.a116 = protobuf.Uint64.create(0);
+$root.dnn.Layer.prototype.filters = 0;
+$root.dnn.Layer.prototype.a007 = 0;
+$root.dnn.Layer.prototype.a008 = 0;
+$root.dnn.Layer.prototype.groups = 0;
+$root.dnn.Layer.prototype.a010 = 0;
+$root.dnn.Layer.prototype.a011 = 0;
+$root.dnn.Layer.prototype.slope = 0;
+$root.dnn.Layer.prototype.intercept = 0;
+$root.dnn.Layer.prototype.operation = 0;
+$root.dnn.Layer.prototype.axis = 0;
+$root.dnn.Layer.prototype.a077 = 0;
+$root.dnn.Layer.prototype.scale = 0;
+$root.dnn.Layer.prototype.pad_1 = 0;
+$root.dnn.Layer.prototype.pad_2 = 0;
+$root.dnn.Layer.prototype.pad_3 = 0;
+$root.dnn.Layer.prototype.pad_4 = 0;
+$root.dnn.Layer.prototype.pad_5 = 0;
+$root.dnn.Layer.prototype.a085 = 0;
+$root.dnn.Layer.prototype.a090 = 0;
+$root.dnn.Layer.prototype.is_quantized = false;
+$root.dnn.Layer.prototype.quantization = null;
+$root.dnn.Layer.prototype.stride_w = 0;
+$root.dnn.Layer.prototype.stride_h = 0;
+$root.dnn.Layer.prototype.kernel_w = 0;
+$root.dnn.Layer.prototype.kernel_h = 0;
+$root.dnn.Layer.prototype.a115 = 0;
+$root.dnn.Layer.prototype.a116 = 0;
 
 $root.dnn.Buffer = class Buffer {
 
@@ -332,22 +332,22 @@ $root.dnn.Tensor = class Tensor {
             const tag = reader.uint32();
             switch (tag >>> 3) {
                 case 1:
-                    message.dim0 = reader.int64();
+                    message.dim0 = reader.int32();
                     break;
                 case 2:
-                    message.dim1 = reader.int64();
+                    message.dim1 = reader.int32();
                     break;
                 case 3:
-                    message.dim2 = reader.int64();
+                    message.dim2 = reader.int32();
                     break;
                 case 4:
-                    message.dim3 = reader.int64();
+                    message.dim3 = reader.int32();
                     break;
                 case 5:
-                    message.data1 = reader.bytes();
+                    message.data = reader.bytes();
                     break;
                 case 6:
-                    message.data2 = reader.bytes();
+                    message.quantized_data = reader.bytes();
                     break;
                 default:
                     reader.skipType(tag & 7);
@@ -358,9 +358,9 @@ $root.dnn.Tensor = class Tensor {
     }
 };
 
-$root.dnn.Tensor.prototype.dim0 = protobuf.Int64.create(0);
-$root.dnn.Tensor.prototype.dim1 = protobuf.Int64.create(0);
-$root.dnn.Tensor.prototype.dim2 = protobuf.Int64.create(0);
-$root.dnn.Tensor.prototype.dim3 = protobuf.Int64.create(0);
-$root.dnn.Tensor.prototype.data1 = new Uint8Array([]);
-$root.dnn.Tensor.prototype.data2 = new Uint8Array([]);
+$root.dnn.Tensor.prototype.dim0 = 0;
+$root.dnn.Tensor.prototype.dim1 = 0;
+$root.dnn.Tensor.prototype.dim2 = 0;
+$root.dnn.Tensor.prototype.dim3 = 0;
+$root.dnn.Tensor.prototype.data = new Uint8Array([]);
+$root.dnn.Tensor.prototype.quantized_data = new Uint8Array([]);

+ 39 - 12
source/dnn.js

@@ -1,5 +1,7 @@
 /* jshint esversion: 6 */
 
+// Experimental
+
 var dnn = dnn || {};
 
 dnn.ModelFactory = class {
@@ -90,19 +92,19 @@ dnn.Graph = class {
 
         for (const input of model.input) {
             const shape = input.shape;
-            const type = new dnn.TensorType('?', new dnn.TensorShape([ shape.dim0.toNumber(), shape.dim1.toNumber(), shape.dim2.toNumber(), shape.dim3.toNumber() ]));
+            const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape.dim0, shape.dim1, shape.dim2, shape.dim3 ]));
             this._inputs.push(new dnn.Parameter(input.name, [ arg(input.name, type) ]));
         }
         for (const output of model.output) {
             const shape = output.shape;
-            const type = new dnn.TensorType('?', new dnn.TensorShape([ shape.dim0.toNumber(), shape.dim1.toNumber(), shape.dim2.toNumber(), shape.dim3.toNumber() ]));
+            const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape.dim0, shape.dim1, shape.dim2, shape.dim3 ]));
             this._outputs.push(new dnn.Parameter(output.name, [ arg(output.name, type) ]));
         }
         if (this._inputs.length === 0 && model.input_name && model.input_shape && model.input_shape.length === model.input_name.length * 4) {
             for (let i = 0; i < model.input_name.length; i++) {
                 const name = model.input_name[i];
                 const shape = model.input_shape.slice(i * 4, (i * 4 + 4));
-                const type = new dnn.TensorType('?', new dnn.TensorShape([ shape[1].toNumber(), shape[3].toNumber(), shape[2].toNumber(), shape[0].toNumber() ]));
+                const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape[1], shape[3], shape[2], shape[0] ]));
                 this._inputs.push(new dnn.Parameter(name, [ arg(name, type) ]));
             }
         }
@@ -110,7 +112,7 @@ dnn.Graph = class {
             model.node.length > 0 && model.node[0].input.length > 0) {
             const name = model.node[0].input[0];
             const shape = model.input_shape;
-            const type = new dnn.TensorType('?', new dnn.TensorShape([ shape[1].toNumber(), shape[3].toNumber(), shape[2].toNumber(), shape[0].toNumber() ]));
+            const type = new dnn.TensorType('float32', new dnn.TensorShape([ shape[1], shape[3], shape[2], shape[0] ]));
             this._inputs.push(new dnn.Parameter(name, [ arg(name, type) ]));
         }
 
@@ -154,13 +156,14 @@ dnn.Parameter = class {
 
 dnn.Argument = class {
 
-    constructor(name, type, initializer) {
+    constructor(name, type, initializer, quantization) {
         if (typeof name !== 'string') {
             throw new dnn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
         }
         this._name = name;
         this._type = type || null;
         this._initializer = initializer || null;
+        this._quantization = quantization || null;
     }
 
     get name() {
@@ -171,6 +174,13 @@ dnn.Argument = class {
         return this._type;
     }
 
+    get quantization() {
+        if (this._quantization) {
+            return this._quantization.map((value, index) => index.toString() + ' = ' + value.toString()).join('; ');
+        }
+        return null;
+    }
+
     get initializer() {
         return this._initializer;
     }
@@ -189,8 +199,17 @@ dnn.Node = class {
 
         const inputs = node.input.map((input) => { return arg(input); });
         for (const weight of layer.weight) {
-            const initializer = new dnn.Tensor(weight);
-            inputs.push(new dnn.Argument('', initializer.type, initializer));
+            let quantization = null;
+            if (layer.is_quantized && weight === layer.weight[0] && layer.quantization && layer.quantization.data) {
+                const data = layer.quantization.data;
+                quantization = new Array(data.length >> 2);
+                const view = new DataView(data.buffer, data.byteOffset, data.byteLength);
+                for (let i = 0; i < quantization.length; i++) {
+                    quantization[i] = view.getFloat32(i << 2, true);
+                }
+            }
+            const initializer = new dnn.Tensor(weight, quantization);
+            inputs.push(new dnn.Argument('', initializer.type, initializer, quantization));
         }
         const outputs = node.output.map((output) => { return arg(output); });
 
@@ -220,8 +239,16 @@ dnn.Node = class {
         }
 
         for (const key of Object.keys(layer)) {
-            if (key !== 'name' && key !== 'type' && key !== 'weight') {
-                this._attributes.push(new dnn.Attribute(metadata.attribute(this._type, key), key, layer[key]));
+            switch (key) {
+                case 'name':
+                case 'type':
+                case 'weight':
+                case 'is_quantized':
+                case 'quantization':
+                    break;
+                default:
+                    this._attributes.push(new dnn.Attribute(metadata.attribute(this._type, key), key, layer[key]));
+                    break;
             }
         }
     }
@@ -269,9 +296,9 @@ dnn.Attribute = class {
 
 dnn.Tensor = class {
 
-    constructor(weight) {
-        const shape = new dnn.TensorShape([ weight.dim0.toNumber(), weight.dim1.toNumber(), weight.dim2.toNumber(), weight.dim3.toNumber() ]);
-        this._data = weight.data1.length > 0 ? weight.data1 : weight.data2;
+    constructor(weight, quantization) {
+        const shape = new dnn.TensorShape([ weight.dim0, weight.dim1, weight.dim2, weight.dim3 ]);
+        this._data = quantization ? weight.quantized_data : weight.data;
 
         const size = shape.dimensions.reduce((a, b) => a * b, 1);
         const itemSize = Math.floor(this._data.length / size);

+ 39 - 39
tools/dnn.proto

@@ -5,8 +5,8 @@ package dnn;
 
 message Model {
     string name = 1;
-    int64 version = 2;
-    repeated int64 input_shape = 4;
+    int32 version = 2;
+    repeated int32 input_shape = 4;
     repeated string input_name = 7;
     repeated Node node = 10;
     repeated Parameter input = 12;
@@ -20,10 +20,10 @@ message Parameter {
 }
 
 message Shape {
-    int64 dim0 = 1;
-    int64 dim1 = 2;
-    int64 dim2 = 3;
-    int64 dim3 = 4;
+    int32 dim0 = 1;
+    int32 dim1 = 2;
+    int32 dim2 = 3;
+    int32 dim3 = 4;
 }
 
 message Node {
@@ -35,34 +35,34 @@ message Node {
 message Layer {
     string name = 1;
     string type = 2;
-    uint64 a003 = 3;   // conv
-    uint64 a007 = 7;   // pool
-    uint64 a008 = 8;   // pool
-    uint64 a009 = 9;   // conv
-    uint64 a010 = 10;  // conv, pool
-    uint64 a011 = 11;  // pool
-    float a014 = 14;   // linear
-    float a015 = 15;   // linear
+    int32 filters = 3;
+    int32 a007 = 7;   // pool
+    int32 a008 = 8;   // pool
+    int32 groups = 9;
+    int32 a010 = 10;  // conv, pool
+    int32 a011 = 11;  // pool
+    float slope = 14;   // linear
+    float intercept = 15;   // linear
     repeated Tensor weight = 50;
-    uint64 operation = 72;
-    uint64 axis = 65;  // concat
-    uint64 a077 = 77;  // conv
-    float a079 = 79;   // resize
-    uint64 a080 = 80;  // pad
-    uint64 a081 = 81;  // pad
-    uint64 a082 = 82;  // pad
-    uint64 a083 = 83;  // pad
-    uint64 a084 = 84;  // pad
-    uint64 a085 = 85;  // resize
-    uint64 a090 = 90;  // pool
-    uint64 a101 = 101; // [conv]
-    Buffer a104 = 104; // [conv]
-    uint64 a109 = 109; // conv
-    uint64 a110 = 110; // [conv]
-    uint64 a111 = 111; // conv
-    uint64 a112 = 112; // conv
-    uint64 a115 = 115; // conv
-    uint64 a116 = 116; // [conv]
+    int32 operation = 72;
+    int32 axis = 65;  // concat
+    int32 a077 = 77;  // conv
+    float scale = 79;  // resize
+    int32 pad_1 = 80;  // pad
+    int32 pad_2 = 81;  // pad
+    int32 pad_3 = 82;  // pad
+    int32 pad_4 = 83;  // pad
+    int32 pad_5 = 84;  // pad
+    int32 a085 = 85;  // resize
+    int32 a090 = 90;  // pool
+    bool is_quantized = 101;
+    Buffer quantization = 104;
+    int32 stride_w = 109;
+    int32 stride_h = 110;
+    int32 kernel_w = 111;
+    int32 kernel_h = 112;
+    int32 a115 = 115; // conv
+    int32 a116 = 116; // [conv]
 }
 
 message Buffer {
@@ -70,10 +70,10 @@ message Buffer {
 }
 
 message Tensor {
-    int64 dim0 = 1;
-    int64 dim1 = 2;
-    int64 dim2 = 3;
-    int64 dim3 = 4;
-    bytes data1 = 5;
-    bytes data2 = 6;
+    int32 dim0 = 1;
+    int32 dim1 = 2;
+    int32 dim2 = 3;
+    int32 dim3 = 4;
+    bytes data = 5;
+    bytes quantized_data = 6;
 }