|
|
@@ -33,31 +33,13 @@ class OnnxModel {
|
|
|
|
|
|
constructor(model) {
|
|
|
this._model = model;
|
|
|
-
|
|
|
- var imports = {};
|
|
|
- if (this._model.opsetImport) {
|
|
|
- this._model.opsetImport.forEach((opsetImport) => {
|
|
|
- var domain = opsetImport.domain || '';
|
|
|
- if (domain == 'ai.onnx') {
|
|
|
- domain = '';
|
|
|
- }
|
|
|
- if (!imports[domain] || imports[domain] > opsetImport.version) {
|
|
|
- imports[domain] = opsetImport.version;
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
- if (Object.keys(imports).length == 0) {
|
|
|
- imports[''] = 1;
|
|
|
- imports['ai.onnx.ml'] = 1;
|
|
|
- }
|
|
|
-
|
|
|
+ this._graphs = [];
|
|
|
+ this._activeGraph = null;
|
|
|
if (this._model.graph) {
|
|
|
- this._graphs = [ new OnnxGraph(this, imports, this._model.graph, 0) ];
|
|
|
- this._activeGraph = this._graphs[0];
|
|
|
- }
|
|
|
- else {
|
|
|
- this._graphs = [];
|
|
|
- this._activeGraph = null;
|
|
|
+ var metadata = new OnnxGraphOperatorMetadata(this._model);
|
|
|
+ var graph = new OnnxGraph(this, metadata, this._model.graph, 0);
|
|
|
+ this._graphs.push(graph);
|
|
|
+ this._activeGraph = graph;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -139,9 +121,9 @@ class OnnxModel {
|
|
|
|
|
|
class OnnxGraph {
|
|
|
|
|
|
- constructor(model, imports, graph, index) {
|
|
|
+ constructor(model, metadata, graph, index) {
|
|
|
this._model = model;
|
|
|
- this._imports = imports;
|
|
|
+ this._metadata = metadata;
|
|
|
this._graph = graph;
|
|
|
|
|
|
if (this._graph) {
|
|
|
@@ -251,8 +233,8 @@ class OnnxGraph {
|
|
|
return initializer ? initializer : null;
|
|
|
}
|
|
|
|
|
|
- get imports() {
|
|
|
- return this._imports;
|
|
|
+ get metadata() {
|
|
|
+ return this._metadata;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -280,7 +262,7 @@ class OnnxNode {
|
|
|
}
|
|
|
|
|
|
get documentation() {
|
|
|
- return OnnxOperatorMetadata.operatorMetadata.getOperatorDocumentation(this);
|
|
|
+ return this._graph.metadata.getOperatorDocumentation(this);
|
|
|
}
|
|
|
|
|
|
get domain() {
|
|
|
@@ -288,7 +270,7 @@ class OnnxNode {
|
|
|
}
|
|
|
|
|
|
get category() {
|
|
|
- return OnnxOperatorMetadata.operatorMetadata.getOperatorCategory(this);
|
|
|
+ return this._graph.metadata.getOperatorCategory(this);
|
|
|
}
|
|
|
|
|
|
get group() {
|
|
|
@@ -297,7 +279,7 @@ class OnnxNode {
|
|
|
|
|
|
get inputs() {
|
|
|
if (this._node.input) {
|
|
|
- var inputs = OnnxOperatorMetadata.operatorMetadata.getInputs(this);
|
|
|
+ var inputs = this._graph.metadata.getInputs(this);
|
|
|
inputs.forEach((input) => {
|
|
|
input.connections.forEach((connection) => {
|
|
|
var initializer = this._graph.getInitializer(connection.id);
|
|
|
@@ -313,7 +295,7 @@ class OnnxNode {
|
|
|
}
|
|
|
|
|
|
get outputs() {
|
|
|
- return OnnxOperatorMetadata.operatorMetadata.getOutputs(this);
|
|
|
+ return this._graph.metadata.getOutputs(this);
|
|
|
}
|
|
|
|
|
|
get dependencies() {
|
|
|
@@ -363,7 +345,7 @@ class OnnxAttribute {
|
|
|
else if (this._attribute.hasOwnProperty('t')) {
|
|
|
return OnnxTensor.formatTensorType(this._attribute.t);
|
|
|
}
|
|
|
- return OnnxOperatorMetadata.operatorMetadata.getAttributeType(this._node, this._attribute.name);
|
|
|
+ return this._node.graph.metadata.getAttributeType(this._node, this._attribute.name);
|
|
|
}
|
|
|
|
|
|
get value() {
|
|
|
@@ -481,7 +463,7 @@ class OnnxTensor {
|
|
|
break;
|
|
|
case onnx.TensorProto.DataType.INT32:
|
|
|
if (this._tensor.int32Data && this._tensor.int32Data.length > 0) {
|
|
|
- this._data = tensor.int32Data;
|
|
|
+ this._data = this._tensor.int32Data;
|
|
|
}
|
|
|
else if (this._tensor.rawData && this._tensor.rawData.length > 0) {
|
|
|
this._rawData = new DataView(this._tensor.rawData.buffer, this._tensor.rawData.byteOffset, this._tensor.rawData.byteLength);
|
|
|
@@ -663,74 +645,35 @@ class OnnxTensor {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-class OnnxOperatorMetadata {
|
|
|
+class OnnxGraphOperatorMetadata {
|
|
|
|
|
|
- static open(host, callback) {
|
|
|
- if (OnnxOperatorMetadata.operatorMetadata) {
|
|
|
- callback(null, OnnxOperatorMetadata.operatorMetadata);
|
|
|
- }
|
|
|
- else {
|
|
|
- host.request('/onnx-operator.json', (err, data) => {
|
|
|
- OnnxOperatorMetadata.operatorMetadata = new OnnxOperatorMetadata(data);
|
|
|
- callback(null, OnnxOperatorMetadata.operatorMetadata);
|
|
|
+ constructor(model) {
|
|
|
+ this._cache = {};
|
|
|
+ this._imports = {};
|
|
|
+ if (model.opsetImport) {
|
|
|
+ model.opsetImport.forEach((opsetImport) => {
|
|
|
+ var domain = opsetImport.domain || '';
|
|
|
+ if (domain == 'ai.onnx') {
|
|
|
+ domain = '';
|
|
|
+ }
|
|
|
+ if (!this._imports[domain] || this._imports[domain] > opsetImport.version) {
|
|
|
+ this._imports[domain] = opsetImport.version;
|
|
|
+ }
|
|
|
});
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- constructor(data) {
|
|
|
- this._map = {};
|
|
|
- if (data) {
|
|
|
- var items = JSON.parse(data);
|
|
|
- if (items) {
|
|
|
- items.forEach((item) => {
|
|
|
- if (item.name && item.schema)
|
|
|
- {
|
|
|
- var name = item.name;
|
|
|
- var schema = item.schema;
|
|
|
- var domain = item.schema.domain || '';
|
|
|
- if (domain == 'ai.onnx') {
|
|
|
- domain = '';
|
|
|
- }
|
|
|
- var version = item.schema.since_version || 0;
|
|
|
- this._map[name] = this._map[name] || {};
|
|
|
- this._map[name][domain] = this._map[name][domain] || {};
|
|
|
- this._map[name][domain][version] = this._map[name][domain][version] || {};
|
|
|
- this._map[name][domain][version] = schema;
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
+ }
|
|
|
+ if (Object.keys(this._imports).length == 0) {
|
|
|
+ this._imports[''] = 1;
|
|
|
+ this._imports['ai.onnx.ml'] = 1;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
getSchema(node) {
|
|
|
- var schema = null;
|
|
|
var operator = node.operator;
|
|
|
- var imports = node.graph.imports;
|
|
|
- var domainMap = this._map[operator];
|
|
|
- if (domainMap) {
|
|
|
- var domainKeys = Object.keys(domainMap);
|
|
|
- for (var i = 0; i < domainKeys.length; i++) {
|
|
|
- var domain = domainKeys[i];
|
|
|
- var versionMap = domainMap[domain];
|
|
|
- var importVersion = imports[domain];
|
|
|
- schema = versionMap[importVersion];
|
|
|
- if (!schema) {
|
|
|
- var version = -1;
|
|
|
- var sinceVersionKeys = Object.keys(versionMap);
|
|
|
- for (var j = 0; j < sinceVersionKeys.length; j++) {
|
|
|
- var sinceVersion = sinceVersionKeys[j];
|
|
|
- if (importVersion >= sinceVersion && version < sinceVersion) {
|
|
|
- version = sinceVersion;
|
|
|
- schema = versionMap[sinceVersion];
|
|
|
- }
|
|
|
- }
|
|
|
- if (version >= 0) {
|
|
|
- versionMap[version] = schema;
|
|
|
- }
|
|
|
- }
|
|
|
- if (schema) {
|
|
|
- break;
|
|
|
- }
|
|
|
+ var schema = this._cache[operator];
|
|
|
+ if (!schema) {
|
|
|
+ schema = OnnxOperatorMetadata.operatorMetadata.getSchema(operator, this._imports);
|
|
|
+ if (schema) {
|
|
|
+ this._cache[operator] = schema;
|
|
|
}
|
|
|
}
|
|
|
return schema;
|
|
|
@@ -918,6 +861,59 @@ class OnnxOperatorMetadata {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+class OnnxOperatorMetadata {
|
|
|
+
|
|
|
+ static open(host, callback) {
|
|
|
+ if (OnnxOperatorMetadata.operatorMetadata) {
|
|
|
+ callback(null, OnnxOperatorMetadata.operatorMetadata);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ host.request('/onnx-operator.json', (err, data) => {
|
|
|
+ OnnxOperatorMetadata.operatorMetadata = new OnnxOperatorMetadata(data);
|
|
|
+ callback(null, OnnxOperatorMetadata.operatorMetadata);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ constructor(data) {
|
|
|
+ this._map = {};
|
|
|
+ if (data) {
|
|
|
+ var items = JSON.parse(data);
|
|
|
+ if (items) {
|
|
|
+ items.forEach((item) => {
|
|
|
+ if (item.name && item.schema)
|
|
|
+ {
|
|
|
+ var name = item.name;
|
|
|
+ this._map[name] = this._map[name] || [];
|
|
|
+ this._map[name].push(item.schema);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ getSchema(operator, imports) {
|
|
|
+ var result = null;
|
|
|
+ var schemas = this._map[operator];
|
|
|
+ if (schemas) {
|
|
|
+ var version = -1;
|
|
|
+ schemas.forEach((schema) => {
|
|
|
+ var domain = schema.domain;
|
|
|
+ if (domain == 'ai.onnx') {
|
|
|
+ domain = '';
|
|
|
+ }
|
|
|
+ var importVersion = imports[domain];
|
|
|
+ var sinceVersion = schema.since_version;
|
|
|
+ if (importVersion >= sinceVersion && version < sinceVersion) {
|
|
|
+ version = sinceVersion;
|
|
|
+ result = schema;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
class OnnxError extends Error {
|
|
|
constructor(message) {
|
|
|
super(message);
|