Explorar o código

Add MLIR test file (#1044)

Lutz Roeder hai 2 meses
pai
achega
3cd3e06766
Modificáronse 4 ficheiros con 81 adicións e 48 borrados
  1. 71 40
      source/mlir.js
  2. 1 1
      source/python.js
  3. 8 7
      test/models.json
  4. 1 0
      test/worker.js

+ 71 - 40
source/mlir.js

@@ -96,10 +96,12 @@ mlir.Model = class {
                 if (isFunc(op.name)) {
                     funcs.push(op);
                 } else if (isModule(op.name)) {
-                    const modulePath = [...path, `$${identifier++}`];
+                    let name = op.getAttr('sym_name');
+                    name = name ? name.value : `$${identifier++}`;
+                    const modulePath = [...path, name];
                     for (const region of op.regions || []) {
                         for (const blk of region.blocks || []) {
-                            collectModules(blk.operations || [], modulePath, op.attributes);
+                            collectModules(blk.operations || [], modulePath, op.getAttrDictionary());
                         }
                     }
                 } else {
@@ -126,13 +128,13 @@ mlir.Model = class {
             return '';
         };
         const functions = new Map();
-        let funcIndex = 0;
+        let identifier = 0;
         for (const module of modules) {
             const prefix = formatPrefix(module.path, module.symName);
             for (const func of module.funcs) {
-                const sym_name = func.attributes.get('sym_name');
-                const base = sym_name ? sym_name.value : `$${funcIndex}`;
-                funcIndex++;
+                const sym_name = func.getAttr('sym_name');
+                const base = sym_name ? sym_name.value : `$${identifier}`;
+                identifier++;
                 const name = prefix ? `${prefix}::@${base}` : `@${base}`;
                 functions.set(name, { func, prefix, base, module });
             }
@@ -145,11 +147,9 @@ mlir.Model = class {
         for (const module of modules) {
             if (module.ops.length > 0 || module.attributes.size > 0) {
                 const name = formatPrefix(module.path, module.symName) || '';
-                const op = {
-                    name: 'builtin.module',
-                    attributes: module.attributes || [],
-                    regions: [{ blocks: [{ operations: module.ops, arguments: [] }] }]
-                };
+                const op = new mlir.Operation('builtin.module');
+                op.attributes = module.attributes;
+                op.regions = [{ blocks: [{ operations: module.ops, arguments: [] }] }];
                 const graph = context.graph(op, name);
                 this.modules.push(graph);
             }
@@ -227,7 +227,7 @@ mlir.Graph = class {
                     const operation = {
                         type: op.kind || op.name,
                         identifier: op.name,
-                        attributes: op.attributes,
+                        attributes: op.getAttrDictionary(),
                         operands: [],
                         results: [],
                         regions: op.regions || [],
@@ -309,7 +309,7 @@ mlir.Graph = class {
                 op.results[0].value.length === 1) {
                 const [result] = op.results[0].value;
                 if (result.to && result.to.length === 1) {
-                    const valueAttr = op.attributes.get('value');
+                    const valueAttr = op.attributes.get('value') || op.attributes.get('values');
                     if ((valueAttr instanceof mlir.DenseElementsAttr || valueAttr instanceof mlir.DenseResourceElementsAttr) &&
                         valueAttr.value !== null &&
                         valueAttr.type && valueAttr.type.toString().startsWith('tensor<')) {
@@ -345,7 +345,8 @@ mlir.Graph = class {
                     const [returnValue] = operand.value;
                     if (returnValue && typeof returnValue.name === 'string' && returnValue.name.startsWith('%')) {
                         const output = this.outputs[i];
-                        output.value[0] = new mlir.Value(returnValue.name, returnValue.type || output.value[0].type, '', null);
+                        const returnType = returnValue.type ? mlir.Utility.valueType(returnValue.type) : output.value[0].type;
+                        output.value[0] = new mlir.Value(returnValue.name, returnType, '', null);
                     }
                 }
             }
@@ -611,6 +612,7 @@ mlir.Operation = class {
         this.operands = [];
         this.regions = [];
         this.results = [];
+        this.propertiesAttr = null;
     }
 
     addAttribute(name, value) {
@@ -622,6 +624,27 @@ mlir.Operation = class {
         this.regions.push(region);
         return region;
     }
+
+    getAttr(name) {
+        if (this.propertiesAttr instanceof mlir.DictionaryAttr) {
+            const value = this.propertiesAttr.get(name);
+            if (value !== undefined) {
+                return value;
+            }
+        }
+        return this.attributes.get(name);
+    }
+
+    getAttrDictionary() {
+        if (this.propertiesAttr instanceof mlir.DictionaryAttr) {
+            const result = new Map(this.attributes);
+            for (const [name, value] of this.propertiesAttr.value) {
+                result.set(name, value);
+            }
+            return result;
+        }
+        return this.attributes;
+    }
 };
 
 mlir.Token = class {
@@ -1379,7 +1402,7 @@ mlir.Parser = class {
             this.parseSuccessors(op.successors);
         }
         if (this.accept('<')) {
-            op.properties = this.parseAttribute();
+            op.propertiesAttr = this.parseAttribute();
             this.expect('>');
         }
         if (this.accept('(')) {
@@ -2650,11 +2673,7 @@ mlir.Parser = class {
         if (this.match('{')) {
             const attributes = new Map();
             this.parseAttributeDict(attributes);
-            const dict = {};
-            for (const [name, value] of attributes) {
-                dict[name] = value;
-            }
-            return { value: dict };
+            return new mlir.DictionaryAttr(attributes);
         }
         if (this.match('#')) {
             const attr = this.parseExtendedAttr();
@@ -4505,6 +4524,28 @@ mlir.ArrayAttr = class extends mlir.Attribute {
     }
 };
 
+mlir.DictionaryAttr = class extends mlir.Attribute {
+
+    constructor(value) {
+        super();
+        this._value = value; // Map of name -> Attribute
+    }
+
+    get value() {
+        return this._value;
+    }
+
+    get(name) {
+        return this._value.get(name);
+    }
+
+    toString() {
+        const entries = Array.from(this._value.entries())
+            .map(([k, v]) => `${k} = ${v && v.toString ? v.toString() : String(v)}`);
+        return `{${entries.join(', ')}}`;
+    }
+};
+
 mlir.DenseArrayAttr = class extends mlir.Attribute {
 
     constructor(elements, type) {
@@ -5557,12 +5598,12 @@ mlir.Dialect = class {
         this.registerCustomAttribute('LevelAttr', this._parseIntegerAttr.bind(this, 'index'));
         this.registerCustomType('Optional', this._parseOptional.bind(this));
         for (const metadata of operations.get(name) || []) {
-            const op = { metadata };
+            const opInfo = { metadata };
             if (metadata.assemblyFormat) {
                 const parser = new mlir.AssemblyFormatParser(metadata);
-                op.directives = parser.parse();
+                opInfo.directives = parser.parse();
             }
-            this._operations.set(metadata.name, op);
+            this._operations.set(metadata.name, opInfo);
         }
     }
 
@@ -7482,7 +7523,7 @@ mlir.StableHLODialect = class extends mlir.HLODialect {
         if (opName === 'stablehlo.constant') {
             if (parser.accept('(') && parser.accept(')')) {
                 if (parser.accept('<')) {
-                    op.properties = parser.parseAttribute();
+                    op.propertiesAttr = parser.parseAttribute();
                     parser.expect('>');
                 }
                 parser.parseOptionalAttrDict(op.attributes);
@@ -11901,26 +11942,16 @@ mlir.MhloDialect = class extends mlir.HLODialect {
             }
             block.arguments.push({ value: '%lhs', type: elementType ? `tensor<${elementType}>` : null });
             block.arguments.push({ value: '%rhs', type: elementType ? `tensor<${elementType}>` : null });
-            const innerOp = {
-                name: innerOpName,
-                operands: [{ value: '%lhs' }, { value: '%rhs' }],
-                results: [{ value: '%0', type: elementType ? `tensor<${elementType}>` : null }],
-                attributes: [],
-                regions: []
-            };
+            const innerOp = new mlir.Operation(innerOpName);
+            innerOp.operands.push({ value: '%lhs' });
+            innerOp.operands.push({ value: '%rhs' });
+            innerOp.results.push({ value: '%0', type: elementType ? `tensor<${elementType}>` : null });
             block.operations.push(innerOp);
-            const returnOp = {
-                name: 'mhlo.return',
-                operands: [{ value: '%0' }],
-                results: [],
-                attributes: [],
-                regions: []
-            };
+            const returnOp = new mlir.Operation('mhlo.return');
+            returnOp.operands.push({ value: '%0' });
             block.operations.push(returnOp);
-
             region.blocks.push(block);
             op.regions.push(region);
-
             return true;
         }
 

+ 1 - 1
source/python.js

@@ -5230,7 +5230,7 @@ python.Execution = class {
                                 context.view.setInt32(context.position, data[i], littleendian);
                                 break;
                             case 'i8':
-                                context.view.setBigInt64(context.position, data[i], littleendian);
+                                context.view.setBigInt64(context.position, typeof data[i] === 'number' ? BigInt(data[i]) : data[i], littleendian);
                                 break;
                             case 'u1':
                                 context.view.setUint8(context.position, data[i], littleendian);

+ 8 - 7
test/models.json

@@ -3417,6 +3417,14 @@
     "format":   "MLIR",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
+  {
+    "type":     "mlir",
+    "target":   "mlp_tosa-0.mlir",
+    "source":   "https://github.com/user-attachments/files/24387445/mlp_tosa-0.mlir.zip[mlp_tosa-0.mlir]",
+    "assert":   "model.functions[0].name == '@example::@mlp_invocation'",
+    "format":   "MLIR",
+    "link":     "https://github.com/lutzroeder/netron/issues/1044"
+  },
   {
     "type":     "mlir",
     "target":   "mnist.onnx.mlir",
@@ -3538,13 +3546,6 @@
     "format":   "MLIR",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
-  {
-    "type":     "mlir",
-    "target":   "test_bf16.mlir",
-    "source":   "https://github.com/user-attachments/files/23023446/test_bf16.mlir.zip[test_bf16.mlir]",
-    "format":   "MLIR",
-    "link":     "https://github.com/lutzroeder/netron/issues/1044"
-  },
   {
     "type":     "mlir",
     "target":   "tfl_quant_conv2d.mlir",

+ 1 - 0
test/worker.js

@@ -387,6 +387,7 @@ export class Target {
                                             case 'int1': data_type = 'int8'; break;
                                             case 'int2': data_type = 'int8'; break;
                                             case 'int4': data_type = 'int8'; break;
+                                            case 'int48': data_type = 'int64'; break;
                                             case 'uint2': data_type = 'uint8'; break;
                                             case 'uint4': data_type = 'uint8'; break;
                                             default: data_type = tensor.type.dataType; break;