Jelajahi Sumber

Update onnx.js (#767)

Lutz Roeder 4 tahun lalu
induk
melakukan
70e17d15cc
1 mengubah file dengan 35 tambahan dan 16 penghapusan
  1. 35 16
      source/onnx.js

+ 35 - 16
source/onnx.js

@@ -560,9 +560,19 @@ onnx.Node = class {
         }
         this._name = name || '';
         this._description = description || '';
+        this._chain = [];
         this._inputs = inputs;
         this._outputs = outputs;
         this._attributes = (attributes || []).map((attribute) => new onnx.Attribute(context, op_type, domain, attribute));
+        attributes = Object.fromEntries(attributes.map((entry) => [ entry.name, entry ]));
+        const identifier = domain ? domain + '.' + op_type : op_type;
+        switch (identifier) {
+            case 'com.microsoft.FusedConv':
+                if (attributes.activation) {
+                    this._chain.push(new onnx.Node(context, attributes.activation.s, '', '', '', [], [], []));
+                }
+                break;
+        }
     }
 
     get type() {
@@ -588,6 +598,10 @@ onnx.Node = class {
     get outputs() {
         return this._outputs;
     }
+
+    get chain() {
+        return this._chain;
+    }
 };
 
 onnx.Attribute = class {
@@ -830,6 +844,9 @@ onnx.Tensor = class {
                     case onnx.DataType.INT64:
                         data = tensor.int64_data;
                         break;
+                    case onnx.DataType.STRING:
+                        data = tensor.string_data;
+                        break;
                 }
                 if (data && (Array.isArray(data) || ArrayBuffer.isView(data)) && data.length === 0) {
                     data = undefined;
@@ -1813,21 +1830,23 @@ onnx.Text = {};
 
 onnx.Text.Reader = class {
 
-    static open(data) {
+    static open(stream) {
         try {
-            const reader = text.Reader.open(data);
-            const lines = [];
-            for (let i = 0; i < 32; i++) {
-                const line = reader.read();
-                if (line === undefined) {
-                    break;
+            if (stream.length > 0 && stream.peek(1)[0] < 0x80 || stream.peek(1)[0] >= 0xFE) {
+                const reader = text.Reader.open(stream);
+                const lines = [];
+                for (let i = 0; i < 32; i++) {
+                    const line = reader.read();
+                    if (line === undefined) {
+                        break;
+                    }
+                    lines.push(line);
+                }
+                const content = lines.join('\n');
+                if (/^\s*<\s*ir_version\s*:/m.exec(content) ||
+                    /^\s*[a-zA-Z][a-zA-Z0-9]*\s*\(.*\)\s=>\s\(/m.exec(content)) {
+                    return new onnx.Text.Reader(stream);
                 }
-                lines.push(line);
-            }
-            const content = lines.join('\n');
-            if (/^\s*<\s*ir_version\s*:/m.exec(content) ||
-                /^\s*[a-zA-Z][a-zA-Z0-9]*\s*\(.*\)\s=>\s\(/m.exec(content)) {
-                return new onnx.Text.Reader(data);
             }
         }
         catch (err) {
@@ -1836,8 +1855,8 @@ onnx.Text.Reader = class {
         return null;
     }
 
-    constructor(data) {
-        this._data = data;
+    constructor(stream) {
+        this._stream = stream;
         this._dataTypes = new Map([
             [ 'float', 1 ], [ 'uint8', 2 ], [ 'int8', 3 ], [ 'uint16', 4 ],
             [ 'int16', 5 ], [ 'int32', 6 ], [ 'int64', 7 ], [ 'string', 8 ],
@@ -1853,7 +1872,7 @@ onnx.Text.Reader = class {
     }
 
     read() {
-        const decoder = text.Decoder.open(this._data);
+        const decoder = text.Decoder.open(this._stream);
         this._decoder = decoder;
         this._position = 0;
         this._char = decoder.decode();