Lutz Roeder 2 mesi fa
parent
commit
41cd15d5c0
4 ha cambiato i file con 152 aggiunte e 39 eliminazioni
  1. 12 2
      source/base.js
  2. 121 28
      source/mlir.js
  3. 1 8
      test/models.json
  4. 18 1
      test/worker.js

+ 12 - 2
source/base.js

@@ -285,6 +285,16 @@ DataView.prototype.getComplexFloat16 = DataView.prototype.getComplexFloat16 || f
     return new base.Complex(real, imaginary);
 };
 
+DataView.prototype.setComplexFloat16 = DataView.prototype.setComplexFloat16 || function(byteOffset, value, littleEndian) {
+    if (littleEndian) {
+        this.setFloat16(byteOffset, value.real, littleEndian);
+        this.setFloat16(byteOffset + 2, value.imaginary, littleEndian);
+    } else {
+        this.setFloat16(byteOffset + 2, value.real, littleEndian);
+        this.setFloat16(byteOffset, value.imaginary, littleEndian);
+    }
+};
+
 DataView.prototype.getComplexFloat32 = DataView.prototype.getComplexFloat32 || function(byteOffset, littleEndian) {
     const real = littleEndian ? this.getFloat32(byteOffset, littleEndian) : this.getFloat32(byteOffset + 4, littleEndian);
     const imaginary = littleEndian ? this.getFloat32(byteOffset + 4, littleEndian) : this.getFloat32(byteOffset, littleEndian);
@@ -656,11 +666,11 @@ base.Tensor = class {
             ['qint8', 1], ['qint16', 2], ['qint32', 4],
             ['quint8', 1], ['quint16', 2], ['quint32', 4],
             ['xint8', 1],
-            ['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8],
+            ['int8', 1], ['int16', 2], ['int32', 4], ['int64', 8], ['int128', 16],
             ['uint8', 1], ['uint16', 2], ['uint32', 4,], ['uint64', 8],
             ['float16', 2], ['float32', 4], ['float64', 8], ['bfloat16', 2],
             ['complex<float32>', 8], ['complex<float64>', 16], ['complex<int32>', 8],
-            ['float8e4m3fn', 1], ['float8e4m3fnuz', 1], ['float8e5m2', 1], ['float8e5m2fnuz', 1], ['float8e3m4', 1], ['float8e4m3', 1],
+            ['float8e4m3fn', 1], ['float8e4m3fnuz', 1], ['float8e5m2', 1], ['float8e5m2fnuz', 1], ['float8e4m3b11fnuz', 1], ['float8e3m4', 1], ['float8e4m3', 1], ['float4e2m1fn', 1], ['float6e2m3fn', 1], ['float6e3m2fn', 1], ['float8e8m0fnu', 1]
         ]);
     }
 

+ 121 - 28
source/mlir.js

