Browse Source

Add RKNN support (#639)

Lutz Roeder 5 years ago
parent
commit
67d1dd6753
3 changed files with 343 additions and 17 deletions
  1. 62 0
      source/rknn-metadata.json
  2. 259 16
      source/rknn.js
  3. 22 1
      test/models.json

+ 62 - 0
source/rknn-metadata.json

@@ -0,0 +1,62 @@
+[
+  {
+    "name": "VSI_NN_OP_CONV2D",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "VSI_NN_OP_DECONVOLUTION",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "VSI_NN_OP_RELU",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "VSI_NN_OP_LEAKY_RELU",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "VSI_NN_OP_LEAKY_SIGMOID",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "VSI_NN_OP_POOL",
+    "schema": {
+      "category": "Pool"
+    }
+  },
+  {
+    "name": "VSI_NN_OP_FCL",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "VSI_NN_OP_RESHAPE",
+    "schema": {
+      "category": "Shape" 
+    }
+  },
+  {
+    "name": "VSI_NN_OP_PERMUTE",
+    "schema": {
+      "category": "Shape" 
+    }
+  },
+  {
+    "name": "VSI_NN_OP_CONCAT",
+    "schema": {
+      "category": "Tensor"
+    }
+  }
+]

+ 259 - 16
source/rknn.js

@@ -14,20 +14,20 @@ rknn.ModelFactory = class {
     }
 
     open(context, host) {
-        return Promise.resolve().then(() => {
+        return rknn.Metadata.open(host).then((metadata) => {
             const container = new rknn.Container(context.buffer);
-            return new rknn.Model(container.configuration);
+            return new rknn.Model(metadata, container.model);
         });
     }
 };
 
 rknn.Model = class {
 
-    constructor(configuration) {
-        this._version = configuration.version;
-        this._producer = configuration.ori_network_platform || configuration.network_platform || '';
-        this._runtime = configuration.target_platform ? configuration.target_platform.join(',') : '';
-        this._graphs = [ new rknn.Graph(configuration) ];
+    constructor(metadata, model) {
+        this._version = model.version;
+        this._producer = model.ori_network_platform || model.network_platform || '';
+        this._runtime = model.target_platform ? model.target_platform.join(',') : '';
+        this._graphs = [ new rknn.Graph(metadata, model) ];
     }
 
     get format() {
@@ -49,14 +49,71 @@ rknn.Model = class {
 
 rknn.Graph = class {
 
-    constructor(configuration) {
-        this._name = configuration.name || '';
+    constructor(metadata, model) {
+        this._name = model.name || '';
         this._inputs = [];
         this._outputs = [];
         this._nodes = [];
 
-        for (const node of configuration.nodes) {
-            this._nodes.push(new rknn.Node(node));
+        const args = new Map();
+        for (const const_tensor of model.const_tensor) {
+            const name = 'const_tensor:' + const_tensor.tensor_id.toString();
+            const shape = new rknn.TensorShape(const_tensor.size);
+            const type = new rknn.TensorType(const_tensor.dtype, shape);
+            const tensor = new rknn.Tensor(type);
+            const argument = new rknn.Argument(name, type, tensor);
+            args.set(name, argument);
+        }
+        for (const virtual_tensor of model.virtual_tensor) {
+            const name = virtual_tensor.node_id.toString() + ':' + virtual_tensor.output_port.toString();
+            const argument = new rknn.Argument(name, null, null);
+            args.set(name, argument);
+        }
+        for (const norm_tensor of model.norm_tensor) {
+            const name = 'norm_tensor:' + norm_tensor.tensor_id.toString();
+            const shape = new rknn.TensorShape(norm_tensor.size);
+            const type = new rknn.TensorType(norm_tensor.dtype, shape);
+            const argument = new rknn.Argument(name, type, null);
+            args.set(name, argument);
+        }
+
+        for (const node of model.nodes) {
+            node.input = [];
+            node.output = [];
+        }
+        for (const connection of model.connection) {
+            switch (connection.left) {
+                case 'input':
+                    model.nodes[connection.node_id].input.push(connection);
+                    if (connection.right_node) {
+                        model.nodes[connection.right_node.node_id].output[connection.right_node.tensor_id] = connection;
+                    }
+                    break;
+                case 'output':
+                    model.nodes[connection.node_id].output.push(connection);
+                    break;
+            }
+        }
+
+        for (const graph of model.graph) {
+            const key = graph.right + ':' + graph.right_tensor_id.toString();
+            const argument = args.get(key);
+            const name = graph.left + ((graph.left_tensor_id === 0) ? '' : graph.left_tensor_id.toString());
+            const parameter = new rknn.Parameter(name, [ argument ]);
+            switch (graph.left) {
+                case 'input': {
+                    this._inputs.push(parameter);
+                    break;
+                }
+                case 'output': {
+                    this._outputs.push(parameter);
+                    break;
+                }
+            }
+        }
+
+        for (const node of model.nodes) {
+            this._nodes.push(new rknn.Node(metadata, node, args));
         }
     }
 
@@ -77,15 +134,86 @@ rknn.Graph = class {
     }
 };
 
+rknn.Parameter = class {
+
+    constructor(name, args) {
+        this._name = name;
+        this._arguments = args;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get visible() {
+        return true;
+    }
+
+    get arguments() {
+        return this._arguments;
+    }
+};
+
+rknn.Argument = class {
+
+    constructor(name, type, initializer) {
+        if (typeof name !== 'string') {
+            throw new rknn.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
+        }
+        this._name = name;
+        this._type = type || null;
+        this._initializer = initializer || null;
+    }
+
+    get name() {
+        return this._name;
+    }
+
+    get type() {
+        return this._type;
+    }
+
+    get initializer() {
+        return this._initializer;
+    }
+};
+
 rknn.Node = class {
 
-    constructor(node) {
+    constructor(metadata, node, args) {
+        this._metadata = metadata;
         this._name = node.name || '';
         this._type = node.op;
         this._inputs = [];
         this._outputs = [];
         this._attributes = [];
 
+        for (const input of node.input) {
+            if (input.right_tensor) {
+                const name = input.right_tensor.type + ':' + input.right_tensor.tensor_id.toString();
+                const argument = args.get(name);
+                this._inputs.push(new rknn.Parameter('', [ argument ]));
+            }
+            if (input.right_node) {
+                const name = input.right_node.node_id.toString() + ':' + input.right_node.tensor_id.toString();
+                const argument = args.get(name);
+                this._inputs.push(new rknn.Parameter('', [ argument ]));
+            }
+        }
+
+        for (const output of node.output) {
+            if (output.right_tensor) {
+                const name = output.right_tensor.type + ':' + output.right_tensor.tensor_id.toString();
+                const argument = args.get(name);
+                this._outputs.push(new rknn.Parameter('', [ argument ]));
+            }
+            if (output.right_node) {
+                const name = output.right_node.node_id.toString() + ':' + output.right_node.tensor_id.toString();
+                const argument = args.get(name);
+                this._outputs.push(new rknn.Parameter('', [ argument ]));
+            }
+        }
+
         if (node.nn) {
             const nn = node.nn;
             for (const key of Object.keys(nn)) {
@@ -103,7 +231,12 @@ rknn.Node = class {
     }
 
     get type() {
-        return this._type;
+        const prefix = 'VSI_NN_OP_';
+        return this._type.startsWith(prefix) ? this._type.substring(prefix.length) : this.type;
+    }
+
+    get metadata() {
+        return this._metadata.type(this._type);
     }
 
     get inputs() {
@@ -135,6 +268,62 @@ rknn.Attribute = class {
     }
 };
 
+rknn.Tensor = class {
+
+    constructor(type) {
+        this._type = type;
+    }
+
+    get type() {
+        return this._type;
+    }
+};
+
+rknn.TensorType = class {
+
+    constructor(dataType, shape) {
+        switch (dataType.vx_type) {
+            case 'VSI_NN_TYPE_UINT8': this._dataType = 'uint8'; break;
+            case 'VSI_NN_TYPE_INT32': this._dataType = 'int32'; break;
+            case 'VSI_NN_TYPE_FLOAT16': this._dataType = 'float16'; break;
+            case 'VSI_NN_TYPE_FLOAT32': this._dataType = 'float32'; break;
+            default:
+                throw new rknn.Error("Invalid data type '" + JSON.stringify(dataType) + "'.");
+        }
+        this._shape = shape;
+    }
+
+    get dataType() {
+        return this._dataType;
+    }
+
+    get shape() {
+        return this._shape;
+    }
+
+    toString() {
+        return this.dataType + this._shape.toString();
+    }
+};
+
+rknn.TensorShape = class {
+
+    constructor(shape) {
+        this._dimensions = shape;
+    }
+
+    get dimensions() {
+        return this._dimensions;
+    }
+
+    toString() {
+        if (!this._dimensions || this._dimensions.length == 0) {
+            return '';
+        }
+        return '[' + this._dimensions.join(',') + ']';
+    }
+};
+
 rknn.Container = class {
 
     constructor(buffer) {
@@ -153,15 +342,69 @@ rknn.Container = class {
             position += size;
         }
         const reader = json.TextReader.create(blocks[1]);
-        this._configuration = reader.read();
+        this._model = reader.read();
     }
 
     get version() {
         return this._version;
     }
 
-    get configuration() {
-        return this._configuration;
+    get model() {
+        return this._model;
+    }
+};
+
+rknn.Metadata = class {
+
+    static open(host) {
+        if (rknn.Metadata._metadata) {
+            return Promise.resolve(rknn.Metadata._metadata);
+        }
+        return host.request(null, 'rknn-metadata.json', 'utf-8').then((data) => {
+            rknn.Metadata._metadata = new rknn.Metadata(data);
+            return rknn.Metadata._metadata;
+        }).catch(() => {
+            rknn.Metadata._metadata = new rknn.Metadata(null);
+            return rknn.Metadata._metadata;
+        });
+    }
+
+    constructor(data) {
+        this._map = new Map();
+        if (data) {
+            const items = JSON.parse(data);
+            if (items) {
+                for (const item of items) {
+                    item.schema.name = item.name;
+                    this._map.set(item.name, item.schema);
+                }
+            }
+        }
+    }
+
+    type(name) {
+        return this._map.has(name) ? this._map.get(name) : null;
+    }
+
+    attribute(type, name) {
+        const schema = this.type(type);
+        if (schema) {
+            let attributeMap = schema.attributeMap;
+            if (!attributeMap) {
+                attributeMap = {};
+                if (schema.attributes) {
+                    for (const attribute of schema.attributes) {
+                        attributeMap[attribute.name] = attribute;
+                    }
+                }
+                schema.attributeMap = attributeMap;
+            }
+            const attributeSchema = attributeMap[name];
+            if (attributeSchema) {
+                return attributeSchema;
+            }
+        }
+        return null;
     }
 };
 

+ 22 - 1
test/models.json

@@ -4642,10 +4642,31 @@
     "source": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
     "format": "PyTorch v0.1.10"
   },
+  {
+    "type":   "rknn",
+    "target": "autopilot.rknn",
+    "source": "https://github.com/lutzroeder/netron/files/5621074/autopilot.rknn.zip[autopilot.rknn]",
+    "format": "RKNN v1.3.0",
+    "link":   "https://github.com/lutzroeder/netron/issues/639"
+  },
+  {
+    "type":   "rknn",
+    "target": "deepfusion.rknn",
+    "source": "https://github.com/lutzroeder/netron/files/5621075/deepfusion.rknn.zip[deepfusion.rknn]",
+    "format": "RKNN v1.2.1",
+    "link":   "https://github.com/lutzroeder/netron/issues/639"
+  },
+  {
+    "type":   "rknn",
+    "target": "kindnet.rknn",
+    "source": "https://github.com/lutzroeder/netron/files/5621076/kindnet.rknn.zip[kindnet.rknn]",
+    "format": "RKNN v1.3.2",
+    "link":   "https://github.com/lutzroeder/netron/issues/639"
+  },
   {
     "type":   "rknn",
     "target": "resnet_18.rknn",
-    "source": "https://github.com/lutzroeder/netron/files/5615383/resnet_18.rknn.zip[resnet_18.rknn]",
+    "source": "https://github.com/lutzroeder/netron/files/5621078/resnet_18.rknn.zip[resnet_18.rknn]",
     "format": "RKNN v1.3.2",
     "link":   "https://github.com/lutzroeder/netron/issues/639"
   },