Browse Source

Add ONNX Runtime test files (#767)

Lutz Roeder 4 years ago
parent
commit
7dda94de18
4 changed files with 164 additions and 54 deletions
  1. 36 5
      source/onnx-metadata.json
  2. 103 42
      source/onnx.js
  3. 15 0
      test/models.json
  4. 10 7
      tools/onnx-script.py

+ 36 - 5
source/onnx-metadata.json

@@ -4081,7 +4081,8 @@
         "summary": "clip_default_int8",
         "code": "node = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', 'min'],\n    outputs=['y'],\n)\nmin_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, min_val, np.iinfo(np.int8).max)\nexpect(node, inputs=[x, min_val], outputs=[y],\n       name='test_clip_default_int8_min')\n\nno_min = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, 'max'],\n    outputs=['y'],\n)\nmax_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, np.iinfo(np.int8).min, max_val)\nexpect(node, inputs=[x, max_val], outputs=[y],\n       name='test_clip_default_int8_max')\n\nno_max = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, no_max],\n    outputs=['y'],\n)\n\nx = np.array([-1, 0, 1]).astype(np.int8)\ny = np.array([-1, 0, 1]).astype(np.int8)\nexpect(node, inputs=[x], outputs=[y],\n       name='test_clip_default_int8_inbounds')"
       }
-    ]
+    ],
+    "category": "Activation"
   },
   {
     "name": "Clip",
@@ -4147,7 +4148,8 @@
         "summary": "clip_default_int8",
         "code": "node = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', 'min'],\n    outputs=['y'],\n)\nmin_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, min_val, np.iinfo(np.int8).max)\nexpect(node, inputs=[x, min_val], outputs=[y],\n       name='test_clip_default_int8_min')\n\nno_min = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, 'max'],\n    outputs=['y'],\n)\nmax_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, np.iinfo(np.int8).min, max_val)\nexpect(node, inputs=[x, max_val], outputs=[y],\n       name='test_clip_default_int8_max')\n\nno_max = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, no_max],\n    outputs=['y'],\n)\n\nx = np.array([-1, 0, 1]).astype(np.int8)\ny = np.array([-1, 0, 1]).astype(np.int8)\nexpect(node, inputs=[x], outputs=[y],\n       name='test_clip_default_int8_inbounds')"
       }
-    ]
+    ],
+    "category": "Activation"
   },
   {
     "name": "Clip",
@@ -4210,7 +4212,8 @@
         "summary": "clip_default_int8",
         "code": "node = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', 'min'],\n    outputs=['y'],\n)\nmin_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, min_val, np.iinfo(np.int8).max)\nexpect(node, inputs=[x, min_val], outputs=[y],\n       name='test_clip_default_int8_min')\n\nno_min = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, 'max'],\n    outputs=['y'],\n)\nmax_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, np.iinfo(np.int8).min, max_val)\nexpect(node, inputs=[x, max_val], outputs=[y],\n       name='test_clip_default_int8_max')\n\nno_max = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, no_max],\n    outputs=['y'],\n)\n\nx = np.array([-1, 0, 1]).astype(np.int8)\ny = np.array([-1, 0, 1]).astype(np.int8)\nexpect(node, inputs=[x], outputs=[y],\n       name='test_clip_default_int8_inbounds')"
       }
-    ]
+    ],
+    "category": "Activation"
   },
   {
     "name": "Clip",
@@ -4281,7 +4284,8 @@
         "summary": "clip_default_int8",
         "code": "node = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', 'min'],\n    outputs=['y'],\n)\nmin_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, min_val, np.iinfo(np.int8).max)\nexpect(node, inputs=[x, min_val], outputs=[y],\n       name='test_clip_default_int8_min')\n\nno_min = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, 'max'],\n    outputs=['y'],\n)\nmax_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, np.iinfo(np.int8).min, max_val)\nexpect(node, inputs=[x, max_val], outputs=[y],\n       name='test_clip_default_int8_max')\n\nno_max = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, no_max],\n    outputs=['y'],\n)\n\nx = np.array([-1, 0, 1]).astype(np.int8)\ny = np.array([-1, 0, 1]).astype(np.int8)\nexpect(node, inputs=[x], outputs=[y],\n       name='test_clip_default_int8_inbounds')"
       }
