Quellcode durchsuchen

ONNX loader shared edge objects (#71)

Lutz Roeder vor 7 Jahren
Ursprung
Commit
929254ec27
1 geänderte Dateien mit 194 neuen und 207 gelöschten Zeilen
  1. 194 207
      src/onnx-model.js

+ 194 - 207
src/onnx-model.js

@@ -43,30 +43,26 @@ class OnnxModel {
 
     constructor(model) {
         this._model = model;
-        this._graphs = [];
-        this._activeGraph = null;
-        if (this._model.graph) {
-            var metadata = new OnnxGraphOperatorMetadata(this._model);
-            var graph = new OnnxGraph(this, metadata, this._model.graph, 0);
-            this._graphs.push(graph);
-            this._activeGraph = graph;
-        }
+        this._irVersion = model.irVersion;
+        this._opsetImport = model.opsetImport;
+        this._producerName = model.producerName;
+        this._producerVersion = model.producerVersion;
+        this._domain = model.domain;
+        this._modelVersion = model.modelVersion;
+        this._docString = model.docString;
+        this._metadataProps = model.metadataProps;
     }
 
     get properties() {
         var results = [];
         var format = 'ONNX';
-        if (this._model.irVersion) {
-            format = format + ' v' + this._model.irVersion.toString();
-            // var major = (this._model.irVersion >> 16) & 0x0f;
-            // var minor = (this._model.irVersion >> 8) & 0x0f;
-            // var revision = (this._model.irVersion) & 0x0f;
-            // format = format + ' v' + major.toString() + '.' + minor.toString() + '.' + revision.toString();
+        if (this._irVersion) {
+            format = format + ' v' + this._irVersion.toString();
         }
         results.push({ name: 'Format', value: format });
-        if (this._model.opsetImport && this._model.opsetImport.length > 0) {
+        if (this._opsetImport && this._opsetImport.length > 0) {
             var opsetImports = [];
-            this._model.opsetImport.forEach((opsetImport) => {
+            this._opsetImport.forEach((opsetImport) => {
                 var domain = opsetImport.domain ? opsetImport.domain : 'ai.onnx';
                 var result = domain + ' v' + opsetImport.version;
                 if (!opsetImports.includes(result)) {
@@ -76,28 +72,28 @@ class OnnxModel {
             results.push({ name: 'Imports', value: opsetImports.join(', ') });
         }
         var producer = [];
-        if (this._model.producerName) {
-            producer.push(this._model.producerName);
+        if (this._producerName) {
+            producer.push(this._producerName);
         }
-        if (this._model.producerVersion && this._model.producerVersion.length > 0) {
-            producer.push(this._model.producerVersion);
+        if (this._producerVersion && this._producerVersion.length > 0) {
+            producer.push(this._producerVersion);
         }
         if (producer.length > 0) {
             results.push({ 'name': 'Producer', 'value': producer.join(' ') });
         }
-        if (this._model.domain) {
-            results.push({ name: 'Domain', value: this._model.domain });
+        if (this._domain) {
+            results.push({ name: 'Domain', value: this._domain });
         }
-        if (this._model.modelVersion) {
-            results.push({ name: 'Version', value: this._model.modelVersion });
+        if (this._modelVersion) {
+            results.push({ name: 'Version', value: this._modelVersion });
         }
-        if (this._model.docString) {
-            results.push({ name: 'Description', value: this._model.docString });
+        if (this._docString) {
+            results.push({ name: 'Description', value: this._docString });
         }
         var metadata = {};
-        if (this._model.metadataProps)
+        if (this._metadataProps)
         {
-            this._model.metadataProps.forEach((metadataProp) => {
+            this._metadataProps.forEach((metadataProp) => {
                 metadata[metadataProp.key] = metadataProp.value;
             });
         }
@@ -125,68 +121,112 @@ class OnnxModel {
     }
 
     get graphs() {
+        if (this._model) {
+            this._graphs = [];
+            if (this._model.graph) {
+                var metadata = new OnnxGraphOperatorMetadata(this._opsetImport);
+                var graph = new OnnxGraph(metadata, this._model.graph, 0);
+                this._graphs.push(graph);
+            }
+            delete this._model;
+        }
         return this._graphs;
     }
 }
 
 class OnnxGraph {
 
-    constructor(model, metadata, graph, index) {
-        this._model = model;
+    constructor(metadata, graph, index) {
         this._metadata = metadata;
-        this._graph = graph;
+        this._node = '';
+        this._description = '';
         this._nodes = [];
-        this._initializerMap = [];
-        this._valueInfoMap = [];
-        this._outputMap = {};
 
-        if (this._graph) {
-            this._name = this._graph.name ? this._graph.name : ('(' + index.toString() + ')');
+        if (graph) {
+            this._name = graph.name || ('(' + index.toString() + ')');
+            this._description = graph.docString || '';
 
-            this._graph.node.forEach((node) => {
+            this._initializerMap = {};
+            this._connectionMap = {};
+            graph.initializer.forEach((tensor) => {
+                this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer');
+            });
+            graph.valueInfo.forEach((valueInfo) => {
+                this._connection(valueInfo.name, valueInfo.type, valueInfo.docString);
+            });
+
+            var nodes = [];
+            var outputCountMap = {};
+            graph.node.forEach((node) => {
                 node.output.forEach((output) => {
-                    this._outputMap[output] = (this._outputMap[output] || 0) + 1;
+                    outputCountMap[output] = (outputCountMap[output] || 0) + 1;
                 });
             });
-    
-            this._graph.initializer.forEach((tensor) => {
-                this._initializerMap[tensor.name] = new OnnxTensor(tensor, tensor.name, 'Initializer');
-            });
-            this._graph.node.forEach((node) => {
-                var add = true;
-                if (node.opType == 'Constant' && node.output && node.output.length == 1 && this._outputMap[node.output[0]] == 1) {
-                    node.attribute.forEach((attribute) => {
-                        if (attribute.name == 'value' && attribute.t) {
-                            var name = node.output[0];
+            graph.node.forEach((node) => {
+                var initializerNode = false;
+                if (node.opType == 'Constant' && node.output && node.output.length == 1) {
+                    var name = node.output[0];
+                    if (outputCountMap[name] == 1) {
+                        var attribute = node.attribute.find((attribute) => { return attribute.name == 'value' && attribute.t; }); 
+                        if (attribute) {
                             this._initializerMap[name] = new OnnxTensor(attribute.t, name, 'Constant');
-                            add = false;
+                            initializerNode = true;
                         }
-                    });
+                    }
                 }
-                if (add) {
-                    this._nodes.push(new OnnxNode(this, node));
+                if (!initializerNode) {
+                    nodes.push(node);
                 }
             });
 
-            this._graph.valueInfo.forEach((valueInfo) => {
-                this._valueInfoMap[valueInfo.name] = valueInfo;
+            this._inputs = [];
+            graph.input.forEach((valueInfo) => {
+                if (!this._initializerMap[valueInfo.name]) {
+                    var connection = this._connection(valueInfo.name, valueInfo.type, valueInfo.docString);
+                    connection.name = valueInfo.name;
+                    this._inputs.push(connection);
+                }
+            });
+            this._outputs = [];
+            graph.output.map((valueInfo) => {
+                var connection = this._connection(valueInfo.name, valueInfo.type, valueInfo.docString);
+                connection.name = valueInfo.name;
+                this._outputs.push(connection);
+            });
+    
+            nodes.forEach((node) => {
+                var inputs = [];
+                if (node.input) {
+                    inputs = this._metadata.getInputs(node.opType, node.input);
+                    inputs.forEach((input) => {
+                        input.connections = input.connections.map((connection) => {
+                            return this._connection(connection.id);
+                        });
+                    });          
+                }
+                var outputs = [];
+                if (node.output) {
+                    outputs = this._metadata.getOutputs(node.opType, node.output);
+                    outputs.forEach((output) => {
+                        output.connections = output.connections.map((connection) => {
+                            return this._connection(connection.id);
+                        });
+                    });
+                }
+                this._nodes.push(new OnnxNode(this, node.opType, node.domain, node.name, node.docString, node.attribute, inputs, outputs));
             });
-        }
-    }
 
-    get model() {
-        return this._model;
+            delete this._initializerMap;
+            delete this._connectionMap;
+        }
     }
 
     get name() {
-        return this._name || '';
+        return this._name;
     }
 
     get description() {
-        if (this._graph && this._graph.docString) {
-            return this._graph.docString;
-        }
-        return '';
+        return this._description;
     }
 
     get groups() {
@@ -194,44 +234,10 @@ class OnnxGraph {
     }
 
     get inputs() {
-        if (!this._inputs) {
-            this._inputs = [];
-            if (this._graph) {
-                var initializerMap = {};
-                this._graph.initializer.forEach((tensor) => {
-                    initializerMap[tensor.name] = true;
-                });
-                this._graph.input.forEach((valueInfo) => {
-                    if (!initializerMap[valueInfo.name]) {
-                        this._inputs.push({
-                            id: valueInfo.name,
-                            name: valueInfo.name,
-                            description: valueInfo.docString,
-                            type: OnnxTensor.formatType(valueInfo.type)
-                        });
-                        this._valueInfoMap[valueInfo.name] = valueInfo;
-                    }
-                });
-            }
-        }
         return this._inputs;
     }
 
     get outputs() {
-        if (!this._outputs) {
-            this._outputs = [];
-            if (this._graph) {
-                this._outputs = this._graph.output.map((valueInfo) => {
-                    this._valueInfoMap[valueInfo.name] = valueInfo;
-                    return {
-                        id: valueInfo.name,
-                        name: valueInfo.name,
-                        description: valueInfo.docString,
-                        type: OnnxTensor.formatType(valueInfo.type)
-                    };
-                });
-            }
-        }
         return this._outputs;
     }
 
@@ -239,38 +245,66 @@ class OnnxGraph {
         return this._nodes;
     }
 
-    getInitializer(input) {
-        var initializer = this._initializerMap[input];
-        return initializer ? initializer : null;
-    }
-
-    getValueInfo(input) {
-        var valueInfo = this._valueInfoMap[input];
-        return valueInfo ? valueInfo : null;
-    }
-
     get metadata() {
         return this._metadata;
     }
+
+    _connection(name, type, docString) {
+        var connection = this._connectionMap[name];
+        if (!connection) {
+            connection = {};
+            connection.id = name;
+            var initializer = this._initializerMap[name];
+            if (initializer) {
+                connection.initializer = initializer;
+                connection.type = initializer.type;
+            }
+            if (type) {
+                connection.type = OnnxTensor.formatType(type);
+            }
+            if (docString) {
+                connection.description = docString;
+            }
+            this._connectionMap[name] = connection;
+        }
+        return connection;
+    }
 }
 
 class OnnxNode {
 
-    constructor(graph, node) {
+    constructor(graph, operator, domain, name, description, attributes, inputs, outputs) {
         this._graph = graph;
-        this._node = node;
+        this._operator = operator;
+        if (domain) {
+            this._domain = domain;
+        }
+        if (name) {
+            this._name = name;
+        }
+        if (description) {
+            this._description = description;
+        }
+        this._attributes = [];
+        if (attributes && attributes.length > 0) {
+            attributes.forEach((attribute) => { 
+                this._attributes.push(new OnnxAttribute(this, attribute));
+            });
+        }            
+        this._inputs = inputs;
+        this._outputs = outputs;
     }
 
     get operator() {
-        return this._node.opType;
+        return this._operator;
     }
 
     get name() {
-        return this._node.name ? this._node.name : null;
+        return this._name || null;
     }
 
     get description() {
-        return this._node.docString ? this._node.docString : null;
+        return this._description || null;
     }
 
     get primitive() {
@@ -278,86 +312,44 @@ class OnnxNode {
     }
 
     get documentation() {
-        return this._graph.metadata.getOperatorDocumentation(this);
+        return this._graph.metadata.getOperatorDocumentation(this._operator);
     }
 
     get domain() {
-        return this._node.domain ? this._node.domain : null;
+        return this._domain || null;
     }
 
     get category() {
-        return this._graph.metadata.getOperatorCategory(this);
+        return this._graph.metadata.getOperatorCategory(this._operator);
     }
 
     get group() {
         return null;
     }
 
+    get attributes() {
+        return this._attributes;
+    }
+
     get inputs() {
-        if (this._node.input) {
-            var inputs = this._graph.metadata.getInputs(this);
-            inputs.forEach((input) => {
-                input.connections.forEach((connection) => {
-                    var initializer = this._graph.getInitializer(connection.id);
-                    if (initializer) {
-                        connection.initializer = initializer;
-                        connection.type = initializer.type;
-                    }
-                    else {
-                        var valueInfo = this._graph.getValueInfo(connection.id);
-                        if (valueInfo) {
-                            connection.type = OnnxTensor.formatType(valueInfo.type);
-                        }
-                    }
-                });
-            });          
-            return inputs;
-        }
-        return [];
+        return this._inputs;
     }
 
     get outputs() {
-        if (this._node.output) {
-            var outputs = this._graph.metadata.getOutputs(this);
-            outputs.forEach((output) => {
-                output.connections.forEach((connection) => {
-                    var valueInfo = this._graph.getValueInfo(connection.id);
-                    if (valueInfo) {
-                        connection.type = OnnxTensor.formatType(valueInfo.type);
-                    }
-                });
-            });
-            return outputs;
-        }
-        return [];
+        return this._outputs;
     }
 
     get dependencies() {
         return [];
     }
 
-    get attributes() {
-        var result = null;
-        var node = this._node;
-        if (node.attribute && node.attribute.length > 0) {
-            result = [];
-            node.attribute.forEach((attribute) => { 
-                result.push(new OnnxAttribute(this, attribute));
-            });
-        }
-        return result;
-    }
-
     get graph() {
         return this._graph;
     }
-
-    get data() {
-        return this._node;
-    }
 }
 
 class OnnxAttribute {
+
     constructor(node, attribute) {
         this._node = node;
         this._attribute = attribute;
@@ -379,7 +371,7 @@ class OnnxAttribute {
         else if (this._attribute.hasOwnProperty('t')) {
             return OnnxTensor.formatTensorType(this._attribute.t);
         }
-        return this._node.graph.metadata.getAttributeType(this._node, this._attribute.name);
+        return this._node.graph.metadata.getAttributeType(this._node.operator, this._attribute.name);
     }
 
     get value() {
@@ -430,7 +422,7 @@ class OnnxAttribute {
     }
 
     get visible() {
-        return this._node.graph.metadata.getAttributeVisible(this._node, this);
+        return this._node.graph.metadata.getAttributeVisible(this._node.operator, this);
     }
 
     get tensor() {
@@ -443,9 +435,7 @@ class OnnxTensor {
     constructor(tensor, id, kind) {
         this._tensor = tensor;
         this._id = id;
-        if (kind) {
-            this._kind = kind;
-        }
+        this._kind = kind || null;
     }
 
     get id() {
@@ -457,7 +447,7 @@ class OnnxTensor {
     }
 
     get kind() {
-        return this._kind ? this._kind : null;
+        return this._kind;
     }
 
     get type() {
@@ -686,11 +676,11 @@ class OnnxTensor {
 
 class OnnxGraphOperatorMetadata {
 
-    constructor(model) {
+    constructor(opsetImport) {
         this._cache = {};
         this._imports = {};
-        if (model.opsetImport) {
-            model.opsetImport.forEach((opsetImport) => {
+        if (opsetImport) {
+            opsetImport.forEach((opsetImport) => {
                 var domain = opsetImport.domain || '';
                 if (domain == 'ai.onnx') {
                     domain = '';
@@ -706,8 +696,7 @@ class OnnxGraphOperatorMetadata {
         }
     }
 
-    getSchema(node) {
-        var operator = node.operator;
+    getSchema(operator) {
         var schema = this._cache[operator];
         if (!schema) {
             schema = OnnxOperatorMetadata.operatorMetadata.getSchema(operator, this._imports);
@@ -718,8 +707,8 @@ class OnnxGraphOperatorMetadata {
         return schema;
     }
 
-    getAttributeSchema(node, name) {
-        var schema = this.getSchema(node);
+    getAttributeSchema(operator, name) {
+        var schema = this.getSchema(operator);
         if (schema) {
             var attributeMap = schema.attributeMap;
             if (!attributeMap) {
@@ -739,32 +728,31 @@ class OnnxGraphOperatorMetadata {
         return null;
     }
 
-    getInputs(node) {
-        var inputs = [];
+    getInputs(operator, inputs) {
+        var results = [];
         var index = 0;
-        var schema = this.getSchema(node);
-        var data = node.data;
+        var schema = this.getSchema(operator);
         if (schema && schema.inputs) {
             schema.inputs.forEach((inputDef) => {
-                if (index < data.input.length || inputDef.option != 'optional') {
+                if (index < inputs.length || inputDef.option != 'optional') {
                     var input = {};
                     input.name = inputDef.name;
                     input.type = inputDef.type;
-                    var count = (inputDef.option == 'variadic') ? (data.input.length - index) : 1;
+                    var count = (inputDef.option == 'variadic') ? (inputs.length - index) : 1;
                     input.connections = [];
-                    data.input.slice(index, index + count).forEach((id) => {
+                    inputs.slice(index, index + count).forEach((id) => {
                         if (id != '' || inputDef.option != 'optional') {
                             input.connections.push({ id: id});
                         }
                     });
                     index += count;
-                    inputs.push(input);
+                    results.push(input);
                 }
             });
         }
         else {
-            data.input.slice(index).forEach((input) => {
-                inputs.push({
+            inputs.slice(index).forEach((input) => {
+                results.push({
                     name: '(' + index.toString() + ')',
                     connections: [ { id: input } ]
                 });
@@ -772,31 +760,30 @@ class OnnxGraphOperatorMetadata {
             });
 
         }
-        return inputs;
+        return results;
     }
 
-    getOutputs(node) {
-        var outputs = [];
+    getOutputs(operator, outputs) {
+        var results = [];
         var index = 0;
-        var schema = this.getSchema(node);
-        var data = node.data;
+        var schema = this.getSchema(operator);
         if (schema && schema.outputs) {
             schema.outputs.forEach((outputDef) => {
-                if (index < data.output.length || outputDef.option != 'optional') {
+                if (index < outputs.length || outputDef.option != 'optional') {
                     var output = {};
                     output.name = outputDef.name;
                     var count = (outputDef.option == 'variadic') ? (data.output.length - index) : 1;
-                    output.connections = data.output.slice(index, index + count).map((id) => {
+                    output.connections = outputs.slice(index, index + count).map((id) => {
                         return { id: id };
                     });
                     index += count;
-                    outputs.push(output);
+                    results.push(output);
                 }
             });
         }
         else {
-            data.output.slice(index).forEach((output) => {
-                outputs.push({
+            outputs.slice(index).forEach((output) => {
+                results.push({
                     name: '(' + index.toString() + ')',
                     connections: [ { id: output } ]
                 });
@@ -804,19 +791,19 @@ class OnnxGraphOperatorMetadata {
             });
 
         }
-        return outputs;
+        return results;
     }
 
-    getAttributeType(node, name) {
-        var schema = this.getAttributeSchema(node, name);
+    getAttributeType(operator, name) {
+        var schema = this.getAttributeSchema(operator, name);
         if (schema && schema.type) {
             return schema.type;
         }
         return '';
     }
 
-    getAttributeVisible(node, attribute) {
-        var schema = this.getAttributeSchema(node, attribute.name);
+    getAttributeVisible(operator, attribute) {
+        var schema = this.getAttributeSchema(operator, attribute.name);
         if (schema && schema.hasOwnProperty('default') && schema.default) {
             if (attribute.value == schema.default.toString()) {
                 return false;
@@ -825,19 +812,19 @@ class OnnxGraphOperatorMetadata {
         return true;     
     }
 
-    getOperatorCategory(node) {
-        var schema = this.getSchema(node);
+    getOperatorCategory(operator) {
+        var schema = this.getSchema(operator);
         if (schema && schema.category) {
             return schema.category;
         }
         return null;
     }
 
-    getOperatorDocumentation(node) {
-        var schema = this.getSchema(node);
+    getOperatorDocumentation(operator) {
+        var schema = this.getSchema(operator);
         if (schema) {
             schema = JSON.parse(JSON.stringify(schema));
-            schema.name = node.operator;
+            schema.name = operator;
             if (schema.description) {
                 var input = schema.description.split('\n');
                 var output = [];