|
|
@@ -132,10 +132,10 @@ onnx.ModelFactory = class {
|
|
|
if (tags.has('graph') && extension !== 'model') {
|
|
|
return 'onnx.pbtxt.ModelProto';
|
|
|
}
|
|
|
- if (context.tags('flatbuffers').get('file_identifier') === 'ORTM') {
|
|
|
+ if (onnx.Runtime.Reader.open(stream, extension)) {
|
|
|
return 'onnx.flatbuffers';
|
|
|
}
|
|
|
- if (onnx.TextReader.open(stream)) {
|
|
|
+ if (onnx.Text.Reader.open(stream)) {
|
|
|
return 'onnx.text';
|
|
|
}
|
|
|
return undefined;
|
|
|
@@ -230,43 +230,8 @@ onnx.ModelFactory = class {
|
|
|
try {
|
|
|
onnx.schema = flatbuffers.get('ort').onnxruntime.fbs;
|
|
|
const stream = context.stream;
|
|
|
- const reader = flatbuffers.BinaryReader.open(stream);
|
|
|
- const session = onnx.schema.InferenceSession.create(reader);
|
|
|
- const model = session.model;
|
|
|
- const graph = model.graph;
|
|
|
- graph.node = graph.nodes;
|
|
|
- graph.doc_string = model.graph_doc_string;
|
|
|
- graph.value_info = graph.node_args;
|
|
|
- graph.input = graph.inputs.map((input) => {
|
|
|
- return { name: input };
|
|
|
- });
|
|
|
- graph.output = graph.outputs.map((output) => {
|
|
|
- return { name: output };
|
|
|
- });
|
|
|
- graph.initializer = graph.initializers.map((tensor) => {
|
|
|
- tensor.data_location = onnx.DataLocation.DEFAULT;
|
|
|
- return tensor;
|
|
|
- });
|
|
|
- graph.sparse_initializer = graph.sparse_initializers.map((tensor) => {
|
|
|
- tensor.values.data_location = onnx.DataLocation.DEFAULT;
|
|
|
- tensor.indices.data_location = onnx.DataLocation.DEFAULT;
|
|
|
- return tensor;
|
|
|
- });
|
|
|
- delete graph.nodes;
|
|
|
- delete graph.node_args;
|
|
|
- delete graph.inputs;
|
|
|
- delete graph.outputs;
|
|
|
- delete graph.initializers;
|
|
|
- delete graph.sparse_initializers;
|
|
|
- delete model.graph_doc_string;
|
|
|
- for (const node of graph.node) {
|
|
|
- node.input = node.inputs;
|
|
|
- node.output = node.outputs;
|
|
|
- node.attribute = node.attributes;
|
|
|
- delete node.inputs;
|
|
|
- delete node.outputs;
|
|
|
- delete node.attributes;
|
|
|
- }
|
|
|
+ const reader = onnx.Runtime.Reader.open(stream, 'ort');
|
|
|
+ const model = reader.read();
|
|
|
const format = 'ONNX Runtime' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
|
|
|
return open(model, format);
|
|
|
}
|
|
|
@@ -281,7 +246,7 @@ onnx.ModelFactory = class {
|
|
|
try {
|
|
|
onnx.proto = protobuf.get('onnx').onnx;
|
|
|
const stream = context.stream;
|
|
|
- const reader = onnx.TextReader.open(stream);
|
|
|
+ const reader = onnx.Text.Reader.open(stream);
|
|
|
const model = reader.read();
|
|
|
const format = 'ONNX Text' + (model.ir_version ? ' v' + model.ir_version.toString() : '');
|
|
|
return open(model, format);
|
|
|
@@ -1750,7 +1715,103 @@ onnx.GraphContext = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-onnx.TextReader = class {
|
|
|
+onnx.Runtime = {};
|
|
|
+
|
|
|
+onnx.Runtime.Reader = class {
|
|
|
+
|
|
|
+ static open(stream, extension) {
|
|
|
+ if (stream.length >= 8) {
|
|
|
+ const buffer = stream.peek(Math.min(32, stream.length));
|
|
|
+ const reader = flatbuffers.BinaryReader.open(buffer);
|
|
|
+ const identifier = reader.identifier;
|
|
|
+ if (identifier === 'ORTM') {
|
|
|
+ return new onnx.Runtime.Reader(stream);
|
|
|
+ }
|
|
|
+ if (extension === 'ort') {
|
|
|
+ const signature = [ 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
|
|
|
+ if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
+ return new onnx.Runtime.Reader(stream);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
+ constructor(stream) {
|
|
|
+ this._stream = stream;
|
|
|
+ }
|
|
|
+
|
|
|
+ read() {
|
|
|
+ this._graphs = new Set();
|
|
|
+ const reader = flatbuffers.BinaryReader.open(this._stream);
|
|
|
+ const session = onnx.schema.InferenceSession.create(reader);
|
|
|
+ const model = session.model;
|
|
|
+ const graph = model.graph;
|
|
|
+ graph.doc_string = model.graph_doc_string;
|
|
|
+ delete model.graph_doc_string;
|
|
|
+ this._graph(graph);
|
|
|
+ return model;
|
|
|
+ }
|
|
|
+
|
|
|
+ _graph(graph) {
|
|
|
+ if (this._graphs.has(graph)) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ this._graphs.add(graph);
|
|
|
+ graph.name = this._graphs.size.toString();
|
|
|
+ graph.node = graph.nodes.map((node) => {
|
|
|
+ this._node(node);
|
|
|
+ return node;
|
|
|
+ });
|
|
|
+ delete graph.nodes;
|
|
|
+ graph.input = graph.inputs.map((input) => {
|
|
|
+ return { name: input };
|
|
|
+ });
|
|
|
+ delete graph.inputs;
|
|
|
+ graph.output = graph.outputs.map((output) => {
|
|
|
+ return { name: output };
|
|
|
+ });
|
|
|
+ delete graph.outputs;
|
|
|
+ graph.value_info = graph.node_args;
|
|
|
+ delete graph.node_args;
|
|
|
+ graph.initializer = graph.initializers.map((tensor) => {
|
|
|
+ tensor.data_location = onnx.DataLocation.DEFAULT;
|
|
|
+ return tensor;
|
|
|
+ });
|
|
|
+ delete graph.initializers;
|
|
|
+ graph.sparse_initializer = graph.sparse_initializers.map((tensor) => {
|
|
|
+ tensor.values.data_location = onnx.DataLocation.DEFAULT;
|
|
|
+ tensor.indices.data_location = onnx.DataLocation.DEFAULT;
|
|
|
+ return tensor;
|
|
|
+ });
|
|
|
+ delete graph.sparse_initializers;
|
|
|
+ }
|
|
|
+
|
|
|
+ _node(node) {
|
|
|
+ node.input = node.inputs;
|
|
|
+ node.output = node.outputs;
|
|
|
+ node.attribute = node.attributes.map((attribute) => {
|
|
|
+ switch (attribute.type) {
|
|
|
+ case onnx.AttributeType.GRAPH:
|
|
|
+ this._graph(attribute.g);
|
|
|
+ break;
|
|
|
+ case onnx.AttributeType.GRAPHS:
|
|
|
+ for (const graph of attribute.graphs) {
|
|
|
+ this._graph(graph);
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ return attribute;
|
|
|
+ });
|
|
|
+ delete node.inputs;
|
|
|
+ delete node.outputs;
|
|
|
+ delete node.attributes;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+onnx.Text = {};
|
|
|
+
|
|
|
+onnx.Text.Reader = class {
|
|
|
|
|
|
static open(data) {
|
|
|
try {
|
|
|
@@ -1766,7 +1827,7 @@ onnx.TextReader = class {
|
|
|
const content = lines.join('\n');
|
|
|
if (/^\s*<\s*ir_version\s*:/m.exec(content) ||
|
|
|
/^\s*[a-zA-Z][a-zA-Z0-9]*\s*\(.*\)\s=>\s\(/m.exec(content)) {
|
|
|
- return new onnx.TextReader(data);
|
|
|
+ return new onnx.Text.Reader(data);
|
|
|
}
|
|
|
}
|
|
|
catch (err) {
|