Przeglądaj źródła

Update rknn.js (#639)

Lutz Roeder 5 miesięcy temu
rodzic
commit
269c973d09
2 zmienionych plików z 34 dodań i 32 usunięć
  1. 8 0
      source/rknn-metadata.json
  2. 26 32
      source/rknn.js

+ 8 - 0
source/rknn-metadata.json

@@ -363,5 +363,13 @@
   {
     "name": "Slice",
     "category": "Tensor"
+  },
+  {
+    "name": "Pad",
+    "category": "Tensor"
+  },
+  {
+    "name": "exSoftmax13",
+    "category": "Activation"
   }
 ]

+ 26 - 32
source/rknn.js

@@ -111,10 +111,15 @@ rknn.Graph = class {
                 for (const const_tensor of model.const_tensor) {
                     const name = `const_tensor:${const_tensor.tensor_id}`;
                     const shape = new rknn.TensorShape(const_tensor.size);
-                    const type = new rknn.TensorType(dataType(const_tensor.dtype), shape);
-                    const tensor = new rknn.Tensor(type, const_tensor.offset, null);
-                    const value = new rknn.Value(name, type, tensor);
-                    values.set(name, value);
+                    if (const_tensor.data_type === 0) {
+                        const value = new rknn.Value(name, null, null);
+                        values.set(name, value);
+                    } else {
+                        const type = new rknn.TensorType(dataType(const_tensor.dtype), shape);
+                        const tensor = new rknn.Tensor(type, const_tensor.offset, null);
+                        const value = new rknn.Value(name, type, tensor);
+                        values.set(name, value);
+                    }
                 }
                 for (const virtual_tensor of model.virtual_tensor) {
                     const name = `${virtual_tensor.node_id}:${virtual_tensor.output_port}`;
@@ -124,9 +129,14 @@ rknn.Graph = class {
                 for (const norm_tensor of model.norm_tensor) {
                     const name = `norm_tensor:${norm_tensor.tensor_id}`;
                     const shape = new rknn.TensorShape(norm_tensor.size);
-                    const type = new rknn.TensorType(dataType(norm_tensor.dtype), shape);
-                    const value = new rknn.Value(name, type, null);
-                    values.set(name, value);
+                    if (norm_tensor.dtype === 0) {
+                        const value = new rknn.Value(name, null, null);
+                        values.set(name, value);
+                    } else {
+                        const type = new rknn.TensorType(dataType(norm_tensor.dtype), shape);
+                        const value = new rknn.Value(name, type, null);
+                        values.set(name, value);
+                    }
                 }
                 const value = (name) => {
                     if (!values.has(name)) {
@@ -173,13 +183,10 @@ rknn.Graph = class {
             }
             case 'flatbuffers': {
                 const graph = obj;
-                const dataTypes = ['unk0', 'int32', '?', 'int8', '?', 'int16', 'float32', 'int64', '?', '?', 'float16', '?', '?', 'unk13', '?', '?', 'bfloat16'];
+                const dataTypes = ['undefined', 'float32', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'boolean', 'float16', 'float64', 'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16'];
                 const args = graph.tensors.map((tensor) => {
                     const shape = new rknn.TensorShape(Array.from(tensor.shape));
                     const dataType = tensor.data_type < dataTypes.length ? dataTypes[tensor.data_type] : '?';
-                    if (dataType === '?') {
-                        throw new rknn.Error(`Unsupported tensor data type '${tensor.data_type}'.`);
-                    }
                     const type = new rknn.TensorType(dataType, shape);
                     const initializer = tensor.kind !== 4 && tensor.kind !== 5 ? null : new rknn.Tensor(type, 0, null);
                     return new rknn.Value(tensor.name, type, initializer);
@@ -424,33 +431,20 @@ rknn.Container = class extends Map {
                     const uint64 = () => {
                         const buffer = stream.read(8);
                         const reader = base.BinaryReader.open(buffer);
-                        return reader.uint64().toNumber();
+                        return reader.uint64();
                     };
                     stream.skip(8);
                     const version = uint64();
-                    const data_size = uint64();
-                    switch (version) {
-                        case 0x0001:
-                        case 0x1001:
-                            break;
-                        case 0x0002:
-                        case 0x1002:
-                        case 0x0003:
-                        case 0x1003:
-                        case 0x0004:
-                        case 0x1004:
-                        case 0x0005:
-                        case 0x0006:
-                            if (data_size > 0) {
-                                stream.skip(40);
-                            }
-                            break;
-                        default:
-                            throw new rknn.Error(`Unsupported RKNN container version '${version}'.`);
+                    if ((version >> 8n) !== 0n && (version >> 8n) !== 0x10n) {
+                        throw new rknn.Error(`Unsupported RKNN container version '${version}'.`);
+                    }
+                    const data_size = uint64().toNumber();
+                    if ((version & 0xffn) > 1n && data_size > 0) {
+                        stream.skip(40);
                     }
                     const signature = rknn.Container.signature(stream, data_size);
                     const data = stream.read(data_size);
-                    const json_size = uint64();
+                    const json_size = uint64().toNumber();
                     const json = stream.read(json_size);
                     this.set('json', json);
                     if (signature) {