Lutz Roeder преди 4 години
родител
ревизия
11a6a076d4
променени са 3 файла, в които са добавени 115 реда и са изтрити 64 реда
  1. 49 0
      source/rknn-metadata.json
  2. 59 61
      source/rknn.js
  3. 7 3
      test/models.json

+ 49 - 0
source/rknn-metadata.json

@@ -22,6 +22,15 @@
       { "name": "bias" }
     ]
   },
+  {
+    "name": "VSI_NN_OP_CONV_RELU_POOL",
+    "category": "Layer",
+    "inputs": [
+      { "name": "input" },
+      { "name": "weights" },
+      { "name": "bias" }
+    ]
+  },
   {
     "name": "VSI_NN_OP_CONV2D",
     "category": "Layer",
@@ -49,6 +58,15 @@
       { "name": "bias" }
     ]
   },
+  {
+    "name": "VSI_NN_OP_FCL_RELU",
+    "category": "Layer",
+    "inputs": [
+      { "name": "input" },
+      { "name": "weights" },
+      { "name": "bias" }
+    ]
+  },
   {
     "name": "VSI_NN_OP_LEAKY_RELU",
     "category": "Activation"
@@ -69,6 +87,14 @@
     "name": "VSI_NN_OP_RELU",
     "category": "Activation"
   },
+  {
+    "name": "VSI_NN_OP_PRELU",
+    "category": "Activation",
+    "inputs": [
+      { "name": "input" },
+      { "name": "slope" }
+    ]
+  },
   {
     "name": "VSI_NN_OP_RESHAPE",
     "category": "Shape"
@@ -76,5 +102,28 @@
   {
     "name": "VSI_NN_OP_SIGMOID",
     "category": "Activation"
+  },
+  {
+    "name": "VSI_NN_OP_SOFTMAX",
+    "category": "Activation"
+  },
+  {
+    "name": "VSI_NN_OP_RESIZE",
+    "category": "Shape"
+  },
+  {
+    "name": "VSI_NN_OP_LRN",
+    "category": "NORMALIZATION"
+  },
+  {
+    "name": "VSI_NN_OP_BATCH_NORM",
+    "category": "NORMALIZATION",
+    "inputs": [
+      { "name": "input" },
+      { "name": "gamma" },
+      { "name": "beta" },
+      { "name": "mean" },
+      { "name": "variance" }
+    ]
   }
 ]

+ 59 - 61
source/rknn.js

@@ -5,13 +5,13 @@ var json = json || require('./json');
 rknn.ModelFactory = class {
 
     match(context) {
-        return rknn.Container.open(context);
+        return rknn.Reader.open(context);
     }
 
     open(context, match) {
         return rknn.Metadata.open(context).then((metadata) => {
-            const container = match;
-            return new rknn.Model(metadata, container.model, container.weights);
+            const reader = match;
+            return new rknn.Model(metadata, reader.model, reader.weights);
         });
     }
 };
@@ -401,10 +401,15 @@ rknn.TensorType = class {
             case 'int64':
             case 'float16':
             case 'float32':
+            case 'vdata':
                 this._dataType = type;
                 break;
             default:
-                throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
+                if (dataType.vx_type !== '') {
+                    throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
+                }
+                this._dataType = '?';
+                break;
         }
         this._shape = shape;
     }
@@ -440,74 +445,68 @@ rknn.TensorShape = class {
     }
 };
 
