Explorar o código

Show ONNX input and output shapes if available

Lutz Roeder %!s(int64=7) %!d(string=hai) anos
pai
achega
c1a58ffa7b
Modificáronse 1 ficheiros con 36 adicións e 6 borrados
  1. 36 6
      src/onnx-model.js

+ 36 - 6
src/onnx-model.js

@@ -136,18 +136,19 @@ class OnnxGraph {
         this._metadata = metadata;
         this._graph = graph;
         this._nodes = [];
+        this._initializerMap = [];
+        this._valueInfoMap = [];
+        this._outputMap = {};
 
         if (this._graph) {
             this._name = this._graph.name ? this._graph.name : ('(' + index.toString() + ')');
-            
-            this._outputMap = {};
+
             this._graph.node.forEach((node) => {
                 node.output.forEach((output) => {
                     this._outputMap[output] = (this._outputMap[output] || 0) + 1;
                 });
             });
     
-            this._initializerMap = [];
             this._graph.initializer.forEach((tensor) => {
                 this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer');
             });
@@ -166,6 +167,10 @@ class OnnxGraph {
                     this._nodes.push(new OnnxNode(this, node));
                 }
             });
+
+            this._graph.valueInfo.forEach((valueInfo) => {
+                this._valueInfoMap[valueInfo.name] = valueInfo;
+            });
         }
     }
 
@@ -174,7 +179,7 @@ class OnnxGraph {
     }
 
     get name() {
-        return this._name;
+        return this._name || '';
     }
 
     get description() {
@@ -196,7 +201,7 @@ class OnnxGraph {
                 this._graph.initializer.forEach((tensor) => {
                     initializerMap[tensor.name] = true;
                 });
-                this._graph.input.forEach((valueInfo, index) => {
+                this._graph.input.forEach((valueInfo) => {
                     if (!initializerMap[valueInfo.name]) {
                         this._inputs.push({
                             id: valueInfo.name,
@@ -204,6 +209,7 @@ class OnnxGraph {
                             description: valueInfo.docString,
                             type: OnnxTensor.formatType(valueInfo.type)
                         });
+                        this._valueInfoMap[valueInfo.name] = valueInfo;
                     }
                 });
             }
@@ -216,6 +222,7 @@ class OnnxGraph {
             this._outputs = [];
             if (this._graph) {
                 this._outputs = this._graph.output.map((valueInfo) => {
+                    this._valueInfoMap[valueInfo.name] = valueInfo;
                     return {
                         id: valueInfo.name,
                         name: valueInfo.name,
@@ -237,6 +244,11 @@ class OnnxGraph {
         return initializer ? initializer : null;
     }
 
+    getValueInfo(input) {
+        var valueInfo = this._valueInfoMap[input];
+        return valueInfo ? valueInfo : null;
+    }
+
     get metadata() {
         return this._metadata;
     }
@@ -291,6 +303,12 @@ class OnnxNode {
                         connection.initializer = initializer;
                         connection.type = initializer.type;
                     }
+                    else {
+                        var valueInfo = this._graph.getValueInfo(connection.id);
+                        if (valueInfo) {
+                            connection.type = OnnxTensor.formatType(valueInfo.type);
+                        }
+                    }
                 });
             });          
             return inputs;
@@ -299,7 +317,19 @@ class OnnxNode {
     }
 
     get outputs() {
-        return this._graph.metadata.getOutputs(this);
+        if (this._node.output) {
+            var outputs = this._graph.metadata.getOutputs(this);
+            outputs.forEach((output) => {
+                output.connections.forEach((connection) => {
+                    var valueInfo = this._graph.getValueInfo(connection.id);
+                    if (valueInfo) {
+                        connection.type = OnnxTensor.formatType(valueInfo.type);
+                    }
+                });
+            });
+            return outputs;
+        }
+        return [];
     }
 
     get dependencies() {