浏览代码

Add ONNX test file (#905)

Lutz Roeder 3 年之前
父节点
当前提交
e36b5858e1
共有 4 个文件被更改,包括 97 次插入13 次删除
  1. 50 0
      source/base.js
  2. 33 12
      source/onnx.js
  3. 7 1
      source/python.js
  4. 7 0
      test/models.json

+ 50 - 0
source/base.js

@@ -438,6 +438,22 @@ base.Uint64.zero = new base.Uint64(0, 0);
 base.Uint64.one = new base.Uint64(1, 0);
 base.Uint64.max = new base.Uint64(-1, -1);
 
+base.Complex = class Complex {
+
+    constructor(real, imaginary) {
+        this.real = real;
+        this.imaginary = imaginary;
+    }
+
+    static create(real, imaginary) {
+        return new base.Complex(real, imaginary);
+    }
+
+    toString(/* radix */) {
+        return this.real + ' + ' + this.imaginary + 'i';
+    }
+};
+
 if (!DataView.prototype.getFloat16) {
     DataView.prototype.getFloat16 = function(byteOffset, littleEndian) {
         const value = this.getUint16(byteOffset, littleEndian);
@@ -550,6 +566,40 @@ DataView.prototype.setUint64 = DataView.prototype.setUint64 || function(byteOffs
     }
 };
 
+DataView.prototype.getComplex64 = DataView.prototype.getComplex64 || function(byteOffset, littleEndian) {
+    const real = littleEndian ? this.getFloat32(byteOffset, littleEndian) : this.getFloat32(byteOffset + 4, littleEndian);
+    const imaginary = littleEndian ? this.getFloat32(byteOffset + 4, littleEndian) : this.getFloat32(byteOffset, littleEndian);
+    return base.Complex.create(real, imaginary);
+};
+
+DataView.prototype.setComplex64 = DataView.prototype.setComplex64 || function(byteOffset, value, littleEndian) {
+    if (littleEndian) {
+        this.setFloat32(byteOffset, value.real, littleEndian);
+        this.setFloat32(byteOffset + 4, value.imaginary, littleEndian);
+    }
+    else {
+        this.setFloat32(byteOffset + 4, value.real, littleEndian);
+        this.setFloat32(byteOffset, value.imaginary, littleEndian);
+    }
+};
+
+DataView.prototype.getComplex128 = DataView.prototype.getComplex128 || function(byteOffset, littleEndian) {
+    const real = littleEndian ? this.getFloat64(byteOffset, littleEndian) : this.getFloat64(byteOffset + 8, littleEndian);
+    const imaginary = littleEndian ? this.getFloat64(byteOffset + 8, littleEndian) : this.getFloat64(byteOffset, littleEndian);
+    return base.Complex.create(real, imaginary);
+};
+
+DataView.prototype.setComplex128 = DataView.prototype.setComplex128 || function(byteOffset, value, littleEndian) {
+    if (littleEndian) {
+        this.setFloat64(byteOffset, value.real, littleEndian);
+        this.setFloat64(byteOffset + 8, value.imaginary, littleEndian);
+    }
+    else {
+        this.setFloat64(byteOffset + 8, value.real, littleEndian);
+        this.setFloat64(byteOffset, value.imaginary, littleEndian);
+    }
+};
+
 DataView.prototype.getBits = DataView.prototype.getBits || function(offset, bits /*, signed */) {
     offset = offset * bits;
     const available = (this.byteLength << 3) - offset;

+ 33 - 12
source/onnx.js

@@ -848,6 +848,8 @@ onnx.Tensor = class {
                         data = tensor.string_data;
                         break;
                     case onnx.DataType.BFLOAT16:
+                    case onnx.DataType.COMPLEX64:
+                    case onnx.DataType.COMPLEX128:
                         break;
                     default:
                         throw new onnx.Error("Unsupported tensor data type '" + tensor.data_type + "'.");
@@ -1013,6 +1015,18 @@ onnx.Tensor = class {
                         data[i] = view.getBfloat16(i << 1, true);
                     }
                     break;
+                case onnx.DataType.COMPLEX64:
+                    data = new Array(buffer.length >> 3);
+                    for (let i = 0; i < data.length; i++) {
+                        data[i] = view.getComplex64(i << 3, true);
+                    }
+                    break;
+                case onnx.DataType.COMPLEX128:
+                    data = new Array(buffer.length >> 4);
+                    for (let i = 0; i < data.length; i++) {
+                        data[i] = view.getComplex64(i << 4, true);
+                    }
+                    break;
                 default:
                     throw new onnx.Error("Unsupported tensor data type '" + type + "'.");
             }
@@ -1110,19 +1124,26 @@ onnx.Tensor = class {
             result.push(indentation + ']');
             return result.join('\n');
         }
-        if (typeof value == 'string') {
-            return indentation + value;
-        }
-        if (value == Infinity) {
-            return indentation + 'Infinity';
-        }
-        if (value == -Infinity) {
-            return indentation + '-Infinity';
-        }
-        if (isNaN(value)) {
-            return indentation + 'NaN';
+        switch (typeof value) {
+            case 'string':
+                return indentation + value;
+            case 'number':
+                if (value == Infinity) {
+                    return indentation + 'Infinity';
+                }
+                if (value == -Infinity) {
+                    return indentation + '-Infinity';
+                }
+                if (isNaN(value)) {
+                    return indentation + 'NaN';
+                }
+                return indentation + value.toString();
+            default:
+                if (value.toString) {
+                    return indentation + value.toString();
+                }
+                return indentation + '(undefined)';
         }
-        return indentation + value.toString();
     }
 };
 

+ 7 - 1
source/python.js

@@ -2694,7 +2694,7 @@ python.Execution = class {
             if (descr[0] !== '<' && descr[0] !== '>') {
                 throw new numpy.Error("Unsupported byte order '" + descr + "'.");
             }
-            if (descr.length !== 3 || (descr[1] !== 'f' && descr[1] !== 'i' && descr[1] !== 'u' && descr.substring(1) !== 'b1')) {
+            if (descr.length !== 3 || (descr[1] !== 'f' && descr[1] !== 'i' && descr[1] !== 'u' && descr[1] !== 'c' && descr.substring(1) !== 'b1')) {
                 throw new numpy.Error("Unsupported data type '" + descr + "'.");
             }
             let shape = '';
@@ -2756,6 +2756,12 @@ python.Execution = class {
                             case 'u8':
                                 context.view.setUint64(context.position, data[i], littleendian);
                                 break;
+                            case 'c8':
+                                context.view.setComplex64(context.position, data[i], littleendian);
+                                break;
+                            case 'c16':
+                                context.view.setComplex128(context.position, data[i], littleendian);
+                                break;
                             default:
                                 throw new numpy.Error("Unsupported tensor data type '" + context.dtype + "'.");
                         }

+ 7 - 0
test/models.json

@@ -3238,6 +3238,13 @@
     "action":   "skip-render",
     "link":     "https://github.com/lutzroeder/netron/issues/589"
   },
+  {
+    "type":     "onnx",
+    "target":   "complex_init.onnx",
+    "source":   "https://github.com/lutzroeder/netron/files/8582128/complex_init.onnx.zip[complex_init.onnx]",
+    "format":   "ONNX v4",
+    "link":     "https://github.com/lutzroeder/netron/issues/905"
+  },
   {
     "type":     "onnx",
     "target":   "conv_autopad.onnx",