瀏覽代碼

Update complex type notation

Lutz Roeder 2 月之前
父節點
當前提交
18707009e7
共有 12 個文件被更改,包括 117 次插入74 次删除
  1. 20 23
      source/base.js
  2. 3 3
      source/executorch.js
  3. 71 32
      source/mlir.js
  4. 2 2
      source/mnn.js
  5. 1 1
      source/om.js
  6. 2 0
      source/onnx.js
  7. 2 2
      source/paddle.js
  8. 5 5
      source/python.js
  9. 1 1
      source/rknn.js
  10. 4 2
      source/tf.js
  11. 4 1
      source/tflite.js
  12. 2 2
      test/models.json

+ 20 - 23
source/base.js

@@ -1,21 +1,7 @@
 
 const base = {};
 
-base.Complex64 = class Complex64 {
-
-    constructor(real, imaginary) {
-        this.real = real;
-        this.imaginary = imaginary;
-    }
-
-    toString(/* radix */) {
-        const sign = this.imaginary < 0 ? '-' : '+';
-        const imaginary = Math.abs(this.imaginary);
-        return `${this.real} ${sign} ${imaginary}i`;
-    }
-};
-
-base.Complex128 = class Complex128 {
+base.Complex = class Complex {
 
     constructor(real, imaginary) {
         this.real = real;
@@ -240,10 +226,16 @@ DataView.prototype.getUintBits = DataView.prototype.getUintBits || function(offs
     return value & ((1 << bits) - 1);
 };
 
+DataView.prototype.getComplex32 = DataView.prototype.getComplex32 || function(byteOffset, littleEndian) {
+    const real = littleEndian ? this.getFloat16(byteOffset, littleEndian) : this.getFloat16(byteOffset + 4, littleEndian);
+    const imaginary = littleEndian ? this.getFloat16(byteOffset + 4, littleEndian) : this.getFloat16(byteOffset, littleEndian);
+    return new base.Complex(real, imaginary);
+};
+
 DataView.prototype.getComplex64 = DataView.prototype.getComplex64 || function(byteOffset, littleEndian) {
     const real = littleEndian ? this.getFloat32(byteOffset, littleEndian) : this.getFloat32(byteOffset + 4, littleEndian);
     const imaginary = littleEndian ? this.getFloat32(byteOffset + 4, littleEndian) : this.getFloat32(byteOffset, littleEndian);
-    return new base.Complex64(real, imaginary);
+    return new base.Complex(real, imaginary);
 };
 
 DataView.prototype.setComplex64 = DataView.prototype.setComplex64 || function(byteOffset, value, littleEndian) {
@@ -259,7 +251,7 @@ DataView.prototype.setComplex64 = DataView.prototype.setComplex64 || function(by
 DataView.prototype.getComplex128 = DataView.prototype.getComplex128 || function(byteOffset, littleEndian) {
     const real = littleEndian ? this.getFloat64(byteOffset, littleEndian) : this.getFloat64(byteOffset + 8, littleEndian);
     const imaginary = littleEndian ? this.getFloat64(byteOffset + 8, littleEndian) : this.getFloat64(byteOffset, littleEndian);
-    return new base.Complex128(real, imaginary);
+    return new base.Complex(real, imaginary);
 };
 
 DataView.prototype.setComplex128 = DataView.prototype.setComplex128 || function(byteOffset, value, littleEndian) {
@@ -614,7 +606,7 @@ base.Tensor = class {
             ['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8],
             ['uint8', 1], ['uint16', 2], ['uint32', 4,], ['uint64', 8],
             ['float16', 2], ['float32', 4], ['float64', 8], ['bfloat16', 2],
-            ['complex64', 8], ['complex128', 16],
+            ['complex<float32>', 8], ['complex<float64>', 16],
             ['float8e4m3fn', 1], ['float8e4m3fnuz', 1], ['float8e5m2', 1], ['float8e5m2fnuz', 1]
         ]);
     }
@@ -789,7 +781,8 @@ base.Tensor = class {
                     }
                     case '|': {
                         context.data = this.values;
-                        if (!base.Tensor._dataTypes.has(dataType) && dataType !== 'string' && dataType !== 'object' && dataType !== 'datetime' && dataType !== 'void') {
+                        const integer = (dataType.startsWith('int') && !isNaN(parseInt(dataType.substring(3), 10))) || (dataType.startsWith('uint') && !isNaN(parseInt(dataType.substring(4), 10)));
+                        if (!base.Tensor._dataTypes.has(dataType) && dataType !== 'string' && dataType !== 'object' && dataType !== 'datetime' && dataType !== 'void' && !integer) {
                             throw new Error(`Tensor data type '${dataType}' is not implemented.`);
                         }
                         const size = context.dimensions.reduce((a, v) => a * v, 1);
@@ -929,12 +922,17 @@ base.Tensor = class {
                         results.push(view.getBfloat16(offset, this._littleEndian));
                     }
                     break;
-                case 'complex64':
+                case 'complex<float16>':
+                    for (; offset < max; offset += stride) {
+                        results.push(view.getComplex32(offset, this._littleEndian));
+                    }
+                    break;
+                case 'complex<float32>':
                     for (; offset < max; offset += stride) {
                         results.push(view.getComplex64(offset, this._littleEndian));
                     }
                     break;
-                case 'complex128':
+                case 'complex<float64>':
                     for (; offset < max; offset += stride) {
                         results.push(view.getComplex128(offset, this._littleEndian));
                     }
@@ -1279,8 +1277,7 @@ base.Metadata = class {
     }
 };
 
-export const Complex64 = base.Complex64;
-export const Complex128 = base.Complex128;
+export const Complex = base.Complex;
 export const BinaryStream = base.BinaryStream;
 export const BinaryReader = base.BinaryReader;
 export const Tensor = base.Tensor;

+ 3 - 3
source/executorch.js

@@ -246,9 +246,9 @@ executorch.TensorType = class {
             case ScalarType.HALF: this.dataType = 'float16'; break;
             case ScalarType.FLOAT: this.dataType = 'float32'; break;
             case ScalarType.DOUBLE: this.dataType = 'float64'; break;
-            case 8: this.dataType = 'complex32'; break;
-            case 9: this.dataType = 'complex64'; break;
-            case 10: this.dataType = 'complex128'; break;
+            case 8: this.dataType = 'complex<float16>'; break;
+            case 9: this.dataType = 'complex<float32>'; break;
+            case 10: this.dataType = 'complex<float64>'; break;
             case ScalarType.BOOL: this.dataType = 'boolean'; break;
             case ScalarType.QINT8: this.dataType = 'qint8'; break;
             case ScalarType.QUINT8: this.dataType = 'quint8'; break;

+ 71 - 32
source/mlir.js

@@ -3350,40 +3350,44 @@ mlir.TensorLiteralParser = class {
         // Handle complex types
         // Reference: Complex types have N*2 elements or complex splat
         if (isComplex && Array.isArray(this._storage)) {
-            // Convert complex float pairs to binary format
-            const convertComplexToBinary = (typeStr, numElements) =>  {
-                const isComplex64 = typeStr.includes('complex<f32>') || typeStr.includes('complex64');
-                const bytesPerFloat = isComplex64 ? 4 : 8;
-                const buffer = new ArrayBuffer(numElements * 2 * bytesPerFloat);
-                const view = new DataView(buffer);
-                // For splat, expand single complex value
-                const isSplat = this._shape.length === 0 && this._storage.length === 2;
-                for (let i = 0; i < numElements; i++) {
-                    const srcIdx = isSplat ? 0 : i * 2;
-                    const real = typeof this._storage[srcIdx] === 'string' ? parseFloat(this._storage[srcIdx]) : this._storage[srcIdx];
-                    const imag = typeof this._storage[srcIdx + 1] === 'string' ? parseFloat(this._storage[srcIdx + 1]) : this._storage[srcIdx + 1];
-                    const offset = i * 2 * bytesPerFloat;
-                    if (isComplex64) {
-                        view.setFloat32(offset, real, true);
-                        view.setFloat32(offset + 4, imag, true);
-                    } else {
-                        view.setFloat64(offset, real, true);
-                        view.setFloat64(offset + 8, imag, true);
+            const isFloatComplex = typeStr.includes('complex<f32>') || typeStr.includes('complex<f64>') || typeStr.includes('complex64') || typeStr.includes('complex128');
+            if (isFloatComplex) {
+                // Convert complex float pairs to binary format
+                const convertComplexToBinary = (typeStr, numElements) =>  {
+                    const isComplex64 = typeStr.includes('complex<f32>') || typeStr.includes('complex64');
+                    const bytesPerFloat = isComplex64 ? 4 : 8;
+                    const buffer = new ArrayBuffer(numElements * 2 * bytesPerFloat);
+                    const view = new DataView(buffer);
+                    // For splat, expand single complex value
+                    const isSplat = this._shape.length === 0 && this._storage.length === 2;
+                    for (let i = 0; i < numElements; i++) {
+                        const srcIdx = isSplat ? 0 : i * 2;
+                        const real = typeof this._storage[srcIdx] === 'string' ? parseFloat(this._storage[srcIdx]) : this._storage[srcIdx];
+                        const imag = typeof this._storage[srcIdx + 1] === 'string' ? parseFloat(this._storage[srcIdx + 1]) : this._storage[srcIdx + 1];
+                        const offset = i * 2 * bytesPerFloat;
+                        if (isComplex64) {
+                            view.setFloat32(offset, real, true);
+                            view.setFloat32(offset + 4, imag, true);
+                        } else {
+                            view.setFloat64(offset, real, true);
+                            view.setFloat64(offset + 8, imag, true);
+                        }
                     }
-                }
-                return new Uint8Array(buffer);
-            };
-            const isSplat = this._shape.length === 0 && numElements !== 0;
-            if (isSplat) {
-                // Complex splat should have exactly 2 elements (real, imag)
-                if (this._storage.length === 2 && numElements <= maxSplatExpansion) {
-                    // Convert to binary format for proper complex handling
+                    return new Uint8Array(buffer);
+                };
+                const isSplat = this._shape.length === 0 && numElements !== 0;
+                if (isSplat) {
+                    // Complex splat should have exactly 2 elements (real, imag)
+                    if (this._storage.length === 2 && numElements <= maxSplatExpansion) {
+                        // Convert to binary format for proper complex handling
+                        return convertComplexToBinary(typeStr, numElements);
+                    }
+                } else if (numElements > 0 && numElements <= maxSplatExpansion) {
+                    // Non-splat should have numElements * 2 values
                     return convertComplexToBinary(typeStr, numElements);
                 }
-            } else if (numElements > 0 && numElements <= maxSplatExpansion) {
-                // Non-splat should have numElements * 2 values
-                return convertComplexToBinary(typeStr, numElements);
             }
+            // For non-float complex types (like complex<i32>), return the storage array directly
         }
         // Handle splats for non-complex types
         // Reference: if shape.empty() and storage has elements, it's a splat
@@ -3829,8 +3833,8 @@ mlir.Utility = class {
             case 'ui16': return 'uint16';
             case 'ui32': return 'uint32';
             case 'ui64': return 'uint64';
-            case 'complex<f32>': return 'complex64';
-            case 'complex<f64>': return 'complex128';
+            case 'complex<f32>': return 'complex<float32>';
+            case 'complex<f64>': return 'complex<float64>';
             case 'b8': return 'int8';
             case 'unk': return 'unk'; // torch dialect unknown dtype
             default:
@@ -3845,6 +3849,24 @@ mlir.Utility = class {
                 if (value && value.startsWith('memref<') && value.endsWith('>')) {
                     return value;
                 }
+                // Handle complex types with arbitrary element types (complex<i32>, complex<f16>, etc.)
+                if (value && value.startsWith('complex<') && value.endsWith('>')) {
+                    return value;
+                }
+                // Handle arbitrary integer types (i3, i6, i9, si7, ui13, etc.)
+                if (value && /^[su]?i[0-9]+$/.test(value)) {
+                    const match = value.match(/^(s|u)?i([0-9]+)$/);
+                    if (match) {
+                        const [, signed, widthStr] = match;
+                        const width = parseInt(widthStr, 10);
+                        if (signed === 'u') {
+                            return `uint${width}`;
+                        } else if (signed === 's') {
+                            return `int${width}`;
+                        }
+                        return width === 1 ? 'boolean' : `int${width}`;
+                    }
+                }
                 throw new mlir.Error(`Unknown data type '${value}'.`);
         }
     }
@@ -11093,6 +11115,23 @@ mlir.SPIRVDialect = class extends mlir.Dialect {
             }
             return true;
         }
+        // Reference: SPIRVOps.cpp parseArithmeticExtendedBinaryOp
+        // Format: spirv.IAddCarry %op1, %op2 : !spirv.struct<(i32, i32)>
+        const arithmeticExtendedOps = new Set([
+            'spirv.IAddCarry', 'spv.IAddCarry',
+            'spirv.ISubBorrow', 'spv.ISubBorrow',
+            'spirv.SMulExtended', 'spv.SMulExtended',
+            'spirv.UMulExtended', 'spv.UMulExtended'
+        ]);
+        if (arithmeticExtendedOps.has(opName)) {
+            parser.parseOptionalAttrDict(op.attributes);
+            op.operands = parser.parseArguments();
+            if (parser.accept(':')) {
+                const resultType = parser.parseType();
+                op.results.push({ type: resultType });
+            }
+            return true;
+        }
         return super.parseOperation(parser, opName, op);
     }
 };

+ 2 - 2
source/mnn.js

@@ -334,7 +334,7 @@ mnn.Utility = class {
             case mnn.schema.DataType.DT_INT16: return 'int16';
             case mnn.schema.DataType.DT_INT8: return 'int8';
             case mnn.schema.DataType.DT_STRING: return 'string';
-            case mnn.schema.DataType.DT_COMPLEX64: return 'complex64';
+            case mnn.schema.DataType.DT_COMPLEX64: return 'complex<float32>';
             case mnn.schema.DataType.DT_INT64: return 'int64';
             case mnn.schema.DataType.DT_BOOL: return 'boolean';
             case mnn.schema.DataType.DT_QINT8: return 'qint8';
@@ -344,7 +344,7 @@ mnn.Utility = class {
             case mnn.schema.DataType.DT_QINT16: return 'qint16';
             case mnn.schema.DataType.DT_QUINT16: return 'quint16';
             case mnn.schema.DataType.DT_UINT16: return 'uint16';
-            case mnn.schema.DataType.DT_COMPLEX128: return 'complex128';
+            case mnn.schema.DataType.DT_COMPLEX128: return 'complex<float64>';
             case mnn.schema.DataType.DT_HALF: return 'float16';
             case mnn.schema.DataType.DT_RESOURCE: return 'resource';
             case mnn.schema.DataType.DT_VARIANT: return 'variant';

+ 1 - 1
source/om.js

@@ -496,7 +496,7 @@ om.Utility = class {
         om.Utility._types = om.Utility._types || [
             'undefined', 'float32', 'float16', 'int8', 'uint8', 'int16', 'uint16', 'int32',
             'int64', 'uint32', 'uint64', 'boolean', 'float64', 'string', 'dual_sub_int8', 'dual_sub_uint8',
-            'complex64', 'complex128', 'qint8', 'qint16', 'qint32', 'quint8', 'quint16', 'resource',
+            'complex<float32>', 'complex<float64>', 'qint8', 'qint16', 'qint32', 'quint8', 'quint16', 'resource',
             'stringref', 'dual', 'variant', 'bfloat16', 'int4', 'uint1', 'int2', 'uint2'
         ];
         if (value >= om.Utility._types.length) {

+ 2 - 0
source/onnx.js

@@ -821,6 +821,8 @@ onnx.Context.Model = class {
         this._dataTypes.set(onnx.DataType.BOOL, 'boolean');
         this._dataTypes.set(onnx.DataType.FLOAT, 'float32');
         this._dataTypes.set(onnx.DataType.DOUBLE, 'float64');
+        this._dataTypes.set(onnx.DataType.COMPLEX64, 'complex<float32>');
+        this._dataTypes.set(onnx.DataType.COMPLEX128, 'complex<float64>');
         this._imageFormat = imageFormat;
         this._imports = imports;
         this._types = new Map();

+ 2 - 2
source/paddle.js

@@ -1247,8 +1247,8 @@ paddle.IR = class {
             case 'i16': return 'int16';
             case 'i32': return 'int32';
             case 'i64': return 'int64';
-            case 'c64': return 'complex64';
-            case 'c128': return 'complex128';
+            case 'c64': return 'complex<float32>';
+            case 'c128': return 'complex<float64>';
             case 'str': return 'string';
             default: return type;
         }

+ 5 - 5
source/python.js

@@ -2748,8 +2748,8 @@ python.Execution = class {
                         }
                     case 'c':
                         switch (this.itemsize) {
-                            case 8: return 'complex64';
-                            case 16: return 'complex128';
+                            case 8: return 'complex<float32>';
+                            case 16: return 'complex<float64>';
                             default: throw new python.Error(`Unsupported complex itemsize '${this.itemsize}'.`);
                         }
                     case 'S':
@@ -20574,9 +20574,9 @@ python.Execution = class {
         torch.float16 = torch.HalfStorage.dtype = new torch.dtype(5, 'float16', 2);
         torch.float32 = torch.FloatStorage.dtype = new torch.dtype(6, 'float32', 4);
         torch.float64 = torch.DoubleStorage.dtype = new torch.dtype(7, 'float64', 8);
-        torch.complex32 = torch.ComplexHalfStorage.dtype = new torch.dtype(8, 'complex32', 4);
-        torch.complex64 = torch.ComplexFloatStorage.dtype = new torch.dtype(9, 'complex64', 8);
-        torch.complex128 = torch.ComplexDoubleStorage.dtype = new torch.dtype(10, 'complex128', 16);
+        torch.complex32 = torch.ComplexHalfStorage.dtype = new torch.dtype(8, 'complex<float16>', 4);
+        torch.complex64 = torch.ComplexFloatStorage.dtype = new torch.dtype(9, 'complex<float32>', 8);
+        torch.complex128 = torch.ComplexDoubleStorage.dtype = new torch.dtype(10, 'complex<float64>', 16);
         torch.bool = torch.BoolStorage.dtype = new torch.dtype(11, 'boolean', 1);
         torch.qint8 = torch.QInt8Storage.dtype = new torch.dtype(12, 'qint8', 1);
         torch.quint8 = torch.QUInt8Storage.dtype = new torch.dtype(13, 'quint8', 1);

+ 1 - 1
source/rknn.js

@@ -183,7 +183,7 @@ rknn.Graph = class {
             }
             case 'flatbuffers': {
                 const graph = obj;
-                const dataTypes = ['undefined', 'float32', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'boolean', 'float16', 'float64', 'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16'];
+                const dataTypes = ['undefined', 'float32', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'boolean', 'float16', 'float64', 'uint32', 'uint64', 'complex<float32>', 'complex<float64>', 'bfloat16'];
                 const args = graph.tensors.map((tensor) => {
                     const shape = new rknn.TensorShape(Array.from(tensor.shape));
                     const dataType = tensor.data_type < dataTypes.length ? dataTypes[tensor.data_type] : '?';

+ 4 - 2
source/tf.js

@@ -1195,7 +1195,7 @@ tf.Tensor = class {
                         const values = tensor.scomplex_val || null;
                         this._values = new Array(values.length >> 1);
                         for (let i = 0; i < values.length; i += 2) {
-                            this._values[i >> 1] = new base.Complex64(values[i], values[i + 1]);
+                            this._values[i >> 1] = new base.Complex(values[i], values[i + 1]);
                         }
                         this.encoding = '|';
                         break;
@@ -1204,7 +1204,7 @@ tf.Tensor = class {
                         const values = tensor.dcomplex_val || null;
                         this._values = new Array(values.length >> 1);
                         for (let i = 0; i < values.length; i += 2) {
-                            this._values[i >> 1] = new base.Complex128(values[i], values[i + 1]);
+                            this._values[i >> 1] = new base.Complex(values[i], values[i + 1]);
                         }
                         this.encoding = '|';
                         break;
@@ -2179,6 +2179,8 @@ tf.Utility = class {
             dataTypes.set(DataType.DT_FLOAT, 'float32');
             dataTypes.set(DataType.DT_DOUBLE, 'float64');
             dataTypes.set(DataType.DT_BOOL, 'boolean');
+            dataTypes.set(DataType.DT_COMPLEX64, 'complex<float32>');
+            dataTypes.set(DataType.DT_COMPLEX128, 'complex<float64>');
             tf.Utility._dataTypes = dataTypes;
         }
         return tf.Utility._dataTypes.has(type) ? tf.Utility._dataTypes.get(type) : '?';

+ 4 - 1
source/tflite.js

@@ -520,8 +520,11 @@ tflite.Utility = class {
 
     static dataType(type) {
         if (!tflite.Utility._tensorTypes) {
+            const TensorType = tflite.schema.TensorType;
             tflite.Utility._tensorTypes = new Map(Object.entries(tflite.schema.TensorType).map(([key, value]) => [value, key.toLowerCase()]));
-            tflite.Utility._tensorTypes.set(6, 'boolean');
+            tflite.Utility._tensorTypes.set(TensorType.BOOL, 'boolean');
+            tflite.Utility._tensorTypes.set(tflite.schema.TensorType.COMPLEX64, 'complex<float32>');
+            tflite.Utility._tensorTypes.set(tflite.schema.TensorType.COMPLEX128, 'complex<float64>');
         }
         return tflite.Utility._tensorTypes.has(type) ? tflite.Utility._tensorTypes.get(type) : '?';
     }

+ 2 - 2
test/models.json

@@ -4087,7 +4087,7 @@
     "target":   "complex128.npy",
     "source":   "https://github.com/user-attachments/files/16554725/complex128.npy.zip[complex128.npy]",
     "format":   "NumPy Array",
-    "assert":   "model.modules[0].nodes[0].inputs[0].value[0].type.dataType == 'complex128'",
+    "assert":   "model.modules[0].nodes[0].inputs[0].value[0].type.dataType == 'complex<float64>'",
     "link":     "https://github.com/lutzroeder/netron/issues/711"
   },
   {
@@ -5745,7 +5745,7 @@
     "target":   "complex_tensor.pt",
     "source":   "https://github.com/lutzroeder/netron/files/9108149/complex_tensor.pt.zip[complex_tensor.pt]",
     "format":   "PyTorch v1.6",
-    "assert":   "model.modules[0].nodes[0].inputs[0].value[0].type.dataType == 'complex64'",
+    "assert":   "model.modules[0].nodes[0].inputs[0].value[0].type.dataType == 'complex<float32>'",
     "link":     "https://github.com/lutzroeder/netron/issues/720"
   },
   {