Bladeren bron

Add MLIR test files (#1044)

Lutz Roeder 2 maanden geleden
bovenliggende
commit
debc283f72
3 gewijzigde bestanden met toevoegingen van 277 en 203 verwijderingen
  1. 24 9
      source/mlir-metadata.json
  2. 222 194
      source/mlir.js
  3. 31 0
      test/models.json

+ 24 - 9
source/mlir-metadata.json

@@ -38653,12 +38653,12 @@
     "description": "Broadcasts memory load ofbits of data for a cluster of workgroups.\n\n      Available on gfx1250+.",
     "operands": [
       { "name": "globalPtr", "type": "ROCDLGlobalBuffer" },
-      { "name": "ldsPtr", "type": "ROCDLBufferLDS" }
+      { "name": "ldsPtr", "type": "ROCDLBufferLDS" },
+      { "name": "mask", "type": "I32" }
     ],
     "attributes": [
       { "name": "offset", "type": "I32Attr" },
       { "name": "cpol", "type": "I32Attr" },
-      { "name": "mask", "type": "I32Attr" },
       { "name": "alias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "noalias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "tbaa", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_TBAATagAttr>>" }
@@ -38670,12 +38670,12 @@
     "description": "Broadcasts memory load ofbits of data for a cluster of workgroups.\n\n      Available on gfx1250+.",
     "operands": [
       { "name": "globalPtr", "type": "ROCDLGlobalBuffer" },
-      { "name": "ldsPtr", "type": "ROCDLBufferLDS" }
+      { "name": "ldsPtr", "type": "ROCDLBufferLDS" },
+      { "name": "mask", "type": "I32" }
     ],
     "attributes": [
       { "name": "offset", "type": "I32Attr" },
       { "name": "cpol", "type": "I32Attr" },
-      { "name": "mask", "type": "I32Attr" },
       { "name": "alias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "noalias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "tbaa", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_TBAATagAttr>>" }
@@ -38687,12 +38687,12 @@
     "description": "Broadcasts memory load ofbits of data for a cluster of workgroups.\n\n      Available on gfx1250+.",
     "operands": [
       { "name": "globalPtr", "type": "ROCDLGlobalBuffer" },
-      { "name": "ldsPtr", "type": "ROCDLBufferLDS" }
+      { "name": "ldsPtr", "type": "ROCDLBufferLDS" },
+      { "name": "mask", "type": "I32" }
     ],
     "attributes": [
       { "name": "offset", "type": "I32Attr" },
       { "name": "cpol", "type": "I32Attr" },
-      { "name": "mask", "type": "I32Attr" },
       { "name": "alias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "noalias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "tbaa", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_TBAATagAttr>>" }
@@ -38704,12 +38704,12 @@
     "description": "Broadcasts memory load ofbits of data for a cluster of workgroups.\n\n      Available on gfx1250+.",
     "operands": [
       { "name": "globalPtr", "type": "ROCDLGlobalBuffer" },
-      { "name": "ldsPtr", "type": "ROCDLBufferLDS" }
+      { "name": "ldsPtr", "type": "ROCDLBufferLDS" },
+      { "name": "mask", "type": "I32" }
     ],
     "attributes": [
       { "name": "offset", "type": "I32Attr" },
       { "name": "cpol", "type": "I32Attr" },
-      { "name": "mask", "type": "I32Attr" },
       { "name": "alias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "noalias_scopes", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_AliasScopeAttr>>" },
       { "name": "tbaa", "type": "OptionalAttr<TypedArrayAttrBase<LLVM_TBAATagAttr>>" }
@@ -75729,7 +75729,7 @@
     "description": "Looks up ids in a list of embedding tensors.",
     "operands": [
       { "name": "lookup", "type": "TFL_TensorOf<[I32]>" },
-      { "name": "value", "type": "TFL_TensorOf<[F32, I8, UI8, QI8, QUI8, QI4]>" }
+      { "name": "value", "type": "TFL_TensorOf<[F32, I8, UI8, QI8, QUI8, QI4, QI2]>" }
     ],
     "results": [
       { "name": "output", "type": "TFL_TensorOf<[F32, I8, UI8]>" }
@@ -98075,6 +98075,21 @@
     ],
     "assemblyFormat": "operands attr-dict `:` functional-type(operands, results)"
   },
+  {
+    "name": "tosa.dim",
+    "summary": "Extract size of dimension from input tensor.",
+    "description": "Returns a length 1 shape_t of the size of the input tensor for the given axis.",
+    "operands": [
+      { "name": "input1", "type": "Tosa_TensorAtLeast1D" }
+    ],
+    "results": [
+      { "name": "output", "type": "Tosa_Shape" }
+    ],
+    "attributes": [
+      { "name": "axis", "type": "I32Attr" }
+    ],
+    "assemblyFormat": "operands attr-dict `:` functional-type(operands, results)"
+  },
   {
     "name": "tosa.div_ceil_shape",
     "summary": "Elementwise ceiling divide of shapes.",

+ 222 - 194
source/mlir.js

@@ -140,7 +140,7 @@ mlir.Model = class {
                 functions.set(name, { func, prefix, base, module });
             }
         }
-        const context = new _.Context(metadata, functions);
+        const context = new mlir.Context(metadata, functions);
         for (const [name, info] of functions) {
             const graph = context.graph(info.func, name);
             this.functions.push(graph);
@@ -194,7 +194,7 @@ mlir.Graph = class {
                 const input = inputs[i];
                 // args[i] is an _.Value with .name set by parseRegion
                 const name = args[i] && args[i].name ? args[i].name : `%arg${i}`;
-                const type = _.Utility.valueType(input.type || input);
+                const type = mlir.Utility.valueType(input.type || input);
                 const value = new mlir.Value(name, type, '', null);
                 const argument = new mlir.Argument(name, [value]);
                 this.inputs.push(argument);
@@ -202,7 +202,7 @@ mlir.Graph = class {
             for (let i = 0; i < results.length; i++) {
                 const output = results[i];
                 const name = output.value || i.toString();
-                const type = _.Utility.valueType(output.type);
+                const type = mlir.Utility.valueType(output.type);
                 const valueName = output.value || output.name || `%result${i}`;
                 const value = new mlir.Value(valueName, type, '', null);
                 const argument = new mlir.Argument(name, [value]);
@@ -248,9 +248,10 @@ mlir.Graph = class {
                         const input = op.operands[i];
                         // Determine operand name: use metadata if available, or variadic name if past metadata bounds
                         let inputName = null;
+                        const isVariadicOverflow = lastVariadicIndex >= 0 && i >= lastVariadicIndex;
                         if (opMetadata && opMetadata.operands && opMetadata.operands[i]) {
                             inputName = opMetadata.operands[i].name;
-                        } else if (lastVariadicIndex >= 0 && i >= lastVariadicIndex) {
+                        } else if (isVariadicOverflow) {
                             // Operand index exceeds metadata, use last variadic operand name
                             inputName = lastVariadicName;
                         } else {
@@ -261,8 +262,13 @@ mlir.Graph = class {
                         }
                         const value = values.map(input.name);
                         value.to.push(operation);
-                        const args = [{ name: input.name, type: input.type }];
-                        operation.operands.push({ name: inputName, value: args });
+                        const arg = { name: input.name, type: input.type };
+                        // Group variadic operands into single argument with multiple values
+                        if (isVariadicOverflow && operation.operands.length > 0 && operation.operands[operation.operands.length - 1].name === inputName) {
+                            operation.operands[operation.operands.length - 1].value.push(arg);
+                        } else {
+                            operation.operands.push({ name: inputName, value: [arg] });
+                        }
                     }
                     const results = op.results;
                     // Find the last variadic result in metadata (if any) for grouping
@@ -285,7 +291,7 @@ mlir.Graph = class {
                             continue;
                         }
                         const value = values.map(output.name);
-                        value.type = _.Utility.valueType(output.type);
+                        value.type = mlir.Utility.valueType(output.type);
                         value.from.push(operation);
                         // Determine result name: use metadata if available, or variadic name if past metadata bounds
                         let outputName = null;
@@ -320,6 +326,7 @@ mlir.Graph = class {
         for (const op of operations) {
             if (constantTypes.has(op.type) &&
                 op.operands.length === 0 &&
+                op.attributes.size === 1 &&
                 op.results.length === 1 &&
                 op.results[0].value.length === 1) {
                 const [result] = op.results[0].value;
@@ -328,7 +335,7 @@ mlir.Graph = class {
                     if ((valueAttr instanceof _.DenseElementsAttr || valueAttr instanceof _.DenseResourceElementsAttr) &&
                         valueAttr.value !== null &&
                         valueAttr.type && valueAttr.type.toString().startsWith('tensor<')) {
-                        const type = _.Utility.valueType(valueAttr.type);
+                        const type = mlir.Utility.valueType(valueAttr.type);
                         if (type instanceof mlir.TensorType) {
                             constantMap.set(result.name, new mlir.Tensor(type, valueAttr.value));
                             op.delete = true;
@@ -340,7 +347,13 @@ mlir.Graph = class {
         const tensor = (arg) => {
             if (!tensors.has(arg.name)) {
                 const initializer = constantMap.get(arg.name) || null;
-                tensors.set(arg.name, new mlir.Value(arg.name, arg.type, null, initializer));
+                let type = null;
+                if (arg.type instanceof mlir.TensorType) {
+                    type = arg.type;
+                } else if (arg.type) {
+                    type = mlir.Utility.valueType(arg.type);
+                }
+                tensors.set(arg.name, new mlir.Value(arg.name, type, null, initializer));
             }
             return tensors.get(arg.name);
         };
@@ -360,7 +373,7 @@ mlir.Graph = class {
                     const [returnValue] = operand.value;
                     if (returnValue && typeof returnValue.name === 'string' && returnValue.name.startsWith('%')) {
                         const output = this.outputs[i];
-                        const returnType = _.Utility.valueType(returnValue.type);
+                        const returnType = mlir.Utility.valueType(returnValue.type);
                         output.value[0] = new mlir.Value(returnValue.name, returnType, '', null);
                     }
                 }
@@ -506,7 +519,7 @@ mlir.Node = class {
                 if (input.type) {
                     const typeStr = input.type instanceof _.Type ? input.type.toString() : input.type;
                     if (typeStr.startsWith('tensor<')) {
-                        const type = _.Utility.valueType(typeStr);
+                        const type = mlir.Utility.valueType(typeStr);
                         const value = new mlir.Tensor(type, input.value);
                         argument = new mlir.Argument(input.name, value, 'tensor');
                     } else {
@@ -548,10 +561,10 @@ mlir.Node = class {
                         type = 'function';
                     }
                 } else if (attr instanceof _.DenseElementsAttr && attr.value !== null) {
-                    value = new mlir.Tensor(_.Utility.valueType(attr.type), attr.value);
+                    value = new mlir.Tensor(mlir.Utility.valueType(attr.type), attr.value);
                     type = 'tensor';
                 } else if (attr instanceof _.DenseResourceElementsAttr) {
-                    value = new mlir.Tensor(_.Utility.valueType(attr.type), null);
+                    value = new mlir.Tensor(mlir.Utility.valueType(attr.type), null);
                     type = 'tensor';
                 } else if (attr instanceof _.ArrayAttr || attr instanceof _.DenseArrayAttr) {
                     value = attr.value;
@@ -568,8 +581,9 @@ mlir.Node = class {
                 const region = op.regions[i];
                 if (region.blocks && region.blocks.length > 0) {
                     const name = (opMetadata.regions && opMetadata.regions[i] ? opMetadata.regions[i].name : null) || i.toString();
+                    const blockName = region.blocks[0].name || '';
                     const func = { name: '', attributes: new Map(), regions: [region] };
-                    const graph = new mlir.Graph(metadata, func, context, '');
+                    const graph = new mlir.Graph(metadata, func, context, blockName);
                     const argument = new mlir.Argument(name, graph, 'graph');
                     this.blocks.push(argument);
                 }
@@ -590,7 +604,7 @@ mlir.Tensor = class {
 mlir.TensorType = class {
 
     constructor(dataType, shape) {
-        this.dataType = _.Utility.dataType(dataType); // string
+        this.dataType = mlir.Utility.dataType(dataType); // string
         this.shape = shape || new mlir.TensorShape([]);  // mlir.TensorShape
     }
 
@@ -613,7 +627,7 @@ mlir.TensorShape = class {
     }
 };
 
-_.Context = class {
+mlir.Context = class {
 
     constructor(metadata, functions) {
         this._metadata = metadata;
@@ -662,6 +676,179 @@ _.Context = class {
     }
 };
 
+mlir.Utility = class {
+
+    static dataType(value) {
+        if (value instanceof _.ComplexType) {
+            const elementType = mlir.Utility.dataType(value.elementType);
+            return `complex<${elementType}>`;
+        }
+        if (value instanceof _.Type) {
+            value = value.toString();
+        }
+        switch (value) {
+            case 'index': return 'int64';
+            case 'f16': return 'float16';
+            case 'f32': return 'float32';
+            case 'f64': return 'float64';
+            case 'f80': return 'float80';
+            case 'f128': return 'float128';
+            case 'bf16': return 'bfloat16';
+            case 'fp8': return 'float8';
+            case 'fp8e4m3': return 'float8e4m3';
+            case 'fp8_e4m3': return 'float8e4m3';
+            case 'fp8e4m3fn': return 'float8e4m3fn';
+            case 'fp8e5m2': return 'float8e5m2';
+            case 'fp8_e5m2': return 'float8e5m2';
+            case 'f4E2M1FN': return 'float4e2m1fn';
+            case 'f6E2M3FN': return 'float6e2m3fn';
+            case 'f6E3M2FN': return 'float6e3m2fn';
+            case 'f8E3M4': return 'float8e3m4';
+            case 'f8E4M3': return 'float8e4m3';
+            case 'f8E4M3B11FNUZ': return 'float8e4m3b11fnuz';
+            case 'f8E4M3FN': return 'float8e4m3fn';
+            case 'f8E4M3FNUZ': return 'float8e4m3fnuz';
+            case 'f8E5M2': return 'float8e5m2';
+            case 'f8E5M2FNUZ': return 'float8e5m2fnuz';
+            case 'f8E8M0FNU': return 'float8e8m0fnu';
+            case 'float8': return 'float8';
+            case 'tf32': return 'tensorfloat32';
+            case 'i1': return 'int1';
+            case 'i2': return 'int2';
+            case 'i4': return 'int4';
+            case 'i8': return 'int8';
+            case 'i16': return 'int16';
+            case 'i32': return 'int32';
+            case 'i48': return 'int48';
+            case 'i64': return 'int64';
+            case 'si8': return 'int8';
+            case 'si16': return 'int16';
+            case 'si32': return 'int32';
+            case 'si64': return 'int64';
+            case 'ui1': return 'uint1';
+            case 'ui2': return 'uint2';
+            case 'ui4': return 'uint4';
+            case 'ui8': return 'uint8';
+            case 'ui16': return 'uint16';
+            case 'ui32': return 'uint32';
+            case 'ui64': return 'uint64';
+            case 'b8': return 'int8';
+            case 'unk': return 'unk'; // torch dialect unknown dtype
+            case '!tf_type.string': return 'string';
+            default:
+                if (value && value.startsWith('!')) {
+                    return value;
+                }
+                if (value && value.startsWith('vector<') && value.endsWith('>')) {
+                    return value;
+                }
+                if (value && value.startsWith('memref<') && value.endsWith('>')) {
+                    return value;
+                }
+                if (value && value.startsWith('tuple<') && value.endsWith('>')) {
+                    return value;
+                }
+                if (value && value.startsWith('complex<') && value.endsWith('>')) {
+                    const elementTypeStr = value.substring(8, value.length - 1);
+                    const convertedElementType = mlir.Utility.dataType(elementTypeStr);
+                    return `complex<${convertedElementType}>`;
+                }
+                if (value && /^[su]?i[0-9]+$/.test(value)) {
+                    const match = value.match(/^(s|u)?i([0-9]+)$/);
+                    if (match) {
+                        const [, signed, widthStr] = match;
+                        const width = parseInt(widthStr, 10);
+                        if (signed === 'u') {
+                            return `uint${width}`;
+                        } else if (signed === 's') {
+                            return `int${width}`;
+                        }
+                        return `int${width}`;
+                    }
+                }
+                throw new mlir.Error(`Unknown data type '${value}'.`);
+        }
+    }
+
+    static valueType(type) {
+        if (type === undefined) {
+            return null;
+        }
+        const typeStr = type instanceof _.Type ? type.toString() : type;
+        if (typeStr.startsWith('!') && !typeStr.startsWith('!torch.vtensor<')) {
+            return typeStr;
+        }
+        if (typeStr.startsWith('tensor<') && typeStr.endsWith('>')) {
+            const spec = typeStr.substring(7, typeStr.length - 1).trim();
+            if (spec.startsWith('!')) {
+                return mlir.Utility.valueType(spec);
+            }
+            let i = 0;
+            const shape = [];
+            while (i < spec.length) {
+                if (spec[i] === '?' || spec[i] === '*') {
+                    shape.push('?');
+                    i++;
+                } else if (/[0-9]/.test(spec[i])) {
+                    let numStr = '';
+                    while (i < spec.length && /[0-9]/.test(spec[i])) {
+                        numStr += spec[i];
+                        i++;
+                    }
+                    const dim = parseInt(numStr, 10);
+                    if (isNaN(dim)) {
+                        shape.push('?');
+                    } else {
+                        shape.push(dim);
+                    }
+                } else {
+                    break;
+                }
+                if (i < spec.length && spec[i] === 'x') {
+                    i++;
+                } else {
+                    break;
+                }
+            }
+            let dataType = spec.substring(i);
+            const encodingIndex = dataType.indexOf(',');
+            if (encodingIndex !== -1) {
+                dataType = dataType.substring(0, encodingIndex).trim();
+            }
+            return new mlir.TensorType(dataType, new mlir.TensorShape(shape));
+        }
+        if (typeStr.startsWith('!torch.vtensor<') && typeStr.endsWith('>')) {
+            const spec = typeStr.substring(15, typeStr.length - 1);
+            let shape = null;
+            let dataType = null;
+            if (spec.startsWith('[')) {
+                const bracketEnd = spec.indexOf(']');
+                const shapeStr = spec.substring(0, bracketEnd + 1);
+                const jsonStr = shapeStr.replace(/\?/g, '"?"');
+                shape = JSON.parse(jsonStr);
+                const rest = spec.substring(bracketEnd + 1);
+                if (rest.startsWith(',')) {
+                    const parts = rest.substring(1).split(',');
+                    dataType = parts[0].trim();
+                }
+            } else if (spec.startsWith('*')) {
+                if (spec.includes(',')) {
+                    const parts = spec.split(',');
+                    dataType = parts[1].trim();
+                }
+            } else {
+                const parts = spec.split(',');
+                dataType = parts[0].trim();
+            }
+            return new mlir.TensorType(dataType, shape ? new mlir.TensorShape(shape) : null);
+        }
+        if (typeStr.startsWith('tuple<') && typeStr.endsWith('>')) {
+            return typeStr;
+        }
+        return typeStr;
+    }
+};
+
 _.OperationState = class {
 
     constructor(name) {
@@ -5003,179 +5190,6 @@ _.BufferReader = class {
     }
 };
 
-_.Utility = class {
-
-    static dataType(value) {
-        if (value instanceof _.ComplexType) {
-            const elementType = _.Utility.dataType(value.elementType);
-            return `complex<${elementType}>`;
-        }
-        if (value instanceof _.Type) {
-            value = value.toString();
-        }
-        switch (value) {
-            case 'index': return 'int64';
-            case 'f16': return 'float16';
-            case 'f32': return 'float32';
-            case 'f64': return 'float64';
-            case 'f80': return 'float80';
-            case 'f128': return 'float128';
-            case 'bf16': return 'bfloat16';
-            case 'fp8': return 'float8';
-            case 'fp8e4m3': return 'float8e4m3';
-            case 'fp8_e4m3': return 'float8e4m3';
-            case 'fp8e4m3fn': return 'float8e4m3fn';
-            case 'fp8e5m2': return 'float8e5m2';
-            case 'fp8_e5m2': return 'float8e5m2';
-            case 'f4E2M1FN': return 'float4e2m1fn';
-            case 'f6E2M3FN': return 'float6e2m3fn';
-            case 'f6E3M2FN': return 'float6e3m2fn';
-            case 'f8E3M4': return 'float8e3m4';
-            case 'f8E4M3': return 'float8e4m3';
-            case 'f8E4M3B11FNUZ': return 'float8e4m3b11fnuz';
-            case 'f8E4M3FN': return 'float8e4m3fn';
-            case 'f8E4M3FNUZ': return 'float8e4m3fnuz';
-            case 'f8E5M2': return 'float8e5m2';
-            case 'f8E5M2FNUZ': return 'float8e5m2fnuz';
-            case 'f8E8M0FNU': return 'float8e8m0fnu';
-            case 'float8': return 'float8';
-            case 'tf32': return 'tensorfloat32';
-            case 'i1': return 'int1';
-            case 'i2': return 'int2';
-            case 'i4': return 'int4';
-            case 'i8': return 'int8';
-            case 'i16': return 'int16';
-            case 'i32': return 'int32';
-            case 'i48': return 'int48';
-            case 'i64': return 'int64';
-            case 'si8': return 'int8';
-            case 'si16': return 'int16';
-            case 'si32': return 'int32';
-            case 'si64': return 'int64';
-            case 'ui1': return 'uint1';
-            case 'ui2': return 'uint2';
-            case 'ui4': return 'uint4';
-            case 'ui8': return 'uint8';
-            case 'ui16': return 'uint16';
-            case 'ui32': return 'uint32';
-            case 'ui64': return 'uint64';
-            case 'b8': return 'int8';
-            case 'unk': return 'unk'; // torch dialect unknown dtype
-            case '!tf_type.string': return 'string';
-            default:
-                if (value && value.startsWith('!')) {
-                    return value;
-                }
-                if (value && value.startsWith('vector<') && value.endsWith('>')) {
-                    return value;
-                }
-                if (value && value.startsWith('memref<') && value.endsWith('>')) {
-                    return value;
-                }
-                if (value && value.startsWith('tuple<') && value.endsWith('>')) {
-                    return value;
-                }
-                if (value && value.startsWith('complex<') && value.endsWith('>')) {
-                    const elementTypeStr = value.substring(8, value.length - 1);
-                    const convertedElementType = _.Utility.dataType(elementTypeStr);
-                    return `complex<${convertedElementType}>`;
-                }
-                if (value && /^[su]?i[0-9]+$/.test(value)) {
-                    const match = value.match(/^(s|u)?i([0-9]+)$/);
-                    if (match) {
-                        const [, signed, widthStr] = match;
-                        const width = parseInt(widthStr, 10);
-                        if (signed === 'u') {
-                            return `uint${width}`;
-                        } else if (signed === 's') {
-                            return `int${width}`;
-                        }
-                        return `int${width}`;
-                    }
-                }
-                throw new mlir.Error(`Unknown data type '${value}'.`);
-        }
-    }
-
-    static valueType(type) {
-        if (type === undefined) {
-            return null;
-        }
-        const typeStr = type instanceof _.Type ? type.toString() : type;
-        if (typeStr.startsWith('!') && !typeStr.startsWith('!torch.vtensor<')) {
-            return typeStr;
-        }
-        if (typeStr.startsWith('tensor<') && typeStr.endsWith('>')) {
-            const spec = typeStr.substring(7, typeStr.length - 1).trim();
-            if (spec.startsWith('!')) {
-                return _.Utility.valueType(spec);
-            }
-            let i = 0;
-            const shape = [];
-            while (i < spec.length) {
-                if (spec[i] === '?' || spec[i] === '*') {
-                    shape.push('?');
-                    i++;
-                } else if (/[0-9]/.test(spec[i])) {
-                    let numStr = '';
-                    while (i < spec.length && /[0-9]/.test(spec[i])) {
-                        numStr += spec[i];
-                        i++;
-                    }
-                    const dim = parseInt(numStr, 10);
-                    if (isNaN(dim)) {
-                        shape.push('?');
-                    } else {
-                        shape.push(dim);
-                    }
-                } else {
-                    break;
-                }
-                if (i < spec.length && spec[i] === 'x') {
-                    i++;
-                } else {
-                    break;
-                }
-            }
-            let dataType = spec.substring(i);
-            const encodingIndex = dataType.indexOf(',');
-            if (encodingIndex !== -1) {
-                dataType = dataType.substring(0, encodingIndex).trim();
-            }
-            return new mlir.TensorType(dataType, new mlir.TensorShape(shape));
-        }
-        if (typeStr.startsWith('!torch.vtensor<') && typeStr.endsWith('>')) {
-            const spec = typeStr.substring(15, typeStr.length - 1);
-            let shape = null;
-            let dataType = null;
-            if (spec.startsWith('[')) {
-                const bracketEnd = spec.indexOf(']');
-                const shapeStr = spec.substring(0, bracketEnd + 1);
-                const jsonStr = shapeStr.replace(/\?/g, '"?"');
-                shape = JSON.parse(jsonStr);
-                const rest = spec.substring(bracketEnd + 1);
-                if (rest.startsWith(',')) {
-                    const parts = rest.substring(1).split(',');
-                    dataType = parts[0].trim();
-                }
-            } else if (spec.startsWith('*')) {
-                if (spec.includes(',')) {
-                    const parts = spec.split(',');
-                    dataType = parts[1].trim();
-                }
-            } else {
-                const parts = spec.split(',');
-                dataType = parts[0].trim();
-            }
-            return new mlir.TensorType(dataType, shape ? new mlir.TensorShape(shape) : null);
-        }
-        if (typeStr.startsWith('tuple<') && typeStr.endsWith('>')) {
-            return typeStr;
-        }
-        return typeStr;
-    }
-};
-
 // Dialect Plugin System
 
 _.AssemblyFormatParser = class {
@@ -7549,15 +7563,19 @@ _.HLODialect = class extends _.Dialect {
         resultTypes.push(type);
     }
 
-    // custom<SelectOpType>(type($operands), type($result))
-    _parseSelectOpType(parser, op, operandTypes, resultTypes) {
+    // custom<SelectOpType>(type($pred), type($on_true), type($on_false), type($result))
+    _parseSelectOpType(parser, op, predTypes, onTrueTypes, onFalseTypes, resultTypes) {
         const firstType = parser.parseType();
         if (parser.accept(',')) {
             const secondType = parser.parseType();
-            operandTypes.push(firstType);
+            predTypes.push(firstType);
+            onTrueTypes.push(secondType);
+            onFalseTypes.push(secondType);
             resultTypes.push(secondType);
         } else {
-            operandTypes.push(firstType);
+            predTypes.push(firstType);
+            onTrueTypes.push(firstType);
+            onFalseTypes.push(firstType);
             resultTypes.push(firstType);
         }
     }
@@ -8039,7 +8057,17 @@ _.StableHLODialect = class extends _.HLODialect {
         if (parser.match('{')) {
             parser.parseAttributeDict(op.attributes);
         }
-        parser.resolveOperands(unresolvedOperands, parser.parseOptionalColonTypeList(), op.operands);
+        // Handle `: (operand-types) -> result-types` functional type format
+        if (parser.accept(':')) {
+            const type = parser.parseType();
+            if (type instanceof _.FunctionType) {
+                parser.resolveOperands(unresolvedOperands, type.inputs, op.operands);
+                op.addTypes(type.results);
+            } else {
+                const types = Array.isArray(type) ? type : [type];
+                parser.resolveOperands(unresolvedOperands, types, op.operands);
+            }
+        }
         if (parser.accept('->') || parser.accept('id', 'to')) {
             const types = parser.parseFunctionResultTypes();
             op.addTypes(types);

+ 31 - 0
test/models.json

@@ -3332,6 +3332,13 @@
     "assert":   "model.functions[0].nodes[2].outputs.length == 1",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
+  {
+    "type":     "mlir",
+    "target":   "collatz.mlir",
+    "source":   "https://github.com/user-attachments/files/24442934/collatz.mlir.zip[collatz.mlir]",
+    "format":   "MLIR",
+    "link":     "https://github.com/lutzroeder/netron/issues/1044"
+  },
   {
     "type":     "mlir",
     "target":   "conv-conversion.mlir",
@@ -3370,6 +3377,14 @@
     "format":   "MLIR",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
+  {
+    "type":     "mlir",
+    "target":   "edge_detection.mlir",
+    "source":   "https://github.com/user-attachments/files/24442856/edge_detection.mlir.zip[edge_detection.mlir]",
+    "format":   "MLIR",
+    "assert":   "model.functions[0].nodes[2].attributes[0].value == '[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]'",
+    "link":     "https://github.com/lutzroeder/netron/issues/1044"
+  },
   {
     "type":     "mlir",
     "target":   "embedding_bm1684x_f16_tpu.mlir",
@@ -3442,6 +3457,14 @@
     "assert":   "model.modules[0].nodes[0].attributes[0].value.nodes.length == 6",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
+  {
+    "type":     "mlir",
+    "target":   "mnist_stablehlo.mlir",
+    "source":   "https://github.com/user-attachments/files/24441726/mnist_stablehlo.mlir.zip",
+    "format":   "MLIR",
+    "assert":   "model.functions[0].nodes.length == 23",
+    "link":     "https://github.com/lutzroeder/netron/issues/1044"
+  },
   {
     "type":     "mlir",
     "target":   "model.mlirbc",
@@ -3477,6 +3500,14 @@
     "format":   "MLIR",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
+  {
+    "type":     "mlir",
+    "target":   "lstm.mlir",
+    "source":   "https://github.com/user-attachments/files/24440170/lstm.mlir.zip[lstm.mlir]",
+    "format":   "MLIR",
+    "assert":   "model.functions[2].nodes[6].inputs[0].value[4].initializer.type.dataType == 'float32'",
+    "link":     "https://github.com/lutzroeder/netron/issues/1044"
+  },
   {
     "type":     "mlir",
     "target":   "sample.mlir",