@@ -184,12 +184,6 @@ mlir.Graph = class {
         this.nodes = [];
         this.metadata = [];
         const tensors = new Map();
-        const tensor = (arg) => {
-            if (!tensors.has(arg.name)) {
-                tensors.set(arg.name, new mlir.Value(arg.name, arg.type, null, arg.value));
-            }
-            return tensors.get(arg.name);
-        };
         // Handle function inputs/outputs if function_type exists
         if (func.attributes.has('function_type')) {
             const function_type = func.attributes.get('function_type');
@@ -245,7 +239,7 @@ mlir.Graph = class {
                     const operands = op.operands || [];
                     for (let i = 0; i < operands.length; i++) {
                         const input = op.operands[i];
-                        const inputName = input.name || (opMetadata && opMetadata.operands && opMetadata.operands[i] ? opMetadata.operands[i].name : null) || i.toString();
+                        const inputName = (opMetadata && opMetadata.operands && opMetadata.operands[i] ? opMetadata.operands[i].name : null) || input.name || i.toString();
                         if (input.value instanceof Uint8Array) {
                             operation.operands.push({
                                 name: inputName,
@@ -270,7 +264,7 @@ mlir.Graph = class {
                                 value: input.value
                             });
                         } else if (typeof input.value === 'string' && input.value) {
-                            const value = values.map(input);
+                            const value = values.map(input.value);
                             value.to.push(operation);
                             const args = [{ name: input.value, type: input.type }];
                             operation.operands.push({
@@ -294,7 +288,7 @@ mlir.Graph = class {
                         const value = values.map(output.value);
                         value.type = mlir.Utility.valueType(output.type);
                         value.from.push(operation);
-                        const outputName = output.name || (opMetadata && opMetadata.results && opMetadata.results[i] ? opMetadata.results[i].name : null) || i.toString();
+                        const outputName = (opMetadata && opMetadata.results && opMetadata.results[i] ? opMetadata.results[i].name : null) || output.name || i.toString();
                         operation.results.push({
                             name: outputName,
                             value: [value]
@@ -304,6 +298,39 @@ mlir.Graph = class {
                 }
             }
         }
+        // Build map of single-use constant tensors to convert to initializers
+        const constantMap = new Map();
+        const constantTypes = new Set([
+            'tosa.const', 'stablehlo.constant', 'arith.constant',
+            'mhlo.constant', 'torch.constant.tensor'
+        ]);
+        for (const op of operations) {
+            if (constantTypes.has(op.type) &&
+                op.operands.length === 0 &&
+                op.results.length === 1 &&
+                op.results[0].value.length === 1) {
+                const [result] = op.results[0].value;
+                if (result.to && result.to.length === 1) {
+                    const valueAttr = op.attributes.get('value');
+                    if ((valueAttr instanceof mlir.DenseElementsAttr || valueAttr instanceof mlir.DenseResourceElementsAttr) &&
+                        valueAttr.value !== null &&
+                        valueAttr.type && valueAttr.type.toString().startsWith('tensor<')) {
+                        const type = mlir.Utility.valueType(valueAttr.type);
+                        if (type instanceof mlir.TensorType) {
+                            constantMap.set(result.name, new mlir.Tensor(type, valueAttr.value));
+                            op.delete = true;
+                        }
+                    }
+                }
+            }
+        }
+        const tensor = (arg) => {
+            if (!tensors.has(arg.name)) {
+                const initializer = constantMap.get(arg.name) || null;
+                tensors.set(arg.name, new mlir.Value(arg.name, arg.type, null, initializer));
+            }
+            return tensors.get(arg.name);
+        };
         for (const input of this.inputs) {
             for (const arg of input.value) {
                 if (!tensors.has(arg.name)) {
@@ -366,7 +393,7 @@ mlir.Argument = class {
                 case 'i32': case 'si32': this.type = 'int32'; break;
                 case 'i16': case 'si16': this.type = 'int16'; break;
                 case 'i8': case 'si8': this.type = 'int8'; break;
-                case 'i1': this.type = 'boolean'; break;
+                case 'i1': this.type = 'int1'; break;
                 case 'f32': case 'float32': this.type = 'float32'; break;
                 case 'f64': case 'float64': this.type = 'float64'; break;
                 case 'f16': this.type = 'float16'; break;
@@ -463,9 +490,12 @@ mlir.Node = class {
                         value = graph;
                         type = 'function';
                     }
-                } else if (attr instanceof mlir.DenseElementsAttr) {
+                } else if (attr instanceof mlir.DenseElementsAttr && attr.value !== null) {
                     value = new mlir.Tensor(mlir.Utility.valueType(attr.type), attr.value);
                     type = 'tensor';
+                } else if (attr instanceof mlir.DenseResourceElementsAttr) {
+                    value = new mlir.Tensor(mlir.Utility.valueType(attr.type), null);
+                    type = 'tensor';
                 } else if (attr instanceof mlir.ArrayAttr) {
                     value = attr.value;
                 } else if (attr instanceof mlir.DenseArrayAttr) {
@@ -1441,6 +1471,17 @@ mlir.Parser = class {
     parseGenericOperation() {
         const name = this.expect('string');
         const op = new mlir.OperationState(name);
+        const index = name.indexOf('.');
+        if (index !== -1) {
+            const dialectName = name.substring(0, index);
+            const dialect = this._context.getDialect(dialectName);
+            if (dialect) {
+                const opInfo = dialect.getOperation(name);
+                if (opInfo) {
+                    op.metadata = opInfo.metadata;
+                }
+            }
+        }
         return this.parseGenericOperationAfterOpName(op);
     }
 
@@ -1957,6 +1998,10 @@ mlir.Parser = class {
             this._token.text = prefix;
             return this.parseType();
         }
+        // Return as PrimitiveType for known primitive types
+        if (/^[su]?i[0-9]+$/.test(prefix) || /^[fb]f?[0-9]+/.test(prefix) || prefix === 'index') {
+            return new mlir.PrimitiveType(prefix);
+        }
         return prefix;
     }
 
@@ -2760,7 +2805,7 @@ mlir.Parser = class {
             this.expect(':');
             type = this.parseType();
         }
-        return { value: handle, type };
+        return new mlir.DenseResourceElementsAttr(handle, type);
     }
 
     parseDenseArrayAttr(/* attrType */) {
@@ -3049,31 +3094,60 @@ mlir.TensorLiteralParser = class {
         const elementType = type.getElementType ? type.getElementType() : null;
         const numElements = type.getNumElements ? type.getNumElements() : 0;
         const isComplex = elementType instanceof mlir.ComplexType;
+        const baseElemType = isComplex && elementType.elementType ? elementType.elementType : elementType;
+        // Determine conversion function once based on element type
+        let convert = (v) => v;
+        if (baseElemType) {
+            const typeStr = baseElemType.toString();
+            const intMatch = typeStr.match(/^[su]?i(\d+)$/);
+            if (intMatch) {
+                const bitWidth = parseInt(intMatch[1], 10);
+                if (bitWidth >= 64) {
+                    convert = (v) => typeof v === 'bigint' ? v : BigInt(v);
+                }
+                // For smaller ints, values are already numbers from tokenizer
+            } else if (typeStr === 'index') {
+                convert = (v) => typeof v === 'bigint' ? v : BigInt(v);
+            }
+            // For floats and other types, values are already correct from tokenizer
+        }
+        // Handle zero-element tensors (e.g., tensor<2x0x3xi4>)
+        if (numElements === 0) {
+            return [];
+        }
         // Limit splat expansion to avoid memory issues with huge tensors
         const maxSplatExpansion = 10000000;
         // Handle splats - Reference: if shape.empty() and storage has elements, it's a splat
         const isSplat = this._shape.length === 0 && this._storage.length > 0;
-        if (isSplat && numElements > 1 && numElements <= maxSplatExpansion) {
+        if (isSplat && numElements > 1) {
+            if (numElements > maxSplatExpansion) {
+                // Too large to expand - return null to indicate we can't provide the data
+                return null;
+            }
             if (isComplex && this._storage.length === 2) {
                 // Complex splat: storage has 2 elements (real, imag)
                 const result = [];
+                const real = convert(this._storage[0]);
+                const imag = convert(this._storage[1]);
                 for (let i = 0; i < numElements; i++) {
-                    result.push(new base.Complex(this._storage[0], this._storage[1]));
+                    result.push(new base.Complex(real, imag));
                 }
                 return result;
             }
             // Regular splat: replicate single value
-            return new Array(numElements).fill(this._storage[0]);
+            const converted = convert(this._storage[0]);
+            return new Array(numElements).fill(converted);
         }
         // Non-splat complex: convert pairs to base.Complex objects
         if (isComplex && Array.isArray(this._storage)) {
             const result = [];
             for (let i = 0; i < this._storage.length; i += 2) {
-                result.push(new base.Complex(this._storage[i], this._storage[i + 1]));
+                result.push(new base.Complex(convert(this._storage[i]), convert(this._storage[i + 1])));
             }
             return result;
         }
-        return this._storage;
+        // Convert all values
+        return this._storage.map(convert);
     }
 };
 
@@ -3301,7 +3375,7 @@ mlir.AttrTypeReader = class {
                 const typeIdx = reader.varint().toNumber();
                 const type = this.readType(typeIdx);
                 const handleIdx = reader.varint().toNumber();
-                return { name: 'dense_resource', value: `resource<${handleIdx}>`, type };
+                return new mlir.DenseResourceElementsAttr(`resource<${handleIdx}>`, type);
             }
             default: {
                 return { name: 'builtin', value: `<builtin code ${typeCode}>` };
@@ -3440,17 +3514,15 @@ mlir.AttrTypeReader = class {
                 const shape = this._readShape(reader);
                 const elementTypeIdx = reader.varint().toNumber();
                 const elementType = this.readType(elementTypeIdx);
-                const shapeStr = shape.map((d) => d < 0 ? '?' : d.toString()).join('x');
-                return new mlir.Type(`tensor<${shapeStr}x${elementType.name}>`);
+                return new mlir.RankedTensorType(shape, elementType, null);
             }
             case 14: { // RankedTensorTypeWithEncoding
                 const encodingAttrIdx = reader.varint().toNumber();
-                this.readAttribute(encodingAttrIdx); // encoding
+                const encoding = this.readAttribute(encodingAttrIdx);
                 const shape = this._readShape(reader);
                 const elementTypeIdx = reader.varint().toNumber();
                 const elementType = this.readType(elementTypeIdx);
-                const shapeStr = shape.map((d) => d < 0 ? '?' : d.toString()).join('x');
-                return new mlir.Type(`tensor<${shapeStr}x${elementType.name}>`);
+                return new mlir.RankedTensorType(shape, elementType, encoding);
             }
             case 15: { // TupleType
                 const numTypes = reader.varint().toNumber();
@@ -3475,8 +3547,7 @@ mlir.AttrTypeReader = class {
                 const shape = this._readShape(reader);
                 const elementTypeIdx = reader.varint().toNumber();
                 const elementType = this.readType(elementTypeIdx);
-                const shapeStr = shape.map((d) => d < 0 ? '?' : d.toString()).join('x');
-                return new mlir.Type(`vector<${shapeStr}x${elementType.name}>`);
+                return new mlir.VectorType(shape, elementType);
             }
             case 19: { // VectorTypeWithScalableDims - simplified
                 return new mlir.Type('vector<?>');
@@ -3900,6 +3971,14 @@ mlir.BytecodeReader = class {
         const kHasProperties = 0x40;
 
         const op = new mlir.OperationState(fullName);
+        const [dialectName] = fullName.split('.');
+        const dialect = this._context.getDialect(dialectName);
+        if (dialect) {
+            const opInfo = dialect.getOperation(fullName);
+            if (opInfo) {
+                op.metadata = opInfo.metadata;
+            }
+        }
 
         // Parse location
         const locIdx = reader.varint().toNumber();
@@ -4371,6 +4450,19 @@ mlir.DenseElementsAttr = class extends mlir.Attribute {
     }
 };
 
+mlir.DenseResourceElementsAttr = class extends mlir.Attribute {
+
+    constructor(handle, type) {
+        super();
+        this.handle = handle;
+        this.type = type;
+    }
+
+    toString() {
+        return `dense_resource<${this.handle}>`;
+    }
+};
+
 mlir.ArrayAttr = class extends mlir.Attribute {
 
     constructor(elements) {
@@ -4573,7 +4665,7 @@ mlir.Utility = class {
             value = value.toString();
         }
         switch (value) {
-            case 'index': return 'index';
+            case 'index': return 'int64';
             case 'f16': return 'float16';
             case 'f32': return 'float32';
             case 'f64': return 'float64';
@@ -4599,7 +4691,7 @@ mlir.Utility = class {
             case 'f8E8M0FNU': return 'float8e8m0fnu';
             case 'float8': return 'float8';
             case 'tf32': return 'tensorfloat32';
-            case 'i1': return 'boolean';
+            case 'i1': return 'int1';
             case 'i2': return 'int2';
             case 'i4': return 'int4';
             case 'i8': return 'int8';
@@ -4620,6 +4712,7 @@ mlir.Utility = class {
             case 'ui64': return 'uint64';
             case 'b8': return 'int8';
             case 'unk': return 'unk'; // torch dialect unknown dtype
+            case '!tf_type.string': return 'string';
             default:
                 if (value && value.startsWith('!')) {
                     return value;
@@ -4648,7 +4741,7 @@ mlir.Utility = class {
                         } else if (signed === 's') {
                             return `int${width}`;
                         }
-                        return width === 1 ? 'boolean' : `int${width}`;
+                        return `int${width}`;
                     }
                 }
                 throw new mlir.Error(`Unknown data type '${value}'.`);

+ 1 - 8
test/models.json

@@ -3319,7 +3319,7 @@
     "type":     "mlir",
     "target":   "broadcast_to_dynamic.mlir",
     "source":   "https://github.com/user-attachments/files/23221752/broadcast_to_dynamic.mlir.zip[broadcast_to_dynamic.mlir]",
-    "assert":   "model.functions[0].nodes[4].inputs.length == 0",
+    "assert":   "model.functions[0].nodes.length == 6",
     "format":   "MLIR",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
@@ -3573,13 +3573,6 @@
     "format":   "MLIR",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
-  {
-    "type":     "mlir",
-    "target":   "versioned-op-2.0.mlirbc",
-    "source":   "https://github.com/user-attachments/files/17174958/versioned-op-2.0.mlirbc.zip[versioned-op-2.0.mlirbc]",
-    "format":   "MLIR Bytecode v1",
-    "link":     "https://github.com/lutzroeder/netron/issues/1044"
-  },
   {
     "type":     "mlir",
     "target":   "wcr.mlir",

+ 18 - 1
test/worker.js

@@ -362,16 +362,33 @@ export class Target {
                                 tensor.toString();
                                 if (this.tags.has('validation')) {
                                     const size = tensor.type.shape.dimensions.reduce((a, b) => a * b, 1);
-                                    if (tensor.type && tensor.type.dataType !== '?' && size < 8192) {
+                                    if (size < 8192 && tensor.type &&
+                                        tensor.type.dataType !== '?' &&
+                                        tensor.type.dataType !== 'string' &&
+                                        tensor.type.dataType !== 'int128' &&
+                                        tensor.type.dataType !== 'complex<int32>') {
                                         let data_type = '?';
                                         switch (tensor.type.dataType) {
                                             case 'boolean': data_type = 'bool'; break;
                                             case 'bfloat16': data_type = 'float32'; break;
+                                            case 'float4e2m1fn': data_type = 'float16'; break;
+                                            case 'float6e2m3fn': data_type = 'float16'; break;
+                                            case 'float6e3m2fn': data_type = 'float16'; break;
                                             case 'float8e5m2': data_type = 'float16'; break;
                                             case 'float8e5m2fnuz': data_type = 'float16'; break;
+                                            case 'float8e3m4': data_type = 'float16'; break;
+                                            case 'float8e4m3': data_type = 'float16'; break;
                                             case 'float8e4m3fn': data_type = 'float16'; break;
                                             case 'float8e4m3fnuz': data_type = 'float16'; break;
+                                            case 'float8e4m3b11fnuz': data_type = 'float16'; break;
+                                            case 'float8e8m0fnu': data_type = 'float16'; break;
+                                            case 'complex<float32>': data_type = 'complex64'; break;
+                                            case 'complex<float64>': data_type = 'complex128'; break;
+                                            case 'int1': data_type = 'int8'; break;
+                                            case 'int2': data_type = 'int8'; break;
                                             case 'int4': data_type = 'int8'; break;
+                                            case 'uint2': data_type = 'uint8'; break;
+                                            case 'uint4': data_type = 'uint8'; break;
                                             default: data_type = tensor.type.dataType; break;
                                         }
                                         Target.execution = Target.execution || new python.Execution();