Explorar el Código

Update rknn.js (#639)

Lutz Roeder hace 3 años
padre
commit
65775592ff
Se han modificado 4 ficheros con 151 adiciones y 73 borrados
  1. 55 5
      source/rknn-metadata.json
  2. 13 12
      source/rknn-schema.js
  3. 68 42
      source/rknn.js
  4. 15 14
      tools/rknn.fbs

+ 55 - 5
source/rknn-metadata.json

@@ -153,15 +153,48 @@
   },
   {
     "name": "Conv",
-    "category": "Layer"
+    "category": "Layer",
+    "inputs": [
+      { "name": "input" },
+      { "name": "weights" },
+      { "name": "bias" }
+    ]
+  },
+  {
+    "name": "ConvRelu",
+    "category": "Layer",
+    "inputs": [
+      { "name": "input" },
+      { "name": "weights" },
+      { "name": "bias" }
+    ]
+  },
+  {
+    "name": "ConvClip",
+    "category": "Layer",
+    "inputs": [
+      { "name": "input" },
+      { "name": "weights" },
+      { "name": "bias" }
+    ]
   },
   {
     "name": "Concat",
-    "category": "Layer"
+    "category": "Tensor",
+    "inputs": [
+      { "name": "inputs", "list": true }
+    ]
   },
   {
     "name": "BatchNormalization",
-    "category": "Normalization"
+    "category": "Normalization",
+    "inputs": [
+      { "name": "input" },
+      { "name": "weight" },
+      { "name": "bias" },
+      { "name": "running_mean" },
+      { "name": "running_var" }
+    ]
   },
   {
     "name": "Relu",
@@ -181,15 +214,32 @@
   },
   {
     "name": "Reshape",
-    "category": "Shape"
+    "category": "Shape",
+    "inputs": [
+      { "name": "input" },
+      { "name": "shape" }
+    ]
   },
   {
     "name": "Transpose",
     "category": "Transform"
   },
