|
|
@@ -3350,40 +3350,44 @@ mlir.TensorLiteralParser = class {
|
|
|
// Handle complex types
|
|
|
// Reference: Complex types have N*2 elements or complex splat
|
|
|
if (isComplex && Array.isArray(this._storage)) {
|
|
|
- // Convert complex float pairs to binary format
|
|
|
- const convertComplexToBinary = (typeStr, numElements) => {
|
|
|
- const isComplex64 = typeStr.includes('complex<f32>') || typeStr.includes('complex64');
|
|
|
- const bytesPerFloat = isComplex64 ? 4 : 8;
|
|
|
- const buffer = new ArrayBuffer(numElements * 2 * bytesPerFloat);
|
|
|
- const view = new DataView(buffer);
|
|
|
- // For splat, expand single complex value
|
|
|
- const isSplat = this._shape.length === 0 && this._storage.length === 2;
|
|
|
- for (let i = 0; i < numElements; i++) {
|
|
|
- const srcIdx = isSplat ? 0 : i * 2;
|
|
|
- const real = typeof this._storage[srcIdx] === 'string' ? parseFloat(this._storage[srcIdx]) : this._storage[srcIdx];
|
|
|
- const imag = typeof this._storage[srcIdx + 1] === 'string' ? parseFloat(this._storage[srcIdx + 1]) : this._storage[srcIdx + 1];
|
|
|
- const offset = i * 2 * bytesPerFloat;
|
|
|
- if (isComplex64) {
|
|
|
- view.setFloat32(offset, real, true);
|
|
|
- view.setFloat32(offset + 4, imag, true);
|
|
|
- } else {
|
|
|
- view.setFloat64(offset, real, true);
|
|
|
- view.setFloat64(offset + 8, imag, true);
|
|
|
+ const isFloatComplex = typeStr.includes('complex<f32>') || typeStr.includes('complex<f64>') || typeStr.includes('complex64') || typeStr.includes('complex128');
|
|
|
+ if (isFloatComplex) {
|
|
|
+ // Convert complex float pairs to binary format
|
|
|
+ const convertComplexToBinary = (typeStr, numElements) => {
|
|
|
+ const isComplex64 = typeStr.includes('complex<f32>') || typeStr.includes('complex64');
|
|
|
+ const bytesPerFloat = isComplex64 ? 4 : 8;
|
|
|
+ const buffer = new ArrayBuffer(numElements * 2 * bytesPerFloat);
|
|
|
+ const view = new DataView(buffer);
|
|
|
+ // For splat, expand single complex value
|
|
|
+ const isSplat = this._shape.length === 0 && this._storage.length === 2;
|
|
|
+ for (let i = 0; i < numElements; i++) {
|
|
|
+ const srcIdx = isSplat ? 0 : i * 2;
|
|
|
+ const real = typeof this._storage[srcIdx] === 'string' ? parseFloat(this._storage[srcIdx]) : this._storage[srcIdx];
|
|
|
+ const imag = typeof this._storage[srcIdx + 1] === 'string' ? parseFloat(this._storage[srcIdx + 1]) : this._storage[srcIdx + 1];
|
|
|
+ const offset = i * 2 * bytesPerFloat;
|
|
|
+ if (isComplex64) {
|
|
|
+ view.setFloat32(offset, real, true);
|
|
|
+ view.setFloat32(offset + 4, imag, true);
|
|
|
+ } else {
|
|
|
+ view.setFloat64(offset, real, true);
|
|
|
+ view.setFloat64(offset + 8, imag, true);
|
|
|
+ }
|
|
|
}
|
|
|
- }
|
|
|
- return new Uint8Array(buffer);
|
|
|
- };
|
|
|
- const isSplat = this._shape.length === 0 && numElements !== 0;
|
|
|
- if (isSplat) {
|
|
|
- // Complex splat should have exactly 2 elements (real, imag)
|
|
|
- if (this._storage.length === 2 && numElements <= maxSplatExpansion) {
|
|
|
- // Convert to binary format for proper complex handling
|
|
|
+ return new Uint8Array(buffer);
|
|
|
+ };
|
|
|
+ const isSplat = this._shape.length === 0 && numElements !== 0;
|
|
|
+ if (isSplat) {
|
|
|
+ // Complex splat should have exactly 2 elements (real, imag)
|
|
|
+ if (this._storage.length === 2 && numElements <= maxSplatExpansion) {
|
|
|
+ // Convert to binary format for proper complex handling
|
|
|
+ return convertComplexToBinary(typeStr, numElements);
|
|
|
+ }
|
|
|
+ } else if (numElements > 0 && numElements <= maxSplatExpansion) {
|
|
|
+ // Non-splat should have numElements * 2 values
|
|
|
return convertComplexToBinary(typeStr, numElements);
|
|
|
}
|
|
|
- } else if (numElements > 0 && numElements <= maxSplatExpansion) {
|
|
|
- // Non-splat should have numElements * 2 values
|
|
|
- return convertComplexToBinary(typeStr, numElements);
|
|
|
}
|
|
|
+ // For non-float complex types (like complex<i32>), return the storage array directly
|
|
|
}
|
|
|
// Handle splats for non-complex types
|
|
|
// Reference: if shape.empty() and storage has elements, it's a splat
|
|
|
@@ -3829,8 +3833,8 @@ mlir.Utility = class {
|
|
|
case 'ui16': return 'uint16';
|
|
|
case 'ui32': return 'uint32';
|
|
|
case 'ui64': return 'uint64';
|
|
|
- case 'complex<f32>': return 'complex64';
|
|
|
- case 'complex<f64>': return 'complex128';
|
|
|
+ case 'complex<f32>': return 'complex<float32>';
|
|
|
+ case 'complex<f64>': return 'complex<float64>';
|
|
|
case 'b8': return 'int8';
|
|
|
case 'unk': return 'unk'; // torch dialect unknown dtype
|
|
|
default:
|
|
|
@@ -3845,6 +3849,24 @@ mlir.Utility = class {
|
|
|
if (value && value.startsWith('memref<') && value.endsWith('>')) {
|
|
|
return value;
|
|
|
}
|
|
|
+ // Handle complex types with arbitrary element types (complex<i32>, complex<f16>, etc.)
|
|
|
+ if (value && value.startsWith('complex<') && value.endsWith('>')) {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+ // Handle arbitrary integer types (i3, i6, i9, si7, ui13, etc.)
|
|
|
+ if (value && /^[su]?i[0-9]+$/.test(value)) {
|
|
|
+ const match = value.match(/^(s|u)?i([0-9]+)$/);
|
|
|
+ if (match) {
|
|
|
+ const [, signed, widthStr] = match;
|
|
|
+ const width = parseInt(widthStr, 10);
|
|
|
+ if (signed === 'u') {
|
|
|
+ return `uint${width}`;
|
|
|
+ } else if (signed === 's') {
|
|
|
+ return `int${width}`;
|
|
|
+ }
|
|
|
+ return width === 1 ? 'boolean' : `int${width}`;
|
|
|
+ }
|
|
|
+ }
|
|
|
throw new mlir.Error(`Unknown data type '${value}'.`);
|
|
|
}
|
|
|
}
|
|
|
@@ -11093,6 +11115,23 @@ mlir.SPIRVDialect = class extends mlir.Dialect {
|
|
|
}
|
|
|
return true;
|
|
|
}
|
|
|
+ // Reference: SPIRVOps.cpp parseArithmeticExtendedBinaryOp
|
|
|
+ // Format: spirv.IAddCarry %op1, %op2 : !spirv.struct<(i32, i32)>
|
|
|
+ const arithmeticExtendedOps = new Set([
|
|
|
+ 'spirv.IAddCarry', 'spv.IAddCarry',
|
|
|
+ 'spirv.ISubBorrow', 'spv.ISubBorrow',
|
|
|
+ 'spirv.SMulExtended', 'spv.SMulExtended',
|
|
|
+ 'spirv.UMulExtended', 'spv.UMulExtended'
|
|
|
+ ]);
|
|
|
+ if (arithmeticExtendedOps.has(opName)) {
|
|
|
+ parser.parseOptionalAttrDict(op.attributes);
|
|
|
+ op.operands = parser.parseArguments();
|
|
|
+ if (parser.accept(':')) {
|
|
|
+ const resultType = parser.parseType();
|
|
|
+ op.results.push({ type: resultType });
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
return super.parseOperation(parser, opName, op);
|
|
|
}
|
|
|
};
|