-    ]
+    ],
+    "category": "Activation"
   },
   {
     "name": "Clip",
@@ -4353,7 +4357,8 @@
         "summary": "clip_default_int8",
         "code": "node = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', 'min'],\n    outputs=['y'],\n)\nmin_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, min_val, np.iinfo(np.int8).max)\nexpect(node, inputs=[x, min_val], outputs=[y],\n       name='test_clip_default_int8_min')\n\nno_min = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, 'max'],\n    outputs=['y'],\n)\nmax_val = np.int8(0)\nx = np.random.randn(3, 4, 5).astype(np.int8)\ny = np.clip(x, np.iinfo(np.int8).min, max_val)\nexpect(node, inputs=[x, max_val], outputs=[y],\n       name='test_clip_default_int8_max')\n\nno_max = \"\"  # optional input, not supplied\nnode = onnx.helper.make_node(\n    'Clip',\n    inputs=['x', no_min, no_max],\n    outputs=['y'],\n)\n\nx = np.array([-1, 0, 1]).astype(np.int8)\ny = np.array([-1, 0, 1]).astype(np.int8)\nexpect(node, inputs=[x], outputs=[y],\n       name='test_clip_default_int8_inbounds')"
       }
-    ]
+    ],
+    "category": "Activation"
   },
   {
     "name": "Compress",
@@ -33901,5 +33906,31 @@
         ]
       }
     ]
+  },
+  {
+    "name": "FusedConv",
+    "module": "com.microsoft",
+    "version": 1,
+    "inputs": [
+      {
+        "name": "input",
+        "type": "T"
+      },
+      {
+        "name": "weights",
+        "type": "T"
+      },
+      {
+        "name": "bias",
+        "type": "T"
+      }
+    ],
+    "outputs": [
+      {
+        "name": "output",
+        "type": "T"
+      }
+    ],
+    "category": "Layer"
   }
 ]

+ 103 - 42
source/onnx.js

