Browse Source

Check argument id

Lutz Roeder 6 years ago
parent
commit
d57944c692
20 changed files with 88 additions and 33 deletions
  1. 3 0
      src/armnn.js
  2. 3 1
      src/bigdl.js
  3. 3 0
      src/caffe.js
  4. 3 0
      src/caffe2.js
  5. 3 0
      src/chainer.js
  6. 3 0
      src/coreml.js
  7. 3 0
      src/darknet.js
  8. 3 0
      src/dl4j.js
  9. 3 0
      src/keras.js
  10. 3 0
      src/mediapipe.js
  11. 3 0
      src/mlnet.js
  12. 3 0
      src/mxnet.js
  13. 3 0
      src/ncnn.js
  14. 33 32
      src/onnx.js
  15. 3 0
      src/paddle.js
  16. 3 0
      src/pytorch.js
  17. 3 0
      src/sklearn.js
  18. 3 0
      src/tf.js
  19. 3 0
      src/torch.js
  20. 1 0
      test/models.json

+ 3 - 0
src/armnn.js

@@ -319,6 +319,9 @@ armnn.Parameter = class {
 armnn.Argument = class {
 
     constructor(id, tensorInfo, initializer) {
+        if (typeof id !== 'string') {
+            throw new armnn.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         const info = initializer ? initializer.info() : tensorInfo;
         this._id = id;
         this._type = new armnn.TensorType(info);

+ 3 - 1
src/bigdl.js

@@ -146,7 +146,9 @@ bigdl.Parameter = class {
 bigdl.Argument = class {
 
     constructor(id, type, initializer) {
-        id.toString();
+        if (typeof id !== 'string') {
+            throw new bigdl.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/caffe.js

@@ -389,6 +389,9 @@ caffe.Parameter = class {
 caffe.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new caffe.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/caffe2.js

@@ -362,6 +362,9 @@ caffe2.Parameter = class {
 caffe2.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new caffe2.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/chainer.js

@@ -334,6 +334,9 @@ chainer.Parameter = class {
 chainer.Argument = class {
 
     constructor(id, initializer) {
+        if (typeof id !== 'string') {
+            throw new chainer.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._initializer = initializer || null;
     }

+ 3 - 0
src/coreml.js

@@ -526,6 +526,9 @@ coreml.Parameter = class {
 coreml.Argument = class {
 
     constructor(id, type, description, initializer) {
+        if (typeof id !== 'string') {
+            throw new coreml.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type;
         this._description = description || null;

+ 3 - 0
src/darknet.js

@@ -697,6 +697,9 @@ darknet.Parameter = class {
 darknet.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new darknet.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type;
         this._initializer = initializer;

+ 3 - 0
src/dl4j.js

@@ -199,6 +199,9 @@ dl4j.Parameter = class {
 dl4j.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new dl4j.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type;
         this._initializer = initializer;

+ 3 - 0
src/keras.js

@@ -466,6 +466,9 @@ keras.Parameter = class {
 keras.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new keras.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/mediapipe.js

@@ -256,6 +256,9 @@ mediapipe.Parameter = class {
 mediapipe.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new mediapipe.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/mlnet.js

@@ -172,6 +172,9 @@ mlnet.Parameter = class {
 mlnet.Argument = class {
 
     constructor(id, type) {
+        if (typeof id !== 'string') {
+            throw new mlnet.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type;
     }

+ 3 - 0
src/mxnet.js

@@ -566,6 +566,9 @@ mxnet.Parameter = class {
 mxnet.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new mxnet.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/ncnn.js

@@ -273,6 +273,9 @@ ncnn.Parameter = class {
 ncnn.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new ncnn.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 33 - 32
src/onnx.js

@@ -258,35 +258,37 @@ onnx.Graph = class {
             this._name = graph.name || null;
             this._description = graph.doc_string || '';
 
-            let initializers = {};
+            let initializers = new Map();
             for (const tensor of graph.initializer) {
-                initializers[tensor.name] = new onnx.Tensor(tensor, 'Initializer');
+                initializers.set(tensor.name, new onnx.Tensor(tensor, 'Initializer'));
             }
             let nodes = [];
-            let outputCountMap = {};
-            let inputCountMap = {};
+            let inputCountMap = new Map();
+            let outputCountMap = new Map();
             for (const node of graph.node) {
                 for (const input of node.input) {
-                    inputCountMap[input] = (inputCountMap[input] || 0) + 1;
+                    inputCountMap.set(input, inputCountMap.has(input) ? inputCountMap.get(input) + 1 : 1);
                 }
                 for (const output of node.output) {
-                    outputCountMap[output] = (outputCountMap[output] || 0) + 1;
+                    outputCountMap.set(output, inputCountMap.has(output) ? inputCountMap.get(output) + 1 : 1);
                 }
             }
             for (const input of graph.input) {
-                delete inputCountMap[input];
+                inputCountMap.delete(input);
             }
             for (const output of graph.output) {
-                delete outputCountMap[output];
+                outputCountMap.delete(output);
             }
             for (const node of graph.node) {
                 let initializerNode = false;
                 if (node.op_type == 'Constant' && node.input.length == 0 && node.output.length == 1) {
-                    let name = node.output[0];
-                    if (inputCountMap[name] == 1 && outputCountMap[name] == 1 && node.attribute.length == 1) {
-                        let attribute = node.attribute[0];
+                    const name = node.output[0];
+                    if (inputCountMap.has(name) && inputCountMap.get(name) == 1 && 
+                        outputCountMap.has(name) && outputCountMap.get(name) == 1 &&
+                        node.attribute.length == 1) {
+                        const attribute = node.attribute[0];
                         if (attribute && attribute.name == 'value' && attribute.t) {
-                            initializers[name] = new onnx.Tensor(attribute.t, 'Constant');
+                            initializers.set(name, new onnx.Tensor(attribute.t, 'Constant'));
                             initializerNode = true;
                         }
                     }
@@ -296,18 +298,25 @@ onnx.Graph = class {
                 }
             }
 
-            this._arguments = {};
+            let args = new Map();
+            const arg = (id, type, doc_string, initializer, imageFormat) => {
+                if (!args.has(id)) {
+                    args.set(id, new onnx.Argument(id, type ? onnx.Tensor._formatType(type, imageFormat) : null, doc_string, initializer));
+                }
+                return args.get(id);
+            };
+
             for (const valueInfo of graph.value_info) {
-                this._argument(valueInfo.name, valueInfo.type, valueInfo.doc_string, initializers[valueInfo.name], imageFormat);
+                arg(valueInfo.name, valueInfo.type, valueInfo.doc_string, initializers.get(valueInfo.name), imageFormat);
             }
             for (const valueInfo of graph.input) {
-                let argument = this._argument(valueInfo.name, valueInfo.type, valueInfo.doc_string, initializers[valueInfo.name], imageFormat);
-                if (!initializers[valueInfo.name]) {
+                const argument = arg(valueInfo.name, valueInfo.type, valueInfo.doc_string, initializers.get(valueInfo.name), imageFormat);
+                if (!initializers.has(valueInfo.name)) {
                     this._inputs.push(new onnx.Parameter(valueInfo.name, [ argument ]));
                 }
             }
             for (const valueInfo of graph.output) {
-                let argument = this._argument(valueInfo.name, valueInfo.type, valueInfo.doc_string, initializers[valueInfo.name], imageFormat);
+                const argument = arg(valueInfo.name, valueInfo.type, valueInfo.doc_string, initializers.get(valueInfo.name), imageFormat);
                 this._outputs.push(new onnx.Parameter(valueInfo.name, [ argument ]));
             }
             for (const node of nodes) {
@@ -320,7 +329,7 @@ onnx.Graph = class {
                             if (inputIndex < node.input.length || inputSchema.option != 'optional') {
                                 let inputCount = (inputSchema.option == 'variadic') ? (node.input.length - inputIndex) : 1;
                                 let inputArguments = node.input.slice(inputIndex, inputIndex + inputCount).map((id) => {
-                                    return this._argument(id, null, null, initializers[id], imageFormat);
+                                    return arg(id, null, null, initializers.get(id), imageFormat);
                                 });
                                 inputIndex += inputCount;
                                 inputs.push(new onnx.Parameter(inputSchema.name, inputArguments));
@@ -330,7 +339,7 @@ onnx.Graph = class {
                     else {
                         inputs = inputs.concat(node.input.slice(inputIndex).map((id, index) => {
                             return new onnx.Parameter((inputIndex + index).toString(), [
-                                this._argument(id, null, null, null, imageFormat)
+                                arg(id, null, null, null, imageFormat)
                             ])
                         }));
                     }
@@ -343,7 +352,7 @@ onnx.Graph = class {
                             if (outputIndex < node.output.length || outputSchema.option != 'optional') {
                                 let outputCount = (outputSchema.option == 'variadic') ? (node.output.length - outputIndex) : 1;
                                 let outputArguments = node.output.slice(outputIndex, outputIndex + outputCount).map((id) => {
-                                    return this._argument(id, null, null, null, imageFormat);
+                                    return arg(id, null, null, null, imageFormat);
                                 });
                                 outputIndex += outputCount;
                                 outputs.push(new onnx.Parameter(outputSchema.name, outputArguments));
@@ -353,7 +362,7 @@ onnx.Graph = class {
                     else {
                         outputs = outputs.concat(node.output.slice(outputIndex).map((id, index) => {
                             return new onnx.Parameter((outputIndex + index).toString(), [
-                                this._argument(id, null, null, null, imageFormat)
+                                arg(id, null, null, null, imageFormat)
                             ]);
                         }));
                     }
@@ -361,8 +370,6 @@ onnx.Graph = class {
                 this._nodes.push(new onnx.Node(metadata, imageFormat, node.op_type, node.domain, node.name, node.doc_string, node.attribute, inputs, outputs));
             }
         }
-
-        delete this._arguments;
     }
 
     get name() {
@@ -392,15 +399,6 @@ onnx.Graph = class {
     toString() {
         return 'graph(' + this.name + ')';
     }
-
-    _argument(id, type, doc_string, initializer, imageFormat) {
-        let argument = this._arguments[id];
-        if (!argument) {
-            argument = new onnx.Argument(id, type ? onnx.Tensor._formatType(type, imageFormat) : null, doc_string, initializer);
-            this._arguments[id] = argument;
-        }
-        return argument;
-    }
 };
 
 onnx.Parameter = class {
@@ -426,6 +424,9 @@ onnx.Parameter = class {
 onnx.Argument = class {
 
     constructor(id, type, description, initializer) {
+        if (typeof id !== 'string') {
+            throw new onnx.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._description = description || '';

+ 3 - 0
src/paddle.js

@@ -183,6 +183,9 @@ paddle.Parameter = class {
 paddle.Argument = class {
 
     constructor(id, type, description, initializer) {
+        if (typeof id !== 'string') {
+            throw new paddle.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._description = description || null;

+ 3 - 0
src/pytorch.js

@@ -375,6 +375,9 @@ pytorch.Parameter = class {
 pytorch.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new pytorch.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type;
         this._initializer = initializer;

+ 3 - 0
src/sklearn.js

@@ -594,6 +594,9 @@ sklearn.Parameter = class {
 
 sklearn.Argument = class {
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new sklearn.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/tf.js

@@ -562,6 +562,9 @@ tf.Parameter = class {
 tf.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new tf.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type || null;
         this._initializer = initializer || null;

+ 3 - 0
src/torch.js

@@ -212,6 +212,9 @@ torch.Parameter = class {
 torch.Argument = class {
 
     constructor(id, type, initializer) {
+        if (typeof id !== 'string') {
+            throw new torch.Error("Invalid argument identifier '" + JSON.stringify(id) + "'.");
+        }
         this._id = id;
         this._type = type;
         this._initializer = initializer;

+ 1 - 0
test/models.json

@@ -5171,6 +5171,7 @@
     "type":   "tf",
     "target": "resnet_v2_fp16_savedmodel_NHWC_saved_model.pb",
     "source": "http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp16_savedmodel_NHWC.tar.gz[./resnet_v2_fp16_savedmodel_NHWC/1538686978/saved_model.pb]",
+    "render": "skip",
     "format": "TensorFlow Saved Model v1",
     "link":   "https://github.com/onnx/tensorflow-onnx/blob/master/tests/run_pretrained_models.yaml"
   },