|
@@ -136,18 +136,19 @@ class OnnxGraph {
|
|
|
this._metadata = metadata;
|
|
this._metadata = metadata;
|
|
|
this._graph = graph;
|
|
this._graph = graph;
|
|
|
this._nodes = [];
|
|
this._nodes = [];
|
|
|
|
|
+ this._initializerMap = [];
|
|
|
|
|
+ this._valueInfoMap = [];
|
|
|
|
|
+ this._outputMap = {};
|
|
|
|
|
|
|
|
if (this._graph) {
|
|
if (this._graph) {
|
|
|
this._name = this._graph.name ? this._graph.name : ('(' + index.toString() + ')');
|
|
this._name = this._graph.name ? this._graph.name : ('(' + index.toString() + ')');
|
|
|
-
|
|
|
|
|
- this._outputMap = {};
|
|
|
|
|
|
|
+
|
|
|
this._graph.node.forEach((node) => {
|
|
this._graph.node.forEach((node) => {
|
|
|
node.output.forEach((output) => {
|
|
node.output.forEach((output) => {
|
|
|
this._outputMap[output] = (this._outputMap[output] || 0) + 1;
|
|
this._outputMap[output] = (this._outputMap[output] || 0) + 1;
|
|
|
});
|
|
});
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
- this._initializerMap = [];
|
|
|
|
|
this._graph.initializer.forEach((tensor) => {
|
|
this._graph.initializer.forEach((tensor) => {
|
|
|
this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer');
|
|
this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer');
|
|
|
});
|
|
});
|
|
@@ -166,6 +167,10 @@ class OnnxGraph {
|
|
|
this._nodes.push(new OnnxNode(this, node));
|
|
this._nodes.push(new OnnxNode(this, node));
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
|
|
+
|
|
|
|
|
+ this._graph.valueInfo.forEach((valueInfo) => {
|
|
|
|
|
+ this._valueInfoMap[valueInfo.name] = valueInfo;
|
|
|
|
|
+ });
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -174,7 +179,7 @@ class OnnxGraph {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get name() {
|
|
get name() {
|
|
|
- return this._name;
|
|
|
|
|
|
|
+ return this._name || '';
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get description() {
|
|
get description() {
|
|
@@ -196,7 +201,7 @@ class OnnxGraph {
|
|
|
this._graph.initializer.forEach((tensor) => {
|
|
this._graph.initializer.forEach((tensor) => {
|
|
|
initializerMap[tensor.name] = true;
|
|
initializerMap[tensor.name] = true;
|
|
|
});
|
|
});
|
|
|
- this._graph.input.forEach((valueInfo, index) => {
|
|
|
|
|
|
|
+ this._graph.input.forEach((valueInfo) => {
|
|
|
if (!initializerMap[valueInfo.name]) {
|
|
if (!initializerMap[valueInfo.name]) {
|
|
|
this._inputs.push({
|
|
this._inputs.push({
|
|
|
id: valueInfo.name,
|
|
id: valueInfo.name,
|
|
@@ -204,6 +209,7 @@ class OnnxGraph {
|
|
|
description: valueInfo.docString,
|
|
description: valueInfo.docString,
|
|
|
type: OnnxTensor.formatType(valueInfo.type)
|
|
type: OnnxTensor.formatType(valueInfo.type)
|
|
|
});
|
|
});
|
|
|
|
|
+ this._valueInfoMap[valueInfo.name] = valueInfo;
|
|
|
}
|
|
}
|
|
|
});
|
|
});
|
|
|
}
|
|
}
|
|
@@ -216,6 +222,7 @@ class OnnxGraph {
|
|
|
this._outputs = [];
|
|
this._outputs = [];
|
|
|
if (this._graph) {
|
|
if (this._graph) {
|
|
|
this._outputs = this._graph.output.map((valueInfo) => {
|
|
this._outputs = this._graph.output.map((valueInfo) => {
|
|
|
|
|
+ this._valueInfoMap[valueInfo.name] = valueInfo;
|
|
|
return {
|
|
return {
|
|
|
id: valueInfo.name,
|
|
id: valueInfo.name,
|
|
|
name: valueInfo.name,
|
|
name: valueInfo.name,
|
|
@@ -237,6 +244,11 @@ class OnnxGraph {
|
|
|
return initializer ? initializer : null;
|
|
return initializer ? initializer : null;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ getValueInfo(input) {
|
|
|
|
|
+ var valueInfo = this._valueInfoMap[input];
|
|
|
|
|
+ return valueInfo ? valueInfo : null;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
get metadata() {
|
|
get metadata() {
|
|
|
return this._metadata;
|
|
return this._metadata;
|
|
|
}
|
|
}
|
|
@@ -291,6 +303,12 @@ class OnnxNode {
|
|
|
connection.initializer = initializer;
|
|
connection.initializer = initializer;
|
|
|
connection.type = initializer.type;
|
|
connection.type = initializer.type;
|
|
|
}
|
|
}
|
|
|
|
|
+ else {
|
|
|
|
|
+ var valueInfo = this._graph.getValueInfo(connection.id);
|
|
|
|
|
+ if (valueInfo) {
|
|
|
|
|
+ connection.type = OnnxTensor.formatType(valueInfo.type);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
});
|
|
});
|
|
|
});
|
|
});
|
|
|
return inputs;
|
|
return inputs;
|
|
@@ -299,7 +317,19 @@ class OnnxNode {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get outputs() {
|
|
get outputs() {
|
|
|
- return this._graph.metadata.getOutputs(this);
|
|
|
|
|
|
|
+ if (this._node.output) {
|
|
|
|
|
+ var outputs = this._graph.metadata.getOutputs(this);
|
|
|
|
|
+ outputs.forEach((output) => {
|
|
|
|
|
+ output.connections.forEach((connection) => {
|
|
|
|
|
+ var valueInfo = this._graph.getValueInfo(connection.id);
|
|
|
|
|
+ if (valueInfo) {
|
|
|
|
|
+ connection.type = OnnxTensor.formatType(valueInfo.type);
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ });
|
|
|
|
|
+ return outputs;
|
|
|
|
|
+ }
|
|
|
|
|
+ return [];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
get dependencies() {
|
|
get dependencies() {
|