@@ -132,10 +132,10 @@ onnx.ModelFactory = class {
         if (tags.has('graph') && extension !== 'model') {
             return 'onnx.pbtxt.ModelProto';
         }
-        if (context.tags('flatbuffers').get('file_identifier') === 'ORTM') {
+        if (onnx.Runtime.Reader.open(stream, extension)) {
             return 'onnx.flatbuffers';
         }
-        if (onnx.TextReader.open(stream)) {
+        if (onnx.Text.Reader.open(stream)) {
             return 'onnx.text';
         }
         return undefined;
@@ -230,43 +230,8 @@ onnx.ModelFactory = class {
                     try {
                         onnx.schema = flatbuffers.get('ort').onnxruntime.fbs;
                         const stream = context.stream;
-                        const reader = flatbuffers.BinaryReader.open(stream);
-                        const session = onnx.schema.InferenceSession.create(reader);
-                        const model = session.model;
-                        const graph = model.graph;
-                        graph.node = graph.nodes;
-                        graph.doc_string = model.graph_doc_string;
-                        graph.value_info = graph.node_args;
-                        graph.input = graph.inputs.map((input) => {
-                            return { name: input };
-                        });
-                        graph.output = graph.outputs.map((output) => {
-                            return { name: output };
-                        });
-                        graph.initializer = graph.initializers.map((tensor) => {
-                            tensor.data_location = onnx.DataLocation.DEFAULT;
-                            return tensor;
-                        });
-                        graph.sparse_initializer = graph.sparse_initializers.map((tensor) => {
-                            tensor.values.data_location = onnx.DataLocation.DEFAULT;
-                            tensor.indices.data_location = onnx.DataLocation.DEFAULT;
-                            return tensor;
-                        });
-                        delete graph.nodes;
-                        delete graph.node_args;
-                        delete graph.inputs;
-                        delete graph.outputs;
-                        delete graph.initializers;
-                        delete graph.sparse_initializers;
-                        delete model.graph_doc_string;
-                        for (const node of graph.node) {
-                            node.input = node.inputs;
-                            node.output = node.outputs;
-                            node.attribute = node.attributes;
-                            delete node.inputs;
-                            delete node.outputs;
-                            delete node.attributes;
-                        }
+                        const reader = onnx.Runtime.Reader.open(stream, 'ort');
+                        const model = reader.read();
                         const format = 'ONNX Runtime' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
                         return open(model, format);
                     }
@@ -281,7 +246,7 @@ onnx.ModelFactory = class {
                     try {
                         onnx.proto = protobuf.get('onnx').onnx;
                         const stream = context.stream;
-                        const reader = onnx.TextReader.open(stream);
+                        const reader = onnx.Text.Reader.open(stream);
                         const model = reader.read();
                         const format = 'ONNX Text' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
                         return open(model, format);
@@ -1750,7 +1715,103 @@ onnx.GraphContext = class {
     }
 };
 
-onnx.TextReader = class {
+onnx.Runtime = {};
+
+onnx.Runtime.Reader = class {
+
+    static open(stream, extension) {
+        if (stream.length >= 8) {
+            const buffer = stream.peek(Math.min(32, stream.length));
+            const reader = flatbuffers.BinaryReader.open(buffer);
+            const identifier = reader.identifier;
+            if (identifier === 'ORTM') {
+                return new onnx.Runtime.Reader(stream);
+            }
+            if (extension === 'ort') {
+                const signature = [ 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
+                if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
+                    return new onnx.Runtime.Reader(stream);
+                }
+            }
+        }
+        return null;
+    }
+
+    constructor(stream) {
+        this._stream = stream;
+    }
+
+    read() {
+        this._graphs = new Set();
+        const reader = flatbuffers.BinaryReader.open(this._stream);
+        const session = onnx.schema.InferenceSession.create(reader);
+        const model = session.model;
+        const graph = model.graph;
+        graph.doc_string = model.graph_doc_string;
+        delete model.graph_doc_string;
+        this._graph(graph);
+        return model;
+    }
+
+    _graph(graph) {
+        if (this._graphs.has(graph)) {
+            return;
+        }
+        this._graphs.add(graph);
+        graph.name = this._graphs.size.toString();
+        graph.node = graph.nodes.map((node) => {
+            this._node(node);
+            return node;
+        });
+        delete graph.nodes;
+        graph.input = graph.inputs.map((input) => {
+            return { name: input };
+        });
+        delete graph.inputs;
+        graph.output = graph.outputs.map((output) => {
+            return { name: output };
+        });
+        delete graph.outputs;
+        graph.value_info = graph.node_args;
+        delete graph.node_args;
+        graph.initializer = graph.initializers.map((tensor) => {
+            tensor.data_location = onnx.DataLocation.DEFAULT;
+            return tensor;
+        });
+        delete graph.initializers;
+        graph.sparse_initializer = graph.sparse_initializers.map((tensor) => {
+            tensor.values.data_location = onnx.DataLocation.DEFAULT;
+            tensor.indices.data_location = onnx.DataLocation.DEFAULT;
+            return tensor;
+        });
+        delete graph.sparse_initializers;
+    }
+
+    _node(node) {
+        node.input = node.inputs;
+        node.output = node.outputs;
+        node.attribute = node.attributes.map((attribute) => {
+            switch (attribute.type) {
+                case onnx.AttributeType.GRAPH:
+                    this._graph(attribute.g);
+                    break;
+                case onnx.AttributeType.GRAPHS:
+                    for (const graph of attribute.graphs) {
+                        this._graph(graph);
+                    }
+                    break;
+            }
+            return attribute;
+        });
+        delete node.inputs;
+        delete node.outputs;
+        delete node.attributes;
+    }
+};
+
+onnx.Text = {};
+
+onnx.Text.Reader = class {
 
     static open(data) {
         try {
@@ -1766,7 +1827,7 @@ onnx.TextReader = class {
             const content = lines.join('\n');
             if (/^\s*<\s*ir_version\s*:/m.exec(content) ||
                 /^\s*[a-zA-Z][a-zA-Z0-9]*\s*\(.*\)\s=>\s\(/m.exec(content)) {
-                return new onnx.TextReader(data);
+                return new onnx.Text.Reader(data);
             }
         }
         catch (err) {

+ 15 - 0
test/models.json

@@ -3423,6 +3423,13 @@
     "format": "ONNX v3",
     "link":   "https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/test/testdata"
   },
+  {
+    "type":   "onnx",
+    "target": "ort_github_issue_4031.onnx.ort",
+    "source": "https://github.com/lutzroeder/netron/files/8191468/ort_github_issue_4031.onnx.ort.zip[ort_github_issue_4031.onnx.ort]",
+    "format": "ONNX Runtime v7",
+    "link":   "https://github.com/lutzroeder/netron/issues/767"
+  },
   {
     "type":   "onnx",
     "target": "resnet50_opset_9.onnx.zip",
@@ -3462,8 +3469,16 @@
     "type":   "onnx",
     "target": "shufflenet_opset_9.onnx",
     "source": "https://s3.amazonaws.com/download.onnx/models/opset_9/shufflenet.tar.gz[shufflenet/model.onnx]",
+    "format": "ONNX v3",
     "link":   "https://github.com/onnx/models/tree/main/shufflenet"
   },
+  {
+    "type":   "onnx",
+    "target": "sklearn_bin_voting_classifier_soft.ort",
+    "source": "https://github.com/lutzroeder/netron/files/8191675/sklearn_bin_voting_classifier_soft.ort.zip[sklearn_bin_voting_classifier_soft.ort]",
+    "format": "ONNX Runtime v6",
+    "link":   "https://github.com/lutzroeder/netron/issues/767"
+  },
   {
     "type":   "onnx",
     "target": "sparse_const.onnx",

+ 10 - 7
tools/onnx-script.py

@@ -23,6 +23,7 @@ categories = {
     'LSTM': 'Layer',
     'GRU': 'Layer',
     'Gemm': 'Layer',
+    'FusedConv': 'Layer',
 
     'Dropout': 'Dropout',
 
@@ -39,6 +40,7 @@ categories = {
     'Softmax': 'Activation',
     'Softplus': 'Activation',
     'Softsign': 'Activation',
+    'Clip': 'Activation',
 
     'BatchNormalization': 'Normalization',
     'InstanceNormalization': 'Normalization',
@@ -135,9 +137,11 @@ def format_description(description):
     description = re.sub("\\[(.+)\\]\\(([^ ]+?)( \"(.+)\")?\\)", replace_line, description)
     return description
 
-def generate_json(schemas, json_file):
+def metadata():
+    json_file = os.path.join(os.path.dirname(__file__), '../source/onnx-metadata.json')
     json_root = []
-    for schema in schemas:
+    all_schemas_with_history = onnx.defs.get_all_schemas_with_history()
+    for schema in all_schemas_with_history:
         json_schema = {}
         json_schema['name'] = schema.name
         if schema.domain:
@@ -222,6 +226,10 @@ def generate_json(schemas, json_file):
             json_schema['category'] = categories[schema.name]
         json_root.append(json_schema);
     json_root = sorted(json_root, key=lambda item: item['name'] + ':' + str(item['version'] if 'version' in item else 0).zfill(4))
+    with io.open(json_file, 'r') as file:
+        content = file.read();
+        items = json.loads(content)
+        json_root = json_root + list(filter(lambda item: item['module'] == "com.microsoft", items))
     with io.open(json_file, 'w', newline='') as fout:
         json_root = json.dumps(json_root, indent=2)
         for line in json_root.splitlines():
@@ -231,11 +239,6 @@ def generate_json(schemas, json_file):
             fout.write(line)
             fout.write('\n')
 
-def metadata():
-    json_file = os.path.join(os.path.dirname(__file__), '../source/onnx-metadata.json')
-    all_schemas_with_history = onnx.defs.get_all_schemas_with_history()
-    generate_json(all_schemas_with_history, json_file)
-
 def optimize():
     import onnx
     from onnx import optimizer