-rknn.Container = class {
+rknn.Reader = class {
 
     static open(context) {
         const stream = context.stream;
-        const signature = [ 0x52, 0x4B, 0x4E, 0x4E, 0x00, 0x00, 0x00, 0x00 ];
-        if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
-            return new rknn.Container(stream);
+        if (stream.length >= 8) {
+            const buffer = stream.read(8);
+            const decoder = new TextDecoder();
+            const signature = decoder.decode(buffer);
+            if (signature === 'RKNN\0\0\0\0' || signature === 'CYPTRKNN') {
+                return new rknn.Reader(stream, signature);
+            }
         }
         return null;
     }
 
-    constructor(stream) {
-        this._reader = new rknn.Container.StreamReader(stream);
+    constructor(stream, signature) {
+        this._stream = stream;
+        this._signature = signature;
     }
 
     get version() {
-        this._read();
+        this._decode();
         return this._version;
     }
 
     get weights() {
-        this._read();
+        this._decode();
         return this._weights;
     }
 
     get model() {
-        this._read();
+        this._decode();
         return this._model;
     }
 
-    _read() {
-        if (this._reader) {
-            this._reader.uint64();
-            this._version = this._reader.uint64();
-            this._weights = this._reader.read();
-            const buffer = this._reader.read();
-            const reader = json.TextReader.open(buffer);
+    _decode() {
+        if (this._stream) {
+            if (this._signature === 'CYPTRKNN') {
+                throw new rknn.Error('Invalid file content. File contains undocumented encrypted RKNN data.');
+            }
+            this._version = this._uint64();
+            const weights_size = this._uint64();
+            if (this._version > 1) {
+                this._stream.read(40);
+            }
+            this._weights = this._stream.read(weights_size);
+            const model_size = this._uint64();
+            const model_buffer = this._stream.read(model_size);
+            const reader = json.TextReader.open(model_buffer);
             this._model = reader.read();
-            delete this._reader;
+            delete this._stream;
         }
     }
-};
-
-rknn.Container.StreamReader = class {
 
-    constructor(stream) {
-        this._stream = stream;
-        this._length = stream.length;
-        this._position = 0;
-    }
-
-    skip(offset) {
-        this._position += offset;
-        if (this._position > this._length) {
-            throw new rknn.Error('Expected ' + (this._position - this._length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
-        }
-    }
-
-    uint64() {
-        this.skip(8);
+    _uint64() {
         const buffer = this._stream.read(8);
         const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
         return view.getUint64(0, true).toNumber();
     }
 
-    read() {
-        const size = this.uint64();
-        this.skip(size);
+    _read() {
+        const size = this._uint64();
         return this._stream.read(size);
     }
 };
@@ -528,36 +527,35 @@ rknn.Metadata = class {
     }
 
     constructor(data) {
-        this._map = new Map();
+        this._types = new Map();
+        this._attributes = new Map();
         if (data) {
-            const metadata = JSON.parse(data);
-            this._map = new Map(metadata.map((item) => [ item.name, item ]));
+            const items = JSON.parse(data);
+            for (const item of items) {
+                this._types.set(item.name, item);
+            }
         }
     }
 
     type(name) {
-        return this._map.has(name) ? this._map.get(name) : null;
+        if (!this._types.has(name)) {
+            this._types.set(name, { name: name });
+        }
+        return this._types.get(name);
     }
 
     attribute(type, name) {
-        const schema = this.type(type);
-        if (schema) {
-            let attributeMap = schema.attributeMap;
-            if (!attributeMap) {
-                attributeMap = {};
-                if (schema.attributes) {
-                    for (const attribute of schema.attributes) {
-                        attributeMap[attribute.name] = attribute;
-                    }
+        const key = type + ':' + name;
+        if (!this._attributes.has(key)) {
+            this._attributes.set(key, null);
+            const metadata = this.type(type);
+            if (metadata && Array.isArray(metadata.attributes)) {
+                for (const attribute of metadata.attributes) {
+                    this._attributes.set(type + ':' + attribute.name, attribute);
                 }
-                schema.attributeMap = attributeMap;
-            }
-            const attributeSchema = attributeMap[name];
-            if (attributeSchema) {
-                return attributeSchema;
             }
         }
-        return null;
+        return this._attributes.get(key);
     }
 };
 

+ 7 - 3
test/models.json

@@ -5100,9 +5100,6 @@
     "format": "TorchScript v1.5",
     "link":   "https://github.com/lutzroeder/netron/issues/842"
   },
-
-
-  
   {
     "type":   "rknn",
     "target": "autopilot.rknn",
@@ -5131,6 +5128,13 @@
     "format": "RKNN v1.3.2",
     "link":   "https://github.com/lutzroeder/netron/issues/639"
   },
+  {
+    "type":   "rknn",
+    "target": "resnet18_for_rk356x.rknn",
+    "source": "https://github.com/rockchip-linux/rknn-toolkit2/raw/master/rknn-toolkit-lite2-v1.2.0/examples/inference_with_lite/resnet18_for_rk356x.rknn",
+    "format": "RKNN v1.2.0",
+    "link":   "https://github.com/rockchip-linux/rknn-toolkit2"
+  },
   {
     "type":   "sklearn",
     "target": "best_boston.pb",