|
@@ -43,30 +43,26 @@ class OnnxModel {
|
|
|
|
|
|
|
|
constructor(model) {
|
|
constructor(model) {
|
|
|
this._model = model;
|
|
this._model = model;
|
|
|
- this._graphs = [];
|
|
|
|
|
- this._activeGraph = null;
|
|
|
|
|
- if (this._model.graph) {
|
|
|
|
|
- var metadata = new OnnxGraphOperatorMetadata(this._model);
|
|
|
|
|
- var graph = new OnnxGraph(this, metadata, this._model.graph, 0);
|
|
|
|
|
- this._graphs.push(graph);
|
|
|
|
|
- this._activeGraph = graph;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ this._irVersion = model.irVersion;
|
|
|
|
|
+ this._opsetImport = model.opsetImport;
|
|
|
|
|
+ this._producerName = model.producerName;
|
|
|
|
|
+ this._producerVersion = model.producerVersion;
|
|
|
|
|
+ this._domain = model.domain;
|
|
|
|
|
+ this._modelVersion = model.modelVersion;
|
|
|
|
|
+ this._docString = model.docString;
|
|
|
|
|
+ this._metadataProps = model.metadataProps;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get properties() {
|
|
get properties() {
|
|
|
var results = [];
|
|
var results = [];
|
|
|
var format = 'ONNX';
|
|
var format = 'ONNX';
|
|
|
- if (this._model.irVersion) {
|
|
|
|
|
- format = format + ' v' + this._model.irVersion.toString();
|
|
|
|
|
- // var major = (this._model.irVersion >> 16) & 0x0f;
|
|
|
|
|
- // var minor = (this._model.irVersion >> 8) & 0x0f;
|
|
|
|
|
- // var revision = (this._model.irVersion) & 0x0f;
|
|
|
|
|
- // format = format + ' v' + major.toString() + '.' + minor.toString() + '.' + revision.toString();
|
|
|
|
|
|
|
+ if (this._irVersion) {
|
|
|
|
|
+ format = format + ' v' + this._irVersion.toString();
|
|
|
}
|
|
}
|
|
|
results.push({ name: 'Format', value: format });
|
|
results.push({ name: 'Format', value: format });
|
|
|
- if (this._model.opsetImport && this._model.opsetImport.length > 0) {
|
|
|
|
|
|
|
+ if (this._opsetImport && this._opsetImport.length > 0) {
|
|
|
var opsetImports = [];
|
|
var opsetImports = [];
|
|
|
- this._model.opsetImport.forEach((opsetImport) => {
|
|
|
|
|
|
|
+ this._opsetImport.forEach((opsetImport) => {
|
|
|
var domain = opsetImport.domain ? opsetImport.domain : 'ai.onnx';
|
|
var domain = opsetImport.domain ? opsetImport.domain : 'ai.onnx';
|
|
|
var result = domain + ' v' + opsetImport.version;
|
|
var result = domain + ' v' + opsetImport.version;
|
|
|
if (!opsetImports.includes(result)) {
|
|
if (!opsetImports.includes(result)) {
|
|
@@ -76,28 +72,28 @@ class OnnxModel {
|
|
|
results.push({ name: 'Imports', value: opsetImports.join(', ') });
|
|
results.push({ name: 'Imports', value: opsetImports.join(', ') });
|
|
|
}
|
|
}
|
|
|
var producer = [];
|
|
var producer = [];
|
|
|
- if (this._model.producerName) {
|
|
|
|
|
- producer.push(this._model.producerName);
|
|
|
|
|
|
|
+ if (this._producerName) {
|
|
|
|
|
+ producer.push(this._producerName);
|
|
|
}
|
|
}
|
|
|
- if (this._model.producerVersion && this._model.producerVersion.length > 0) {
|
|
|
|
|
- producer.push(this._model.producerVersion);
|
|
|
|
|
|
|
+ if (this._producerVersion && this._producerVersion.length > 0) {
|
|
|
|
|
+ producer.push(this._producerVersion);
|
|
|
}
|
|
}
|
|
|
if (producer.length > 0) {
|
|
if (producer.length > 0) {
|
|
|
results.push({ 'name': 'Producer', 'value': producer.join(' ') });
|
|
results.push({ 'name': 'Producer', 'value': producer.join(' ') });
|
|
|
}
|
|
}
|
|
|
- if (this._model.domain) {
|
|
|
|
|
- results.push({ name: 'Domain', value: this._model.domain });
|
|
|
|
|
|
|
+ if (this._domain) {
|
|
|
|
|
+ results.push({ name: 'Domain', value: this._domain });
|
|
|
}
|
|
}
|
|
|
- if (this._model.modelVersion) {
|
|
|
|
|
- results.push({ name: 'Version', value: this._model.modelVersion });
|
|
|
|
|
|
|
+ if (this._modelVersion) {
|
|
|
|
|
+ results.push({ name: 'Version', value: this._modelVersion });
|
|
|
}
|
|
}
|
|
|
- if (this._model.docString) {
|
|
|
|
|
- results.push({ name: 'Description', value: this._model.docString });
|
|
|
|
|
|
|
+ if (this._docString) {
|
|
|
|
|
+ results.push({ name: 'Description', value: this._docString });
|
|
|
}
|
|
}
|
|
|
var metadata = {};
|
|
var metadata = {};
|
|
|
- if (this._model.metadataProps)
|
|
|
|
|
|
|
+ if (this._metadataProps)
|
|
|
{
|
|
{
|
|
|
- this._model.metadataProps.forEach((metadataProp) => {
|
|
|
|
|
|
|
+ this._metadataProps.forEach((metadataProp) => {
|
|
|
metadata[metadataProp.key] = metadataProp.value;
|
|
metadata[metadataProp.key] = metadataProp.value;
|
|
|
});
|
|
});
|
|
|
}
|
|
}
|
|
@@ -125,68 +121,112 @@ class OnnxModel {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get graphs() {
|
|
get graphs() {
|
|
|
|
|
+ if (this._model) {
|
|
|
|
|
+ this._graphs = [];
|
|
|
|
|
+ if (this._model.graph) {
|
|
|
|
|
+ var metadata = new OnnxGraphOperatorMetadata(this._opsetImport);
|
|
|
|
|
+ var graph = new OnnxGraph(metadata, this._model.graph, 0);
|
|
|
|
|
+ this._graphs.push(graph);
|
|
|
|
|
+ }
|
|
|
|
|
+ delete this._model;
|
|
|
|
|
+ }
|
|
|
return this._graphs;
|
|
return this._graphs;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class OnnxGraph {
|
|
class OnnxGraph {
|
|
|
|
|
|
|
|
- constructor(model, metadata, graph, index) {
|
|
|
|
|
- this._model = model;
|
|
|
|
|
|
|
+ constructor(metadata, graph, index) {
|
|
|
this._metadata = metadata;
|
|
this._metadata = metadata;
|
|
|
- this._graph = graph;
|
|
|
|
|
|
|
+ this._node = '';
|
|
|
|
|
+ this._description = '';
|
|
|
this._nodes = [];
|
|
this._nodes = [];
|
|
|
- this._initializerMap = [];
|
|
|
|
|
- this._valueInfoMap = [];
|
|
|
|
|
- this._outputMap = {};
|
|
|
|
|
|
|
|
|
|
- if (this._graph) {
|
|
|
|
|
- this._name = this._graph.name ? this._graph.name : ('(' + index.toString() + ')');
|
|
|
|
|
|
|
+ if (graph) {
|
|
|
|
|
+ this._name = graph.name || ('(' + index.toString() + ')');
|
|
|
|
|
+ this._description = graph.docString || '';
|
|
|
|
|
|
|
|
- this._graph.node.forEach((node) => {
|
|
|
|
|
|
|
+ this._initializerMap = {};
|
|
|
|
|
+ this._connectionMap = {};
|
|
|
|
|
+ graph.initializer.forEach((tensor) => {
|
|
|
|
|
+ this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer');
|
|
|
|
|
+ });
|
|
|
|
|
+ graph.valueInfo.forEach((valueInfo) => {
|
|
|
|
|
+ this._connection(valueInfo.name, valueInfo.type, valueInfo.docString);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ var nodes = [];
|
|
|
|
|
+ var outputCountMap = {};
|
|
|
|
|
+ graph.node.forEach((node) => {
|
|
|
node.output.forEach((output) => {
|
|
node.output.forEach((output) => {
|
|
|
- this._outputMap[output] = (this._outputMap[output] || 0) + 1;
|
|
|
|
|
|
|
+ outputCountMap[output] = (outputCountMap[output] || 0) + 1;
|
|
|
});
|
|
});
|
|
|
});
|
|
});
|
|
|
-
|
|
|
|
|
- this._graph.initializer.forEach((tensor) => {
|
|
|
|
|
- this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer');
|
|
|
|
|
- });
|
|
|
|
|
- this._graph.node.forEach((node) => {
|
|
|
|
|
- var add = true;
|
|
|
|
|
- if (node.opType == 'Constant' && node.output && node.output.length == 1 && this._outputMap[node.output[0]] == 1) {
|
|
|
|
|
- node.attribute.forEach((attribute) => {
|
|
|
|
|
- if (attribute.name == 'value' && attribute.t) {
|
|
|
|
|
- var name = node.output[0];
|
|
|
|
|
|
|
+ graph.node.forEach((node) => {
|
|
|
|
|
+ var initializerNode = false;
|
|
|
|
|
+ if (node.opType == 'Constant' && node.output && node.output.length == 1) {
|
|
|
|
|
+ var name = node.output[0];
|
|
|
|
|
+ if (outputCountMap[name] == 1) {
|
|
|
|
|
+ var attribute = node.attribute.find((attribute) => { return attribute.name == 'value' && attribute.t; });
|
|
|
|
|
+ if (attribute) {
|
|
|
this._initializerMap[name] = new OnnxTensor(attribute.t, name, 'Constant');
|
|
this._initializerMap[name] = new OnnxTensor(attribute.t, name, 'Constant');
|
|
|
- add = false;
|
|
|
|
|
|
|
+ initializerNode = true;
|
|
|
}
|
|
}
|
|
|
- });
|
|
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- if (add) {
|
|
|
|
|
- this._nodes.push(new OnnxNode(this, node));
|
|
|
|
|
|
|
+ if (!initializerNode) {
|
|
|
|
|
+ nodes.push(node);
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- this._graph.valueInfo.forEach((valueInfo) => {
|
|
|
|
|
- this._valueInfoMap[valueInfo.name] = valueInfo;
|
|
|
|
|
|
|
+ this._inputs = [];
|
|
|
|
|
+ graph.input.forEach((valueInfo) => {
|
|
|
|
|
+ if (!this._initializerMap[valueInfo.name]) {
|
|
|
|
|
+ var connection = this._connection(valueInfo.name, valueInfo.type, valueInfo.docString);
|
|
|
|
|
+ connection.name = valueInfo.name;
|
|
|
|
|
+ this._inputs.push(connection);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ this._outputs = [];
|
|
|
|
|
+ graph.output.map((valueInfo) => {
|
|
|
|
|
+ var connection = this._connection(valueInfo.name, valueInfo.type, valueInfo.docString);
|
|
|
|
|
+ connection.name = valueInfo.name;
|
|
|
|
|
+ this._outputs.push(connection);
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ nodes.forEach((node) => {
|
|
|
|
|
+ var inputs = [];
|
|
|
|
|
+ if (node.input) {
|
|
|
|
|
+ inputs = this._metadata.getInputs(node.opType, node.input);
|
|
|
|
|
+ inputs.forEach((input) => {
|
|
|
|
|
+ input.connections = input.connections.map((connection) => {
|
|
|
|
|
+ return this._connection(connection.id);
|
|
|
|
|
+ });
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
|
|
+ var outputs = [];
|
|
|
|
|
+ if (node.output) {
|
|
|
|
|
+ outputs = this._metadata.getOutputs(node.opType, node.output);
|
|
|
|
|
+ outputs.forEach((output) => {
|
|
|
|
|
+ output.connections = output.connections.map((connection) => {
|
|
|
|
|
+ return this._connection(connection.id);
|
|
|
|
|
+ });
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
|
|
+ this._nodes.push(new OnnxNode(this, node.opType, node.domain, node.name, node.docString, node.attribute, inputs, outputs));
|
|
|
});
|
|
});
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
|
|
|
- get model() {
|
|
|
|
|
- return this._model;
|
|
|
|
|
|
|
+ delete this._initializerMap;
|
|
|
|
|
+ delete this._connectionMap;
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get name() {
|
|
get name() {
|
|
|
- return this._name || '';
|
|
|
|
|
|
|
+ return this._name;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get description() {
|
|
get description() {
|
|
|
- if (this._graph && this._graph.docString) {
|
|
|
|
|
- return this._graph.docString;
|
|
|
|
|
- }
|
|
|
|
|
- return '';
|
|
|
|
|
|
|
+ return this._description;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get groups() {
|
|
get groups() {
|
|
@@ -194,44 +234,10 @@ class OnnxGraph {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get inputs() {
|
|
get inputs() {
|
|
|
- if (!this._inputs) {
|
|
|
|
|
- this._inputs = [];
|
|
|
|
|
- if (this._graph) {
|
|
|
|
|
- var initializerMap = {};
|
|
|
|
|
- this._graph.initializer.forEach((tensor) => {
|
|
|
|
|
- initializerMap[tensor.name] = true;
|
|
|
|
|
- });
|
|
|
|
|
- this._graph.input.forEach((valueInfo) => {
|
|
|
|
|
- if (!initializerMap[valueInfo.name]) {
|
|
|
|
|
- this._inputs.push({
|
|
|
|
|
- id: valueInfo.name,
|
|
|
|
|
- name: valueInfo.name,
|
|
|
|
|
- description: valueInfo.docString,
|
|
|
|
|
- type: OnnxTensor.formatType(valueInfo.type)
|
|
|
|
|
- });
|
|
|
|
|
- this._valueInfoMap[valueInfo.name] = valueInfo;
|
|
|
|
|
- }
|
|
|
|
|
- });
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
return this._inputs;
|
|
return this._inputs;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get outputs() {
|
|
get outputs() {
|
|
|
- if (!this._outputs) {
|
|
|
|
|
- this._outputs = [];
|
|
|
|
|
- if (this._graph) {
|
|
|
|
|
- this._outputs = this._graph.output.map((valueInfo) => {
|
|
|
|
|
- this._valueInfoMap[valueInfo.name] = valueInfo;
|
|
|
|
|
- return {
|
|
|
|
|
- id: valueInfo.name,
|
|
|
|
|
- name: valueInfo.name,
|
|
|
|
|
- description: valueInfo.docString,
|
|
|
|
|
- type: OnnxTensor.formatType(valueInfo.type)
|
|
|
|
|
- };
|
|
|
|
|
- });
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
return this._outputs;
|
|
return this._outputs;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -239,38 +245,66 @@ class OnnxGraph {
|
|
|
return this._nodes;
|
|
return this._nodes;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getInitializer(input) {
|
|
|
|
|
- var initializer = this._initializerMap[input];
|
|
|
|
|
- return initializer ? initializer : null;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- getValueInfo(input) {
|
|
|
|
|
- var valueInfo = this._valueInfoMap[input];
|
|
|
|
|
- return valueInfo ? valueInfo : null;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
get metadata() {
|
|
get metadata() {
|
|
|
return this._metadata;
|
|
return this._metadata;
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ _connection(name, type, docString) {
|
|
|
|
|
+ var connection = this._connectionMap[name];
|
|
|
|
|
+ if (!connection) {
|
|
|
|
|
+ connection = {};
|
|
|
|
|
+ connection.id = name;
|
|
|
|
|
+ var initializer = this._initializerMap[name];
|
|
|
|
|
+ if (initializer) {
|
|
|
|
|
+ connection.initializer = initializer;
|
|
|
|
|
+ connection.type = initializer.type;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (type) {
|
|
|
|
|
+ connection.type = OnnxTensor.formatType(type);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (docString) {
|
|
|
|
|
+ connection.description = docString;
|
|
|
|
|
+ }
|
|
|
|
|
+ this._connectionMap[name] = connection;
|
|
|
|
|
+ }
|
|
|
|
|
+ return connection;
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class OnnxNode {
|
|
class OnnxNode {
|
|
|
|
|
|
|
|
- constructor(graph, node) {
|
|
|
|
|
|
|
+ constructor(graph, operator, domain, name, description, attributes, inputs, outputs) {
|
|
|
this._graph = graph;
|
|
this._graph = graph;
|
|
|
- this._node = node;
|
|
|
|
|
|
|
+ this._operator = operator;
|
|
|
|
|
+ if (domain) {
|
|
|
|
|
+ this._domain = domain;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (name) {
|
|
|
|
|
+ this._name = name;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (description) {
|
|
|
|
|
+ this._description = description;
|
|
|
|
|
+ }
|
|
|
|
|
+ this._attributes = [];
|
|
|
|
|
+ if (attributes && attributes.length > 0) {
|
|
|
|
|
+ attributes.forEach((attribute) => {
|
|
|
|
|
+ this._attributes.push(new OnnxAttribute(this, attribute));
|
|
|
|
|
+ });
|
|
|
|
|
+ }
|
|
|
|
|
+ this._inputs = inputs;
|
|
|
|
|
+ this._outputs = outputs;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get operator() {
|
|
get operator() {
|
|
|
- return this._node.opType;
|
|
|
|
|
|
|
+ return this._operator;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get name() {
|
|
get name() {
|
|
|
- return this._node.name ? this._node.name : null;
|
|
|
|
|
|
|
+ return this._name || null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get description() {
|
|
get description() {
|
|
|
- return this._node.docString ? this._node.docString : null;
|
|
|
|
|
|
|
+ return this._description || null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get primitive() {
|
|
get primitive() {
|
|
@@ -278,86 +312,44 @@ class OnnxNode {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get documentation() {
|
|
get documentation() {
|
|
|
- return this._graph.metadata.getOperatorDocumentation(this);
|
|
|
|
|
|
|
+ return this._graph.metadata.getOperatorDocumentation(this._operator);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get domain() {
|
|
get domain() {
|
|
|
- return this._node.domain ? this._node.domain : null;
|
|
|
|
|
|
|
+ return this._domain || null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get category() {
|
|
get category() {
|
|
|
- return this._graph.metadata.getOperatorCategory(this);
|
|
|
|
|
|
|
+ return this._graph.metadata.getOperatorCategory(this._operator);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get group() {
|
|
get group() {
|
|
|
return null;
|
|
return null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ get attributes() {
|
|
|
|
|
+ return this._attributes;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
get inputs() {
|
|
get inputs() {
|
|
|
- if (this._node.input) {
|
|
|
|
|
- var inputs = this._graph.metadata.getInputs(this);
|
|
|
|
|
- inputs.forEach((input) => {
|
|
|
|
|
- input.connections.forEach((connection) => {
|
|
|
|
|
- var initializer = this._graph.getInitializer(connection.id);
|
|
|
|
|
- if (initializer) {
|
|
|
|
|
- connection.initializer = initializer;
|
|
|
|
|
- connection.type = initializer.type;
|
|
|
|
|
- }
|
|
|
|
|
- else {
|
|
|
|
|
- var valueInfo = this._graph.getValueInfo(connection.id);
|
|
|
|
|
- if (valueInfo) {
|
|
|
|
|
- connection.type = OnnxTensor.formatType(valueInfo.type);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- });
|
|
|
|
|
- });
|
|
|
|
|
- return inputs;
|
|
|
|
|
- }
|
|
|
|
|
- return [];
|
|
|
|
|
|
|
+ return this._inputs;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get outputs() {
|
|
get outputs() {
|
|
|
- if (this._node.output) {
|
|
|
|
|
- var outputs = this._graph.metadata.getOutputs(this);
|
|
|
|
|
- outputs.forEach((output) => {
|
|
|
|
|
- output.connections.forEach((connection) => {
|
|
|
|
|
- var valueInfo = this._graph.getValueInfo(connection.id);
|
|
|
|
|
- if (valueInfo) {
|
|
|
|
|
- connection.type = OnnxTensor.formatType(valueInfo.type);
|
|
|
|
|
- }
|
|
|
|
|
- });
|
|
|
|
|
- });
|
|
|
|
|
- return outputs;
|
|
|
|
|
- }
|
|
|
|
|
- return [];
|
|
|
|
|
|
|
+ return this._outputs;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get dependencies() {
|
|
get dependencies() {
|
|
|
return [];
|
|
return [];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- get attributes() {
|
|
|
|
|
- var result = null;
|
|
|
|
|
- var node = this._node;
|
|
|
|
|
- if (node.attribute && node.attribute.length > 0) {
|
|
|
|
|
- result = [];
|
|
|
|
|
- node.attribute.forEach((attribute) => {
|
|
|
|
|
- result.push(new OnnxAttribute(this, attribute));
|
|
|
|
|
- });
|
|
|
|
|
- }
|
|
|
|
|
- return result;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
get graph() {
|
|
get graph() {
|
|
|
return this._graph;
|
|
return this._graph;
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- get data() {
|
|
|
|
|
- return this._node;
|
|
|
|
|
- }
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class OnnxAttribute {
|
|
class OnnxAttribute {
|
|
|
|
|
+
|
|
|
constructor(node, attribute) {
|
|
constructor(node, attribute) {
|
|
|
this._node = node;
|
|
this._node = node;
|
|
|
this._attribute = attribute;
|
|
this._attribute = attribute;
|
|
@@ -379,7 +371,7 @@ class OnnxAttribute {
|
|
|
else if (this._attribute.hasOwnProperty('t')) {
|
|
else if (this._attribute.hasOwnProperty('t')) {
|
|
|
return OnnxTensor.formatTensorType(this._attribute.t);
|
|
return OnnxTensor.formatTensorType(this._attribute.t);
|
|
|
}
|
|
}
|
|
|
- return this._node.graph.metadata.getAttributeType(this._node, this._attribute.name);
|
|
|
|
|
|
|
+ return this._node.graph.metadata.getAttributeType(this._node.operator, this._attribute.name);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get value() {
|
|
get value() {
|
|
@@ -430,7 +422,7 @@ class OnnxAttribute {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get visible() {
|
|
get visible() {
|
|
|
- return this._node.graph.metadata.getAttributeVisible(this._node, this);
|
|
|
|
|
|
|
+ return this._node.graph.metadata.getAttributeVisible(this._node.operator, this);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get tensor() {
|
|
get tensor() {
|
|
@@ -443,9 +435,7 @@ class OnnxTensor {
|
|
|
constructor(tensor, id, kind) {
|
|
constructor(tensor, id, kind) {
|
|
|
this._tensor = tensor;
|
|
this._tensor = tensor;
|
|
|
this._id = id;
|
|
this._id = id;
|
|
|
- if (kind) {
|
|
|
|
|
- this._kind = kind;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ this._kind = kind || null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get id() {
|
|
get id() {
|
|
@@ -457,7 +447,7 @@ class OnnxTensor {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get kind() {
|
|
get kind() {
|
|
|
- return this._kind ? this._kind : null;
|
|
|
|
|
|
|
+ return this._kind;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get type() {
|
|
get type() {
|
|
@@ -686,11 +676,11 @@ class OnnxTensor {
|
|
|
|
|
|
|
|
class OnnxGraphOperatorMetadata {
|
|
class OnnxGraphOperatorMetadata {
|
|
|
|
|
|
|
|
- constructor(model) {
|
|
|
|
|
|
|
+ constructor(opsetImport) {
|
|
|
this._cache = {};
|
|
this._cache = {};
|
|
|
this._imports = {};
|
|
this._imports = {};
|
|
|
- if (model.opsetImport) {
|
|
|
|
|
- model.opsetImport.forEach((opsetImport) => {
|
|
|
|
|
|
|
+ if (opsetImport) {
|
|
|
|
|
+ opsetImport.forEach((opsetImport) => {
|
|
|
var domain = opsetImport.domain || '';
|
|
var domain = opsetImport.domain || '';
|
|
|
if (domain == 'ai.onnx') {
|
|
if (domain == 'ai.onnx') {
|
|
|
domain = '';
|
|
domain = '';
|
|
@@ -706,8 +696,7 @@ class OnnxGraphOperatorMetadata {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getSchema(node) {
|
|
|
|
|
- var operator = node.operator;
|
|
|
|
|
|
|
+ getSchema(operator) {
|
|
|
var schema = this._cache[operator];
|
|
var schema = this._cache[operator];
|
|
|
if (!schema) {
|
|
if (!schema) {
|
|
|
schema = OnnxOperatorMetadata.operatorMetadata.getSchema(operator, this._imports);
|
|
schema = OnnxOperatorMetadata.operatorMetadata.getSchema(operator, this._imports);
|
|
@@ -718,8 +707,8 @@ class OnnxGraphOperatorMetadata {
|
|
|
return schema;
|
|
return schema;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getAttributeSchema(node, name) {
|
|
|
|
|
- var schema = this.getSchema(node);
|
|
|
|
|
|
|
+ getAttributeSchema(operator, name) {
|
|
|
|
|
+ var schema = this.getSchema(operator);
|
|
|
if (schema) {
|
|
if (schema) {
|
|
|
var attributeMap = schema.attributeMap;
|
|
var attributeMap = schema.attributeMap;
|
|
|
if (!attributeMap) {
|
|
if (!attributeMap) {
|
|
@@ -739,32 +728,31 @@ class OnnxGraphOperatorMetadata {
|
|
|
return null;
|
|
return null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getInputs(node) {
|
|
|
|
|
- var inputs = [];
|
|
|
|
|
|
|
+ getInputs(operator, inputs) {
|
|
|
|
|
+ var results = [];
|
|
|
var index = 0;
|
|
var index = 0;
|
|
|
- var schema = this.getSchema(node);
|
|
|
|
|
- var data = node.data;
|
|
|
|
|
|
|
+ var schema = this.getSchema(operator);
|
|
|
if (schema && schema.inputs) {
|
|
if (schema && schema.inputs) {
|
|
|
schema.inputs.forEach((inputDef) => {
|
|
schema.inputs.forEach((inputDef) => {
|
|
|
- if (index < data.input.length || inputDef.option != 'optional') {
|
|
|
|
|
|
|
+ if (index < inputs.length || inputDef.option != 'optional') {
|
|
|
var input = {};
|
|
var input = {};
|
|
|
input.name = inputDef.name;
|
|
input.name = inputDef.name;
|
|
|
input.type = inputDef.type;
|
|
input.type = inputDef.type;
|
|
|
- var count = (inputDef.option == 'variadic') ? (data.input.length - index) : 1;
|
|
|
|
|
|
|
+ var count = (inputDef.option == 'variadic') ? (inputs.length - index) : 1;
|
|
|
input.connections = [];
|
|
input.connections = [];
|
|
|
- data.input.slice(index, index + count).forEach((id) => {
|
|
|
|
|
|
|
+ inputs.slice(index, index + count).forEach((id) => {
|
|
|
if (id != '' || inputDef.option != 'optional') {
|
|
if (id != '' || inputDef.option != 'optional') {
|
|
|
input.connections.push({ id: id});
|
|
input.connections.push({ id: id});
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
index += count;
|
|
index += count;
|
|
|
- inputs.push(input);
|
|
|
|
|
|
|
+ results.push(input);
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
}
|
|
}
|
|
|
else {
|
|
else {
|
|
|
- data.input.slice(index).forEach((input) => {
|
|
|
|
|
- inputs.push({
|
|
|
|
|
|
|
+ inputs.slice(index).forEach((input) => {
|
|
|
|
|
+ results.push({
|
|
|
name: '(' + index.toString() + ')',
|
|
name: '(' + index.toString() + ')',
|
|
|
connections: [ { id: input } ]
|
|
connections: [ { id: input } ]
|
|
|
});
|
|
});
|
|
@@ -772,31 +760,30 @@ class OnnxGraphOperatorMetadata {
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
}
|
|
}
|
|
|
- return inputs;
|
|
|
|
|
|
|
+ return results;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getOutputs(node) {
|
|
|
|
|
- var outputs = [];
|
|
|
|
|
|
|
+ getOutputs(operator, outputs) {
|
|
|
|
|
+ var results = [];
|
|
|
var index = 0;
|
|
var index = 0;
|
|
|
- var schema = this.getSchema(node);
|
|
|
|
|
- var data = node.data;
|
|
|
|
|
|
|
+ var schema = this.getSchema(operator);
|
|
|
if (schema && schema.outputs) {
|
|
if (schema && schema.outputs) {
|
|
|
schema.outputs.forEach((outputDef) => {
|
|
schema.outputs.forEach((outputDef) => {
|
|
|
- if (index < data.output.length || outputDef.option != 'optional') {
|
|
|
|
|
|
|
+ if (index < outputs.length || outputDef.option != 'optional') {
|
|
|
var output = {};
|
|
var output = {};
|
|
|
output.name = outputDef.name;
|
|
output.name = outputDef.name;
|
|
|
var count = (outputDef.option == 'variadic') ? (data.output.length - index) : 1;
|
|
var count = (outputDef.option == 'variadic') ? (data.output.length - index) : 1;
|
|
|
- output.connections = data.output.slice(index, index + count).map((id) => {
|
|
|
|
|
|
|
+ output.connections = outputs.slice(index, index + count).map((id) => {
|
|
|
return { id: id };
|
|
return { id: id };
|
|
|
});
|
|
});
|
|
|
index += count;
|
|
index += count;
|
|
|
- outputs.push(output);
|
|
|
|
|
|
|
+ results.push(output);
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
}
|
|
}
|
|
|
else {
|
|
else {
|
|
|
- data.output.slice(index).forEach((output) => {
|
|
|
|
|
- outputs.push({
|
|
|
|
|
|
|
+ outputs.slice(index).forEach((output) => {
|
|
|
|
|
+ results.push({
|
|
|
name: '(' + index.toString() + ')',
|
|
name: '(' + index.toString() + ')',
|
|
|
connections: [ { id: output } ]
|
|
connections: [ { id: output } ]
|
|
|
});
|
|
});
|
|
@@ -804,19 +791,19 @@ class OnnxGraphOperatorMetadata {
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
}
|
|
}
|
|
|
- return outputs;
|
|
|
|
|
|
|
+ return results;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getAttributeType(node, name) {
|
|
|
|
|
- var schema = this.getAttributeSchema(node, name);
|
|
|
|
|
|
|
+ getAttributeType(operator, name) {
|
|
|
|
|
+ var schema = this.getAttributeSchema(operator, name);
|
|
|
if (schema && schema.type) {
|
|
if (schema && schema.type) {
|
|
|
return schema.type;
|
|
return schema.type;
|
|
|
}
|
|
}
|
|
|
return '';
|
|
return '';
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getAttributeVisible(node, attribute) {
|
|
|
|
|
- var schema = this.getAttributeSchema(node, attribute.name);
|
|
|
|
|
|
|
+ getAttributeVisible(operator, attribute) {
|
|
|
|
|
+ var schema = this.getAttributeSchema(operator, attribute.name);
|
|
|
if (schema && schema.hasOwnProperty('default') && schema.default) {
|
|
if (schema && schema.hasOwnProperty('default') && schema.default) {
|
|
|
if (attribute.value == schema.default.toString()) {
|
|
if (attribute.value == schema.default.toString()) {
|
|
|
return false;
|
|
return false;
|
|
@@ -825,19 +812,19 @@ class OnnxGraphOperatorMetadata {
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getOperatorCategory(node) {
|
|
|
|
|
- var schema = this.getSchema(node);
|
|
|
|
|
|
|
+ getOperatorCategory(operator) {
|
|
|
|
|
+ var schema = this.getSchema(operator);
|
|
|
if (schema && schema.category) {
|
|
if (schema && schema.category) {
|
|
|
return schema.category;
|
|
return schema.category;
|
|
|
}
|
|
}
|
|
|
return null;
|
|
return null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- getOperatorDocumentation(node) {
|
|
|
|
|
- var schema = this.getSchema(node);
|
|
|
|
|
|
|
+ getOperatorDocumentation(operator) {
|
|
|
|
|
+ var schema = this.getSchema(operator);
|
|
|
if (schema) {
|
|
if (schema) {
|
|
|
schema = JSON.parse(JSON.stringify(schema));
|
|
schema = JSON.parse(JSON.stringify(schema));
|
|
|
- schema.name = node.operator;
|
|
|
|
|
|
|
+ schema.name = operator;
|
|
|
if (schema.description) {
|
|
if (schema.description) {
|
|
|
var input = schema.description.split('\n');
|
|
var input = schema.description.split('\n');
|
|
|
var output = [];
|
|
var output = [];
|