Parcourir la source

Update rknn.js (#734)

Lutz Roeder il y a 4 ans
Parent
commit
f399baeb3f
1 fichiers modifiés avec 19 ajouts et 15 suppressions
  1. 19 15
      source/rknn.js

+ 19 - 15
source/rknn.js

@@ -185,16 +185,18 @@ rknn.Argument = class {
 rknn.Node = class {
 
     constructor(metadata, node, args) {
-        this._metadata = metadata;
         this._name = node.name || '';
+        this._metadata = metadata.type(node.op);
         this._type = node.op;
+        for (const prefix of [ 'VSI_NN_OP_', 'RKNN_OP_' ]) {
+            this._type = this._type.startsWith(prefix) ? this._type.substring(prefix.length) : this._type;
+        }
         this._inputs = [];
         this._outputs = [];
         this._attributes = [];
-        const schema = this._metadata.type(this._type);
         node.input = node.input || [];
         for (let i = 0; i < node.input.length; ) {
-            const input = schema && schema.inputs && i < schema.inputs.length ? schema.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
+            const input = this._metadata && this._metadata.inputs && i < this._metadata.inputs.length ? this._metadata.inputs[i] : { name: i === 0 ? 'input' : i.toString() };
             const count = input.list ? node.input.length - i : 1;
             const list = node.input.slice(i, i + count).map((input) => {
                 if (input.right_tensor) {
@@ -220,7 +222,7 @@ rknn.Node = class {
         }
         node.output = node.output || [];
         for (let i = 0; i < node.output.length; ) {
-            const output = schema && schema.outputs && i < schema.outputs.length ? schema.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
+            const output = this._metadata && this._metadata.outputs && i < this._metadata.outputs.length ? this._metadata.outputs[i] : { name: i === 0 ? 'output' : i.toString() };
             const count = output.list ? node.output.length - i : 1;
             const list = node.output.slice(i, i + count).map((output) => {
                 if (output.right_tensor) {
@@ -261,12 +263,11 @@ rknn.Node = class {
     }
 
     get type() {
-        const prefix = 'VSI_NN_OP_';
-        return this._type.startsWith(prefix) ? this._type.substring(prefix.length) : this._type;
+        return this._type;
     }
 
     get metadata() {
-        return this._metadata.type(this._type);
+        return this._metadata;
     }
 
     get inputs() {
@@ -420,14 +421,17 @@ rknn.Tensor = class {
 rknn.TensorType = class {
 
     constructor(dataType, shape) {
-        switch (dataType.vx_type) {
-            case 'VSI_NN_TYPE_UINT8': this._dataType = 'uint8'; break;
-            case 'VSI_NN_TYPE_INT8': this._dataType = 'int8'; break;
-            case 'VSI_NN_TYPE_INT16': this._dataType = 'int16'; break;
-            case 'VSI_NN_TYPE_INT32': this._dataType = 'int32'; break;
-            case 'VSI_NN_TYPE_INT64': this._dataType = 'int64'; break;
-            case 'VSI_NN_TYPE_FLOAT16': this._dataType = 'float16'; break;
-            case 'VSI_NN_TYPE_FLOAT32': this._dataType = 'float32'; break;
+        const type = dataType.vx_type.startsWith('VSI_NN_TYPE_') ? dataType.vx_type.split('_').pop().toLowerCase() : dataType.vx_type;
+        switch (type) {
+            case 'uint8':
+            case 'int8':
+            case 'int16':
+            case 'int32':
+            case 'int64':
+            case 'float16':
+            case 'float32':
+                this._dataType = type;
+                break;
             default:
                 throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
         }