Lutz Roeder 3 hónapja
szülő
commit
048580bba3
3 módosított fájl, 63 hozzáadás és 11 törlés
  1. 61 9
      source/mlir.js
  2. 2 2
      tools/mlir
  3. 0 0
      tools/mlir-script.js

+ 61 - 9
source/mlir.js

@@ -2052,6 +2052,7 @@ mlir.Parser = class {
     }
 
     parseCustomAttributeWithFallback(attrT, type) {
+        // Reference: if token is #, use generic parser; otherwise use custom parser
         if (this.match('#')) {
             return this.parseAttribute();
         }
@@ -2359,10 +2360,12 @@ mlir.Parser = class {
             this.parseTypeAndAttrList(types, attrs);
             this.expect(')');
         } else {
-            // Single type without parens (cannot have attributes in this case per reference impl)
-            const type = this.parseType();
-            types.push(type);
-            attrs.push(null);
+            // Parse comma-separated types without parens (cannot have attributes in this case per reference impl)
+            do {
+                const type = this.parseType();
+                types.push(type);
+                attrs.push(null);
+            } while (this.accept(','));
         }
     }
 
@@ -2656,7 +2659,7 @@ mlir.Parser = class {
         }
         if (this.accept('<')) {
             const parts = [];
-            while (!this.match('>')) {
+            while (!this.match('>') && !this.match('eof')) {
                 if (this.match('id')) {
                     const id = this.expect('id');
                     if (this.accept('=')) {
@@ -2665,6 +2668,12 @@ mlir.Parser = class {
                     } else {
                         parts.push(id);
                     }
+                } else if (this.match('int') || this.match('float') || this.match('string')) {
+                    // Parse literal values inside <...>
+                    parts.push(this.parseValue());
+                } else if (this.match('[') || this.match('#') || this.match('<')) {
+                    // Parse nested structures
+                    parts.push(this.parseAttribute());
                 } else {
                     break;
                 }
@@ -4044,6 +4053,13 @@ mlir.Dialect = class {
                     }
                 }
             }
+            if (Array.isArray(op.metadata.regions)) {
+                for (const region of op.metadata.regions) {
+                    if (region && region.type) {
+                        region.type = this._parseConstraint(region.type);
+                    }
+                }
+            }
             op.metadata._ = true;
         }
         return op || null;
@@ -4116,9 +4132,24 @@ mlir.Dialect = class {
                 break;
             case 'region_ref': {
                 // Parse region variable - matches reference implementation
-                const region = {};
-                parser.parseRegion(region);
-                op.regions.push(region);
+                // Check if this is a variadic region from metadata
+                const regionMeta = opInfo.metadata && opInfo.metadata.regions && opInfo.metadata.regions.find((r) => r.name === directive.name);
+                const isVariadicRegion = regionMeta && regionMeta.type && regionMeta.type.name === 'VariadicRegion';
+                if (isVariadicRegion) {
+                    // Parse multiple regions separated by commas
+                    // For variadic regions, may have zero or more regions
+                    if (parser.match('{')) {
+                        do {
+                            const region = {};
+                            parser.parseRegion(region);
+                            op.regions.push(region);
+                        } while (parser.accept(',') && parser.match('{'));
+                    }
+                } else {
+                    const region = {};
+                    parser.parseRegion(region);
+                    op.regions.push(region);
+                }
                 break;
             }
             case 'successor_ref': {
@@ -4725,7 +4756,13 @@ mlir.Dialect = class {
 
     _parseTypedAttrInterface(parser) {
         if (parser.match('#')) {
-            return parser._parser.parseAttribute();
+            const attr = parser.parseAttribute();
+            // Handle typed attribute with trailing : type
+            if (parser.accept(':')) {
+                const type = parser.parseType();
+                attr.type = type;
+            }
+            return attr;
         }
         const value = parser.parseValue();
         if (parser.accept(':')) {
@@ -4879,6 +4916,11 @@ mlir.Dialect = class {
 
             if (parser.accept('=')) {
                 const attr = parser.parseAttribute();
+                // Handle typed attribute with trailing : type (e.g., -2 : i8)
+                if (parser.accept(':')) {
+                    const attrType = parser.parseType();
+                    attr.type = attrType;
+                }
                 const attrIndex = op.attributes.findIndex((a) => a.name === attrArg);
                 if (attrIndex === -1) {
                     op.attributes.push({ name: attrArg, value: attr });
@@ -6557,6 +6599,7 @@ mlir.HALDialect = class extends mlir.IREEDialect {
     constructor(operations) {
         super('hal', operations);
         this.simpleTypes = new Set(['allocator', 'buffer', 'buffer_view', 'channel', 'command_buffer', 'descriptor_set', 'descriptor_set_layout', 'device', 'event', 'executable', 'executable_layout', 'fence', 'file', 'semaphore']);
+        this.registerCustomAttribute('HAL_PipelineLayoutAttr', this._parsePipelineLayoutAttr.bind(this));
     }
 
     parseType(parser, dialectName) {
@@ -6897,6 +6940,14 @@ mlir.HALDialect = class extends mlir.IREEDialect {
         }
         return super.parseOperation(parser, opName, op);
     }
+
+    _parsePipelineLayoutAttr(parser) {
+        // HAL_PipelineLayoutAttr format: <constants = N, bindings = [...], flags = ...>
+        if (parser.match('<')) {
+            return parser.parseAttribute();
+        }
+        return parser.parseOptionalAttribute();
+    }
 };
 
 mlir.UtilDialect = class extends mlir.IREEDialect {
@@ -12085,6 +12136,7 @@ mlir.MathDialect = class extends mlir.Dialect {
 
     constructor(operations) {
         super('math', operations);
+        this.registerCustomAttribute('Arith_FastMathAttr', this._parseEnumFlagsAngleBracketComma.bind(this));
     }
 };
 

+ 2 - 2
tools/mlir

@@ -54,12 +54,12 @@ sync() {
 
 schema() {
     echo "mlir schema"
-    node ./tools/mlir_script.js schema
+    node ./tools/mlir-script.js schema
 }
 
 test() {
     echo "mlir test"
-    node ./tools/mlir_script.js test "$@"
+    node ./tools/mlir-script.js test "$@"
 }
 
 while [ "$#" != 0 ]; do

+ 0 - 0
tools/mlir_script.js → tools/mlir-script.js