فهرست منبع

Cache ONNX schemas per graph

Lutz Roeder 8 سال پیش
والد
کامیت
561cf52621
1فایلهای تغییر یافته به همراه91 افزوده شده و 95 حذف شده
  1. 91 95
      src/onnx-model.js

+ 91 - 95
src/onnx-model.js

@@ -33,31 +33,13 @@ class OnnxModel {
 
     constructor(model) {
         this._model = model;
-
-        var imports = {};
-        if (this._model.opsetImport) {
-            this._model.opsetImport.forEach((opsetImport) => {
-                var domain = opsetImport.domain || '';
-                if (domain == 'ai.onnx') {
-                    domain = '';
-                }
-                if (!imports[domain] || imports[domain] > opsetImport.version) {
-                    imports[domain] = opsetImport.version;
-                }
-            });
-        }
-        if (Object.keys(imports).length == 0) {
-            imports[''] = 1;
-            imports['ai.onnx.ml'] = 1;
-        }
-
+        this._graphs = [];
+        this._activeGraph = null;
         if (this._model.graph) {
-            this._graphs = [ new OnnxGraph(this, imports, this._model.graph, 0) ];
-            this._activeGraph = this._graphs[0];
-        }
-        else {
-            this._graphs = [];
-            this._activeGraph = null;
+            var metadata = new OnnxGraphOperatorMetadata(this._model);
+            var graph = new OnnxGraph(this, metadata, this._model.graph, 0);
+            this._graphs.push(graph);
+            this._activeGraph = graph;
         }
     }
 
@@ -139,9 +121,9 @@ class OnnxModel {
 
 class OnnxGraph {
 
-    constructor(model, imports, graph, index) {
+    constructor(model, metadata, graph, index) {
         this._model = model;
-        this._imports = imports;
+        this._metadata = metadata;
         this._graph = graph;
 
         if (this._graph) {
@@ -251,8 +233,8 @@ class OnnxGraph {
         return initializer ? initializer : null;
     }
 
-    get imports() {
-        return this._imports;
+    get metadata() {
+        return this._metadata;
     }
 }
 
@@ -280,7 +262,7 @@ class OnnxNode {
     }
 
     get documentation() {
-        return OnnxOperatorMetadata.operatorMetadata.getOperatorDocumentation(this);
+        return this._graph.metadata.getOperatorDocumentation(this);
     }
 
     get domain() {
@@ -288,7 +270,7 @@ class OnnxNode {
     }
 
     get category() {
-        return OnnxOperatorMetadata.operatorMetadata.getOperatorCategory(this);
+        return this._graph.metadata.getOperatorCategory(this);
     }
 
     get group() {
@@ -297,7 +279,7 @@ class OnnxNode {
 
     get inputs() {
         if (this._node.input) {
-            var inputs = OnnxOperatorMetadata.operatorMetadata.getInputs(this);
+            var inputs = this._graph.metadata.getInputs(this);
             inputs.forEach((input) => {
                 input.connections.forEach((connection) => {
                     var initializer = this._graph.getInitializer(connection.id);
@@ -313,7 +295,7 @@ class OnnxNode {
     }
 
     get outputs() {
-        return OnnxOperatorMetadata.operatorMetadata.getOutputs(this);
+        return this._graph.metadata.getOutputs(this);
     }
 
     get dependencies() {
@@ -363,7 +345,7 @@ class OnnxAttribute {
         else if (this._attribute.hasOwnProperty('t')) {
             return OnnxTensor.formatTensorType(this._attribute.t);
         }
-        return OnnxOperatorMetadata.operatorMetadata.getAttributeType(this._node, this._attribute.name);
+        return this._node.graph.metadata.getAttributeType(this._node, this._attribute.name);
     }
 
     get value() {
@@ -481,7 +463,7 @@ class OnnxTensor {
                 break;
             case onnx.TensorProto.DataType.INT32:
                 if (this._tensor.int32Data && this._tensor.int32Data.length > 0) {
-                    this._data = tensor.int32Data;
+                    this._data = this._tensor.int32Data;
                 }
                 else if (this._tensor.rawData && this._tensor.rawData.length > 0) {
                     this._rawData = new DataView(this._tensor.rawData.buffer, this._tensor.rawData.byteOffset, this._tensor.rawData.byteLength);
@@ -663,74 +645,35 @@ class OnnxTensor {
     }
 }
 
-class OnnxOperatorMetadata {
+class OnnxGraphOperatorMetadata {
 
-    static open(host, callback) {
-        if (OnnxOperatorMetadata.operatorMetadata) {
-            callback(null, OnnxOperatorMetadata.operatorMetadata);
-        }
-        else {
-            host.request('/onnx-operator.json', (err, data) => {
-                OnnxOperatorMetadata.operatorMetadata = new OnnxOperatorMetadata(data);
-                callback(null, OnnxOperatorMetadata.operatorMetadata);
+    constructor(model) {
+        this._cache = {};
+        this._imports = {};
+        if (model.opsetImport) {
+            model.opsetImport.forEach((opsetImport) => {
+                var domain = opsetImport.domain || '';
+                if (domain == 'ai.onnx') {
+                    domain = '';
+                }
+                if (!this._imports[domain] || this._imports[domain] > opsetImport.version) {
+                    this._imports[domain] = opsetImport.version;
+                }
             });
-        }    
-    }
-
-    constructor(data) {
-        this._map = {};
-        if (data) {
-            var items = JSON.parse(data);
-            if (items) {
-                items.forEach((item) => {
-                    if (item.name && item.schema)
-                    {
-                        var name = item.name;
-                        var schema = item.schema;
-                        var domain = item.schema.domain || '';
-                        if (domain == 'ai.onnx') {
-                            domain = '';
-                        }
-                        var version = item.schema.since_version || 0;
-                        this._map[name] = this._map[name] || {};
-                        this._map[name][domain] = this._map[name][domain] || {};
-                        this._map[name][domain][version] = this._map[name][domain][version] || {};
-                        this._map[name][domain][version] = schema;
-                    }
-                });
-            }
+        }
+        if (Object.keys(this._imports).length == 0) {
+            this._imports[''] = 1;
+            this._imports['ai.onnx.ml'] = 1;
         }
     }
 
     getSchema(node) {
-        var schema = null;
         var operator = node.operator;
-        var imports = node.graph.imports;
-        var domainMap = this._map[operator];
-        if (domainMap) {
-            var domainKeys = Object.keys(domainMap);
-            for (var i = 0; i < domainKeys.length; i++) {
-                var domain = domainKeys[i];
-                var versionMap = domainMap[domain];
-                var importVersion = imports[domain];
-                schema = versionMap[importVersion];
-                if (!schema) {
-                    var version = -1;
-                    var sinceVersionKeys = Object.keys(versionMap);
-                    for (var j = 0; j < sinceVersionKeys.length; j++) {
-                        var sinceVersion = sinceVersionKeys[j];
-                        if (importVersion >= sinceVersion && version < sinceVersion) {
-                            version = sinceVersion;
-                            schema = versionMap[sinceVersion];
-                        }
-                    }
-                    if (version >= 0) {
-                        versionMap[version] = schema;
-                    }
-                }
-                if (schema) {
-                    break;
-                }
+        var schema = this._cache[operator];
+        if (!schema) {
+            schema = OnnxOperatorMetadata.operatorMetadata.getSchema(operator, this._imports);
+            if (schema) {
+                this._cache[operator] = schema;
             }
         }
         return schema;
@@ -918,6 +861,59 @@ class OnnxOperatorMetadata {
     }
 }
 
+class OnnxOperatorMetadata {
+
+    static open(host, callback) {
+        if (OnnxOperatorMetadata.operatorMetadata) {
+            callback(null, OnnxOperatorMetadata.operatorMetadata);
+        }
+        else {
+            host.request('/onnx-operator.json', (err, data) => {
+                OnnxOperatorMetadata.operatorMetadata = new OnnxOperatorMetadata(data);
+                callback(null, OnnxOperatorMetadata.operatorMetadata);
+            });
+        }    
+    }
+
+    constructor(data) {
+        this._map = {};
+        if (data) {
+            var items = JSON.parse(data);
+            if (items) {
+                items.forEach((item) => {
+                    if (item.name && item.schema)
+                    {
+                        var name = item.name;
+                        this._map[name] = this._map[name] || [];
+                        this._map[name].push(item.schema);
+                    }
+                });
+            }
+        }
+    }
+
+    getSchema(operator, imports) {
+        var result = null;
+        var schemas = this._map[operator];
+        if (schemas) {
+            var version = -1;
+            schemas.forEach((schema) => {
+                var domain = schema.domain;
+                if (domain == 'ai.onnx') {
+                    domain = '';
+                }
+                var importVersion = imports[domain];
+                var sinceVersion = schema.since_version;
+                if (importVersion >= sinceVersion && version < sinceVersion) {
+                    version = sinceVersion;
+                    result = schema;
+                }
+            });
+        }
+        return result;
+    }
+}
+
 class OnnxError extends Error {
     constructor(message) {
         super(message);