+  {
+    "name": "Add",
+    "inputs": [
+      { "name": "A" },
+      { "name": "B" }
+    ],
+    "outputs": [
+      { "name": "C" }
+    ]
+  },
   {
     "name": "Split",
-    "category": "Tensor"
+    "category": "Tensor",
+    "outputs": [
+      { "name": "output", "list": true }
+    ]
   },
   {
     "name": "PoolingLayer2",

+ 13 - 12
source/rknn-schema.js

@@ -67,24 +67,25 @@ $root.rknn.Tensor = class Tensor {
 
     static decode(reader, position) {
         const $ = new $root.rknn.Tensor();
-        $.var01 = reader.int32_(position, 4, 0);
-        $.var02 = reader.int32_(position, 6, 0);
-        $.var03 = reader.int32_(position, 8, 0);
-        $.var04 = reader.int32_(position, 10, 0);
-        $.var05 = reader.int32_(position, 12, 0);
+        $.data_type = reader.int8_(position, 4, 0);
+        $.var02 = reader.int8_(position, 6, 0);
+        $.kind = reader.int8_(position, 8, 0);
+        $.var04 = reader.typedArray(position, 10, Int32Array);
+        $.shape = reader.typedArray(position, 12, Int32Array);
         $.name = reader.string_(position, 14, null);
-        $.var06 = reader.int32_(position, 16, 0);
-        $.var07 = reader.int32_(position, 18, 0);
-        $.var08 = reader.int32_(position, 20, 0);
-        $.var09 = reader.int32_(position, 22, 0);
-        $.var10 = reader.int32_(position, 24, 0);
-        $.var11 = reader.int32_(position, 26, 0);
-        $.var12 = reader.int32_(position, 28, 0);
+        $.var06 = reader.typedArray(position, 16, Int8Array);
+        $.var07 = reader.string_(position, 18, null);
+        $.var08 = reader.typedArray(position, 20, Int8Array);
+        $.var09 = reader.typedArray(position, 22, Int8Array);
+        $.var10 = reader.typedArray(position, 24, Int8Array);
+        $.var11 = reader.typedArray(position, 26, Int8Array);
+        $.size = reader.int32_(position, 28, 0);
         $.var13 = reader.int32_(position, 30, 0);
         $.var14 = reader.int32_(position, 32, 0);
         $.var15 = reader.int32_(position, 34, 0);
         $.var16 = reader.int32_(position, 36, 0);
         $.var17 = reader.int32_(position, 38, 0);
+        $.index = reader.int32_(position, 40, 0);
         return $;
     }
 };

+ 68 - 42
source/rknn.js

@@ -111,12 +111,32 @@ rknn.Graph = class {
         this._nodes = [];
         switch (type) {
             case 'json': {
+                const dataType = (value) => {
+                    const type = value.vx_type.startsWith('VSI_NN_TYPE_') ? value.vx_type.split('_').pop().toLowerCase() : value.vx_type;
+                    switch (type) {
+                        case 'uint8':
+                        case 'int8':
+                        case 'int16':
+                        case 'int32':
+                        case 'int64':
+                        case 'float16':
+                        case 'float32':
+                        case 'float64':
+                        case 'vdata':
+                            return type;
+                        default:
+                            if (value.vx_type !== '') {
+                                throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
+                            }
+                            return '?';
+                    }
+                };
                 const model = obj;
                 const args = new Map();
                 for (const const_tensor of model.const_tensor) {
                     const name = 'const_tensor:' + const_tensor.tensor_id.toString();
                     const shape = new rknn.TensorShape(const_tensor.size);
-                    const type = new rknn.TensorType(const_tensor.dtype, shape);
+                    const type = new rknn.TensorType(dataType(const_tensor.dtype), shape);
                     const tensor = new rknn.Tensor(type, const_tensor.offset, next.value);
                     const argument = new rknn.Argument(name, type, tensor);
                     args.set(name, argument);
@@ -129,7 +149,7 @@ rknn.Graph = class {
                 for (const norm_tensor of model.norm_tensor) {
                     const name = 'norm_tensor:' + norm_tensor.tensor_id.toString();
                     const shape = new rknn.TensorShape(norm_tensor.size);
-                    const type = new rknn.TensorType(norm_tensor.dtype, shape);
+                    const type = new rknn.TensorType(dataType(norm_tensor.dtype), shape);
                     const argument = new rknn.Argument(name, type, null);
                     args.set(name, argument);
                 }
@@ -180,7 +200,17 @@ rknn.Graph = class {
             }
             case 'flatbuffers': {
                 const graph = obj;
-                const args = graph.tensors.map((tensor) => new rknn.Argument(tensor.name));
+                const dataTypes = [ 'unk0', '?', '?', 'int8', '?', 'int16', 'float32', 'int64', '?', '?', 'float16', '?', '?', 'unk13' ];
+                const args = graph.tensors.map((tensor) => {
+                    const shape = new rknn.TensorShape(Array.from(tensor.shape));
+                    const dataType = tensor.data_type < dataTypes.length ? dataTypes[tensor.data_type] : '?';
+                    if (dataType === '?') {
+                        throw new rknn.Error("Unsupported tensor data type '" + tensor.data_type + "'.");
+                    }
+                    const type = new rknn.TensorType(dataType, shape);
+                    const initializer = tensor.kind !== 4 && tensor.kind !== 5 ? null : new rknn.Tensor(type, 0, null);
+                    return new rknn.Argument(tensor.name, type, initializer);
+                });
                 const arg = (index) => {
                     if (index >= args.length) {
                         throw new rknn.Error("Invalid tensor index '" + index.toString() + "'.");
@@ -265,6 +295,9 @@ rknn.Argument = class {
 rknn.Node = class {
 
     constructor(metadata, type, node, arg, next) {
+        this._inputs = [];
+        this._outputs = [];
+        this._attributes = [];
         switch (type) {
             case 'json': {
                 this._name = node.name || '';
@@ -285,9 +318,6 @@ rknn.Node = class {
                         this._type.name = this._type.name.startsWith(prefix) ? this._type.name.substring(prefix.length) : this._type.name;
                     }
                 }
-                this._inputs = [];
-                this._outputs = [];
-                this._attributes = [];
                 node.input = node.input || [];
                 for (let i = 0; i < node.input.length; ) {
                     const input = this._type && this._type.inputs && i < this._type.inputs.length ? this._type.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
@@ -335,23 +365,35 @@ rknn.Node = class {
             case 'flatbuffers': {
                 this._name = node.name;
                 this._type = metadata.type(node.type);
-                this._inputs = Array.from(node.inputs).map((input, index) => {
-                    const argument = arg(input);
-                    return new rknn.Parameter(index.toString(), [ argument ]);
-                });
-                this._outputs = Array.from(node.outputs).map((output, index) => {
-                    const argument = arg(output);
-                    return new rknn.Parameter(index.toString(), [ argument ]);
-                });
-                this._attributes = [];
+                if (node.inputs.length > 0) {
+                    const inputs = this._type.inputs || (node.inputs.length === 1 ? [ { name: "input" } ] : [ { name: "inputs", list: true } ]);
+                    if (Array.isArray(inputs) && inputs.length > 0 && inputs[0].list === true) {
+                        this._inputs = [new rknn.Parameter(inputs[0].name, Array.from(node.inputs).map((input) => arg(input))) ];
+                    }
+                    else {
+                        this._inputs = Array.from(node.inputs).map((input, index) => {
+                            const argument = arg(input);
+                            return new rknn.Parameter(index < inputs.length ? inputs[index].name : index.toString(), [ argument ]);
+                        });
+                    }
+                }
+                if (node.outputs.length > 0) {
+                    const outputs = this._type.outputs || (node.outputs.length === 1 ? [ { name: "output" } ] : [ { name: "outputs", list: true } ]);
+                    if (Array.isArray(outputs) && outputs.length > 0 && outputs[0].list === true) {
+                        this._outputs = [ new rknn.Parameter(outputs[0].name, Array.from(node.outputs).map((output) => arg(output))) ];
+                    }
+                    else {
+                        this._outputs = Array.from(node.outputs).map((output, index) => {
+                            const argument = arg(output);
+                            return new rknn.Parameter(index < outputs.length ? outputs[index].name : index.toString(), [ argument ]);
+                        });
+                    }
+                }
                 break;
             }
             case 'openvx': {
                 this._name = '';
                 this._type = metadata.type(node.type);
-                this._inputs = [];
-                this._outputs = [];
-                this._attributes = [];
                 break;
             }
             default: {
@@ -401,6 +443,7 @@ rknn.Tensor = class {
 
     constructor(type, offset, weights) {
         this._type = type;
+        this._data = null;
         let size = 0;
         switch (this._type.dataType) {
             case 'uint8': size = 1; break;
@@ -414,10 +457,12 @@ rknn.Tensor = class {
             case 'vdata': size = 1; break;
             default: throw new rknn.Error("Unsupported tensor data type '" + this._type.dataType + "'.");
         }
-        const shape = type.shape.dimensions;
-        size = size * shape.reduce((a, b) => a * b, 1);
-        if (size > 0) {
-            this._data = weights.slice(offset, offset + size);
+        if (weights) {
+            const shape = type.shape.dimensions;
+            size = size * shape.reduce((a, b) => a * b, 1);
+            if (size > 0) {
+                this._data = weights.slice(offset, offset + size);
+            }
         }
     }
 
@@ -541,26 +586,7 @@ rknn.Tensor = class {
 rknn.TensorType = class {
 
     constructor(dataType, shape) {
-        const type = dataType.vx_type.startsWith('VSI_NN_TYPE_') ? dataType.vx_type.split('_').pop().toLowerCase() : dataType.vx_type;
-        switch (type) {
-            case 'uint8':
-            case 'int8':
-            case 'int16':
-            case 'int32':
-            case 'int64':
-            case 'float16':
-            case 'float32':
-            case 'float64':
-            case 'vdata':
-                this._dataType = type;
-                break;
-            default:
-                if (dataType.vx_type !== '') {
-                    throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
-                }
-                this._dataType = '?';
-                break;
-        }
+        this._dataType = dataType;
         this._shape = shape;
     }
 

+ 15 - 14
tools/rknn.fbs

@@ -1,7 +1,7 @@
 namespace rknn;
 
-file_identifier 'RKNN'
-root_type Model
+file_identifier 'RKNN';
+root_type Model;
 
 table Model {
     var1: int;
@@ -42,24 +42,25 @@ table Node {
 }
 
 table Tensor {
-    var01: int;
-    var02: int;
-    var03: int;
-    var04: int;
-    var05: int;
+    data_type: byte;
+    var02: byte;
+    kind: byte;
+    var04: [int];
+    shape: [int];
     name: string;
-    var06: int;
-    var07: int;
-    var08: int;
-    var09: int;
-    var10: int;
-    var11: int;
-    var12: int;
+    var06: [byte];
+    var07: string;
+    var08: [byte];
+    var09: [byte];
+    var10: [byte];
+    var11: [byte];
+    size: int;
     var13: int;
     var14: int;
     var15: int;
     var16: int;
     var17: int;
+    index: int;
 }
 
 table Type1 {