瀏覽代碼

Update rknn.js

Lutz Roeder 4 年之前
父節點
當前提交
25f3b9b633
共有 1 個文件被更改,包括 15 次插入33 次删除
  1. 15 33
      source/rknn.js

+ 15 - 33
source/rknn.js

@@ -77,6 +77,13 @@ rknn.Graph = class {
             const argument = new rknn.Argument(name, type, null);
             args.set(name, argument);
         }
+        const arg = (name) => {
+            if (!args.has(name)) {
+                const argument = new rknn.Argument(name, null, null);
+                args.set(name, argument);
+            }
+            return args.get(name);
+        };
 
         for (const node of model.nodes) {
             node.input = [];
@@ -98,11 +105,8 @@ rknn.Graph = class {
 
         for (const graph of model.graph) {
             const key = graph.right + ':' + graph.right_tensor_id.toString();
-            const argument = args.get(key);
-            if (!argument) {
-                throw new rknn.Error("Invalid argument '" + key + "'.");
-            }
-            const name = graph.left + ((graph.left_tensor_id === 0) ? '' : graph.left_tensor_id.toString());
+            const argument = arg(key);
+            const name = graph.left + (graph.left_tensor_id === 0 ? '' : graph.left_tensor_id.toString());
             const parameter = new rknn.Parameter(name, [ argument ]);
             switch (graph.left) {
                 case 'input': {
@@ -116,9 +120,7 @@ rknn.Graph = class {
             }
         }
 
-        for (const node of model.nodes) {
-            this._nodes.push(new rknn.Node(metadata, node, args));
-        }
+        this._nodes = model.nodes.map((node) => new rknn.Node(metadata, node, arg));
     }
 
     get name() {
@@ -184,7 +186,7 @@ rknn.Argument = class {
 
 rknn.Node = class {
 
-    constructor(metadata, node, args) {
+    constructor(metadata, node, arg) {
         this._name = node.name || '';
         this._metadata = metadata.type(node.op);
         this._type = node.op;
@@ -200,20 +202,10 @@ rknn.Node = class {
             const count = input.list ? node.input.length - i : 1;
             const list = node.input.slice(i, i + count).map((input) => {
                 if (input.right_tensor) {
-                    const key = input.right_tensor.type + ':' + input.right_tensor.tensor_id.toString();
-                    const argument = args.get(key);
-                    if (!argument) {
-                        throw new rknn.Error("Invalid input argument '" + key + "'.");
-                    }
-                    return argument;
+                    return arg(input.right_tensor.type + ':' + input.right_tensor.tensor_id.toString());
                 }
                 if (input.right_node) {
-                    const key = input.right_node.node_id.toString() + ':' + input.right_node.tensor_id.toString();
-                    const argument = args.get(key);
-                    if (!argument) {
-                        throw new rknn.Error("Invalid input argument '" + key + "'.");
-                    }
-                    return argument;
+                    return arg(input.right_node.node_id.toString() + ':' + input.right_node.tensor_id.toString());
                 }
                 throw new rknn.Error('Invalid input argument.');
             });
@@ -226,20 +218,10 @@ rknn.Node = class {
             const count = output.list ? node.output.length - i : 1;
             const list = node.output.slice(i, i + count).map((output) => {
                 if (output.right_tensor) {
-                    const key = output.right_tensor.type + ':' + output.right_tensor.tensor_id.toString();
-                    const argument = args.get(key);
-                    if (!argument) {
-                        throw new rknn.Error("Invalid output argument '" + key + "'.");
-                    }
-                    return argument;
+                    return arg(output.right_tensor.type + ':' + output.right_tensor.tensor_id.toString());
                 }
                 if (output.right_node) {
-                    const key = output.right_node.node_id.toString() + ':' + output.right_node.tensor_id.toString();
-                    const argument = args.get(key);
-                    if (!argument) {
-                        throw new rknn.Error("Invalid output argument '" + key + "'.");
-                    }
-                    return argument;
+                    return arg(output.right_node.node_id.toString() + ':' + output.right_node.tensor_id.toString());
                 }
                 throw new rknn.Error('Invalid output argument.');
             });