Lutz Roeder пре 4 месеци
родитељ
комит
466111c75e
4 измењених фајлова са 414 додато и 45 уклоњено
  1. 16 0
      source/mlir-metadata.json
  2. 323 12
      source/mlir.js
  3. 25 2
      tools/mlir_script.js
  4. 50 31
      tools/tablegen.js

Разлика између датотеке није приказан због своје велике величине
+ 16 - 0
source/mlir-metadata.json


+ 323 - 12
source/mlir.js

@@ -1,6 +1,5 @@
 
 // Experimental
-// contributor @tucan9389
 
 const mlir = {};
 
@@ -258,6 +257,7 @@ mlir.Argument = class {
         if (this.type) {
             switch (this.type) {
                 case 'i64': case 'si64': this.type = 'int64'; break;
+                case 'i48': case 'si48': this.type = 'int48'; break;
                 case 'i32': case 'si32': this.type = 'int32'; break;
                 case 'i16': case 'si16': this.type = 'int16'; break;
                 case 'i8': case 'si8': this.type = 'int8'; break;
@@ -545,6 +545,21 @@ mlir.Tokenizer = class {
             }
             return;
         }
+        if (this._current === '*') {
+            this._read();
+            while (this._current) {
+                if (this._current === '*') {
+                    this._read();
+                    if (this._current === '/') {
+                        this._read();
+                        return;
+                    }
+                } else {
+                    this._read();
+                }
+            }
+            return;
+        }
         throw new mlir.Error('Invalid comment.');
     }
 
@@ -844,7 +859,7 @@ mlir.Parser = class {
         this._dialects.set('tfr', new mlir.TFRDialect(operations));
         this._dialects.set('tfrt', new mlir.TFRTDialect(operations));
         this._dialects.set('tfrt_fallback', new mlir.Dialect('tfrt_fallback', operations));
-        this._dialects.set('tfl', new mlir.Dialect('tfl', operations));
+        this._dialects.set('tfl', new mlir.TFLDialect(operations));
         this._dialects.set('stdx', new mlir.StdxDialect(operations));
         this._dialects.set('vm', new mlir.VMDialect(operations));
         this._dialects.set('math', new mlir.MathDialect(operations));
@@ -1498,8 +1513,38 @@ mlir.Parser = class {
         }
         const open = this.accept('(');
         // eslint-disable-next-line no-unmodified-loop-condition
-        while (!this.match(')') && !this.match('->') && !this.match('{') && !this.match('}') && !this.match('[') && !this.match('=') && !this.match('^') && !(this.match(':') && !open)) {
+        while (!this.match(')') && !this.match('->') && !this.match('{') && !this.match('}') && !this.match('=') && !this.match('^') && !(this.match(':') && !open)) {
             const input = {};
+            if (this.match('[')) {
+                this.expect('[');
+                const array = [];
+                while (!this.match(']')) {
+                    if (this.match('%')) {
+                        array.push(this.expect());
+                    } else if (this.match('int')) {
+                        array.push(parseInt(this.expect('int'), 10));
+                    } else if (this.match('-')) {
+                        this.expect('-');
+                        if (this.match('int')) {
+                            array.push(-parseInt(this.expect('int'), 10));
+                        } else {
+                            throw new mlir.Error(`Expected integer after '-' in array literal ${this.location()}`);
+                        }
+                    } else {
+                        break;
+                    }
+                    if (!this.accept(',')) {
+                        break;
+                    }
+                }
+                this.expect(']');
+                input.value = array;
+                inputs.push(input);
+                if (!this.accept(',')) {
+                    break;
+                }
+                continue;
+            }
             if (this._token.kind === 'id' && this._token.value !== 'dense' && this._token.value !== 'dense_resource') {
                 const identifier = this.expect('id');
                 if (this.accept('(')) {
@@ -2607,6 +2652,7 @@ mlir.Utility = class {
             case 'i8': return 'int8';
             case 'i16': return 'int16';
             case 'i32': return 'int32';
+            case 'i48': return 'int48';
             case 'i64': return 'int64';
             case 'si8': return 'int8';
             case 'si16': return 'int16';
@@ -2872,9 +2918,27 @@ mlir.AssemblyFormatParser = class {
             return { type: 'custom', parser, args };
         }
         if (remaining.startsWith('oilist(')) {
-            this._pos += 'oilist'.length;
-            this._parseParenList();
-            return null;
+            this._pos += 'oilist('.length;
+            let content = '';
+            let depth = 1;
+            while (this._pos < this._format.length && depth > 0) {
+                const ch = this._format[this._pos];
+                if (ch === '(') {
+                    depth++;
+                    content += ch;
+                    this._pos++;
+                } else if (ch === ')') {
+                    depth--;
+                    if (depth > 0) {
+                        content += ch;
+                    }
+                    this._pos++;
+                } else {
+                    content += ch;
+                    this._pos++;
+                }
+            }
+            return { type: 'oilist', content };
         }
         if (/^[:()[\]{}<>,=|]/.test(ch)) {
             this._pos++;
@@ -2984,7 +3048,10 @@ mlir.Dialect = class {
         this.registerCustomParser('CustomCallTarget', this._parseCustomCallTarget.bind(this));
         this.registerCustomParser('VariadicOperandWithAttribute', this._parseVariadicOperandWithAttribute.bind(this));
         this.registerCustomParser('DynamicIndexList', this._parseDynamicIndexList.bind(this));
+        this.registerCustomParser('Offsets', this._parseOffsets.bind(this));
         this.registerCustomParser('SymbolVisibility', this._parseSymbolVisibility.bind(this));
+        this.registerCustomParser('SymbolAlias', this._parseSymbolAlias.bind(this));
+        this.registerCustomParser('WorkgroupCountRegion', this._parseWorkgroupCountRegion.bind(this));
         this.registerCustomParser('OptionalUnitAttr', this._parseOptionalUnitAttr.bind(this));
         for (const op of operations.get(name) || []) {
             if (op.assemblyFormat) {
@@ -3077,9 +3144,11 @@ mlir.Dialect = class {
                     } else if (isAttribute) {
                         const attrValue = parser.parseValue();
                         if (attrValue) {
+                            // Check if there's a type annotation (`: type`) after the attribute value
+                            // This handles typed attributes like AnyAttr, TypedAttrInterface, etc.
                             if ((attrValue.type === 'int64' || attrValue.type === 'float32' || attrValue.type === 'boolean' || attrValue.type === 'dense') &&
                                 parser.accept(':')) {
-                                parser.parseType();
+                                attrValue.attrType = parser.parseType();
                             }
                             // For attributes, we only store the value, not the internal "type" field
                             // The type field here is just metadata about how the value was parsed (e.g., 'dense')
@@ -3232,7 +3301,19 @@ mlir.Dialect = class {
                     }
                     const result = fn(parser, directive.args);
                     if (result) {
-                        if (result.kind === 'SameOperandsAndResultType' && result.type) {
+                        if (result.kind === 'SymbolVisibility' && result.visibility && directive.args && directive.args.length > 0) {
+                            const attrName = directive.args[0].replace(/^\$/, '');
+                            op.attributes.push({ name: attrName, value: result.visibility });
+                        } else if (result.kind === 'SymbolAlias' && result.symbolName) {
+                            if (result.symNameArg) {
+                                op.attributes.push({ name: result.symNameArg, value: result.symbolName });
+                            }
+                            if (result.functionRefArg) {
+                                op.attributes.push({ name: result.functionRefArg, value: result.symbolName });
+                            }
+                        } else if (result.kind === 'WorkgroupCountRegion' && result.region) {
+                            op.regions.push(result.region);
+                        } else if (result.kind === 'SameOperandsAndResultType' && result.type) {
                             for (const operand of op.operands) {
                                 if (!operand.type) {
                                     operand.type = result.type;
@@ -3271,6 +3352,9 @@ mlir.Dialect = class {
                     }
                     break;
                 }
+                case 'oilist': {
+                    break;
+                }
                 case 'optional_group': {
                     let shouldParse = false;
 
@@ -3286,7 +3370,21 @@ mlir.Dialect = class {
                             if (firstElem.name === 'overflowFlags') {
                                 shouldParse = parser.match('id', 'overflow');
                             } else {
-                                shouldParse = parser.match('%');
+                                // Check if this is an attribute or an operand
+                                let isFirstAttribute = false;
+                                if (opInfo.metadata && opInfo.metadata.attributes) {
+                                    const attrInfo = opInfo.metadata.attributes.find((attr) => attr.name === firstElem.name);
+                                    if (attrInfo) {
+                                        isFirstAttribute = true;
+                                    }
+                                }
+                                if (isFirstAttribute) {
+                                    // For attributes, check if there's a value present (id, int, float, etc.)
+                                    shouldParse = parser.match('id') || parser.match('int') || parser.match('float') || parser.match('[') || parser.match('{');
+                                } else {
+                                    // For operands, check for %
+                                    shouldParse = parser.match('%');
+                                }
                             }
                         } else if (firstElem.type === 'operands') {
                             shouldParse = parser.match('(') || parser.match('%');
@@ -3778,6 +3876,29 @@ mlir.Dialect = class {
         return result;
     }
 
+    _parseOffsets(parser, args) {
+        const values = [];
+        while (parser.match('int') || parser.match('-')) {
+            if (parser.accept('-')) {
+                if (parser.match('int')) {
+                    values.push(-parseInt(parser.expect('int'), 10));
+                } else {
+                    throw new mlir.Error(`Expected integer after '-' in offsets ${parser.location()}`);
+                }
+            } else {
+                values.push(parseInt(parser.expect('int'), 10));
+            }
+            if (!parser.accept(',')) {
+                break;
+            }
+        }
+        if (args && args.length > 0) {
+            const [attrName] = args;
+            return { name: attrName, value: values };
+        }
+        return values;
+    }
+
     _parseVariadicOperandWithAttribute(parser, /*, args */) {
         const result = {
             kind: 'VariadicOperandWithAttribute',
@@ -3819,6 +3940,30 @@ mlir.Dialect = class {
         return { kind: 'SymbolVisibility', visibility: null };
     }
 
+    _parseSymbolAlias(parser, args) {
+        if (!parser.match('@')) {
+            return { kind: 'SymbolAlias' };
+        }
+        const symbolName = parser.expect('@');
+        const result = { kind: 'SymbolAlias', symbolName };
+        if (args && args.length >= 1) {
+            result.symNameArg = args[0].replace(/^\$/, '');
+        }
+        if (args && args.length >= 2) {
+            result.functionRefArg = args[1].replace(/^\$/, '');
+        }
+        return result;
+    }
+
+    _parseWorkgroupCountRegion(parser /*, args */) {
+        if (parser.match('{')) {
+            const region = {};
+            parser.parseRegion(region);
+            return { kind: 'WorkgroupCountRegion', region };
+        }
+        return { kind: 'WorkgroupCountRegion', region: null };
+    }
+
     _parseOptionalUnitAttr(parser /*, args */) {
         if (parser.match('id') || parser.match('%') || parser.match('(')) {
             return { kind: 'OptionalUnitAttr', present: true };
@@ -4481,12 +4626,32 @@ mlir.VectorDialect = class extends mlir.Dialect {
             }
             return true;
         }
+        if (name === 'vector.mask') {
+            if (parser.match('%')) {
+                const mask = parser.expect('%');
+                op.operands.push({ value: mask, name: 'mask' });
+            }
+            if (parser.accept(',')) {
+                const passthru = parser.expect('%');
+                op.operands.push({ value: passthru, name: 'passthru' });
+            }
+            if (parser.match('{')) {
+                const region = {};
+                parser.parseRegion(region);
+                op.regions.push(region);
+            }
+            parser.parseOptionalAttrDict(op.attributes);
+            if (parser.accept(':')) {
+                parser.parseArgumentTypes(op.operands);
+                if (parser.accept('->')) {
+                    parser.parseArgumentTypes(op.results);
+                }
+            }
+            return true;
+        }
         if (name === 'vector.transfer_read' || name === 'vector.transfer_write') {
             return this._parseTransferOp(parser, op);
         }
-        // Handle old vector.extract syntax (pre-2023) without 'from' keyword
-        // Old: %r = vector.extract %v[0] : vector<4xf32>
-        // New: %r = vector.extract %v[0] : f32 from vector<4xf32>
         if (name === 'vector.extract' && !op.isGeneric) {
             return this._parseExtractOp(parser, op);
         }
@@ -8316,6 +8481,43 @@ mlir.SdfgDialect = class extends mlir.Dialect {
     }
 };
 
+mlir.TFLDialect = class extends mlir.Dialect {
+
+    constructor(operations) {
+        super('tfl', operations);
+        // Operations that use parseOneResultSameOperandTypeOp in tfl_ops.cc
+        // Format: operands attr-dict : single-type
+        this._binaryOps = new Set([
+            'add', 'sub', 'mul', 'div', 'floor_div', 'pow', 'squared_difference',
+            'less', 'less_equal', 'greater', 'greater_equal', 'not_equal',
+            'logical_and', 'logical_or'
+        ]);
+    }
+
+    parseOperation(parser, opName, op) {
+        const name = opName.replace(/^"|"$/g, '');
+        const opKind = name.substring('tfl.'.length);
+        if (this._binaryOps.has(opKind)) {
+            // Parse: operands attr-dict : type
+            op.operands = parser.parseArguments();
+            parser.parseOptionalAttrDict(op.attributes);
+            if (parser.accept(':')) {
+                const type = parser.parseType();
+                // All operands and result share the same type
+                for (const operand of op.operands) {
+                    operand.type = type;
+                }
+                if (op.results.length > 0) {
+                    op.results[0].type = type;
+                }
+            }
+            return true;
+        }
+
+        return super.parseOperation(parser, opName, op);
+    }
+};
+
 mlir.TFDialect = class extends mlir.Dialect {
 
     constructor(operations) {
@@ -8451,6 +8653,15 @@ mlir.TransformDialect = class extends mlir.Dialect {
         if (name === 'transform.get_result') {
             return this._parseGetResultOp(parser, op);
         }
+        if (name === 'transform.func.cast_and_call') {
+            return this._parseFuncCastAndCallOp(parser, op);
+        }
+        if (name === 'transform.func.replace_func_signature') {
+            return this._parseReplaceFuncSignatureOp(parser, op);
+        }
+        if (name === 'transform.func.deduplicate_func_args') {
+            return this._parseDeduplicateFuncArgsOp(parser, op);
+        }
         if (name.startsWith('transform.test.')) {
             return this._parseTestOp(parser, name, op);
         }
@@ -8525,6 +8736,106 @@ mlir.TransformDialect = class extends mlir.Dialect {
         return true;
     }
 
+    _parseFuncCastAndCallOp(parser, op) {
+        if (parser.match('@')) {
+            const funcName = parser.expect('@');
+            op.attributes.push({ name: 'function_name', value: funcName });
+        }
+        if (parser.match('%') && !parser.match('id', 'before') && !parser.match('id', 'after')) {
+            const funcHandle = parser.expect('%');
+            op.operands.push({ value: funcHandle, name: 'function' });
+        }
+        if (parser.accept('(')) {
+            while (!parser.match(')')) {
+                const value = parser.expect('%');
+                op.operands.push({ value, name: 'inputs' });
+                if (!parser.accept(',')) {
+                    break;
+                }
+            }
+            parser.expect(')');
+        }
+        if (parser.accept('->')) {
+            while (!parser.match('id') && parser.match('%')) {
+                const value = parser.expect('%');
+                op.operands.push({ value, name: 'outputs' });
+                if (!parser.accept(',')) {
+                    break;
+                }
+            }
+        }
+        if (parser.accept('id', 'after')) {
+            op.attributes.push({ name: 'insert_after', value: true });
+        } else {
+            parser.accept('id', 'before');
+        }
+        if (parser.match('%')) {
+            const insertionPoint = parser.expect('%');
+            op.operands.push({ value: insertionPoint, name: 'insertion_point' });
+        }
+        if (parser.match('{')) {
+            const region = {};
+            parser.parseRegion(region);
+            op.regions.push(region);
+        }
+        parser.parseOptionalAttrDict(op.attributes);
+        if (parser.accept(':')) {
+            parser.parseArgumentTypes(op.operands);
+            if (parser.accept('->')) {
+                parser.parseArgumentTypes(op.results);
+            }
+        }
+        return true;
+    }
+
+    _parseReplaceFuncSignatureOp(parser, op) {
+        if (parser.match('@')) {
+            const funcName = parser.expect('@');
+            op.attributes.push({ name: 'function_name', value: funcName });
+        }
+        if (parser.accept('id', 'args_interchange')) {
+            parser.expect('=');
+            const argsInterchange = parser.parseValue();
+            op.attributes.push({ name: 'args_interchange', value: argsInterchange.value });
+        }
+        if (parser.accept('id', 'results_interchange')) {
+            parser.expect('=');
+            const resultsInterchange = parser.parseValue();
+            op.attributes.push({ name: 'results_interchange', value: resultsInterchange.value });
+        }
+        if (parser.accept('id', 'at')) {
+            const module = parser.expect('%');
+            op.operands.push({ value: module, name: 'module' });
+        }
+        parser.parseOptionalAttrDict(op.attributes);
+        if (parser.accept(':')) {
+            parser.parseArgumentTypes(op.operands);
+            if (parser.accept('->')) {
+                parser.parseArgumentTypes(op.results);
+            }
+        }
+        return true;
+    }
+
+    _parseDeduplicateFuncArgsOp(parser, op) {
+        if (parser.match('@')) {
+            const funcName = parser.expect('@');
+            op.attributes.push({ name: 'function_name', value: funcName });
+        }
+        if (parser.accept('id', 'at')) {
+            const module = parser.expect('%');
+            op.operands.push({ value: module, name: 'module' });
+        }
+        parser.parseOptionalAttrDict(op.attributes);
+        if (parser.accept(':')) {
+            parser.parseArgumentTypes(op.operands);
+            if (parser.accept('->')) {
+                parser.parseArgumentTypes(op.results);
+            }
+        }
+        return true;
+    }
+
     _parseGetResultOp(parser, op) {
         // Parse: transform.get_result %op[0] : (!transform.any_op) -> !transform.any_value
         if (parser.match('%')) {

+ 25 - 2
tools/mlir_script.js

@@ -251,7 +251,18 @@ const main = async () => {
                 metadata.description = evaluatedDesc;
             }
         }
-        const args = def.resolveField('arguments');
+        let args = def.resolveField('arguments');
+        if (!args || !args.value || args.value.type !== 'dag' || (args.value.value && args.value.value.operands && args.value.value.operands.length === 0)) {
+            for (const parent of def.parents) {
+                if (parent.name === 'Arguments' && parent.args && parent.args.length > 0) {
+                    const [dagValue] = parent.args;
+                    if (dagValue && dagValue.type === 'dag') {
+                        args = { value: dagValue };
+                    }
+                    break;
+                }
+            }
+        }
         if (args && args.value && args.value.type === 'dag') {
             const dag = args.value.value;
             if (dag.operator === 'ins') {
@@ -282,7 +293,19 @@ const main = async () => {
                 }
             }
         }
-        const results = def.resolveField('results');
+        let results = def.resolveField('results');
+        if (!results || !results.value || results.value.type !== 'dag' || (results.value.value && results.value.value.operands && results.value.value.operands.length === 0)) {
+            for (const parent of def.parents) {
+                if (parent.name === 'Results' && parent.args && parent.args.length > 0) {
+                    const [dagValue] = parent.args;
+                    if (dagValue && dagValue.type === 'dag') {
+                        results = { value: dagValue };
+                    }
+                    break;
+                }
+            }
+        }
+
         if (results && results.value && results.value.type === 'dag') {
             const dag = results.value.value;
             if (dag.operator === 'outs') {

+ 50 - 31
tools/tablegen.js

@@ -815,45 +815,64 @@ tablegen.Reader = class {
         }
         const type = new tablegen.Type(typeName);
         if (this._eat('<')) {
-            let depth = 1;
-            const argsTokens = [];
-            while (depth > 0 && !this._match('eof')) {
-                if (this._eat('<')) {
-                    argsTokens.push(this._tokenizer.current());
-                    depth++;
-                } else if (this._eat('>')) {
-                    depth--;
-                    if (depth === 0) {
-                        break;
-                    }
-                    argsTokens.push(this._tokenizer.current());
-                } else {
-                    argsTokens.push(this._tokenizer.current());
-                    this._read();
-                }
-            }
-            type.args = this._parseTemplateArgList(argsTokens);
+            type.args = this._parseTemplateArgList();
         }
         return type;
     }
 
-    _parseTemplateArgList(tokens) {
+    _parseTemplateArgList() {
+        // Parse template arguments directly from token stream
+        // Supports both positional (arg1, arg2) and named (name=value) arguments
         const args = [];
-        let current = '';
-        for (const token of tokens) {
-            if (token.type === ',') {
-                if (current.trim()) {
-                    args.push(current.trim());
+        while (!this._match('>') && !this._match('eof')) {
+            // Check if this is a named argument: id = value
+            if (this._match('id')) {
+                const name = this._read();
+                if (this._match('=')) {
+                    // Named argument
+                    this._read(); // Consume '='
+                    const value = this._parseValue();
+                    args.push({ name, value });
+                } else {
+                    // Positional argument that starts with an id
+                    // Reconstruct the value - id might be part of concat, field access, etc.
+                    let value = new tablegen.Value('def', name);
+                    // Handle < > for template instantiation
+                    if (this._match('<')) {
+                        this._skip('<', '>');
+                    }
+                    // Handle field access
+                    if (this._eat('.')) {
+                        const field = this._expect('id');
+                        value = new tablegen.Value('def', `${name}.${field}`);
+                    }
+                    // Handle :: suffix
+                    if (this._eat('::')) {
+                        const suffix = this._expect('id');
+                        value = new tablegen.Value('def', `${name}::${suffix}`);
+                    }
+                    // Handle # concatenation
+                    if (this._match('#')) {
+                        const values = [value];
+                        while (this._match('#')) {
+                            this._read();
+                            values.push(this._parsePrimaryValue());
+                        }
+                        value = new tablegen.Value('concat', values);
+                    }
+                    args.push(value);
                 }
-                current = '';
             } else {
-                current += token.value === null ? token.type : token.value;
+                // Positional argument that doesn't start with an id
+                const value = this._parseValue();
+                args.push(value);
+            }
+
+            if (!this._eat(',')) {
+                break;
             }
         }
-        current = current.trim();
-        if (current) {
-            args.push(current);
-        }
+        this._expect('>');
         return args;
     }
 
@@ -983,7 +1002,7 @@ tablegen.Reader = class {
             }
             return new tablegen.Value('bang', { op, args, field });
         }
-        if (this._match('id')) {
+        if (this._match('id') || this._isKeyword(this._tokenizer.current().type)) {
             let value = this._read();
             if (this._match('<')) {
                 this._skip('<', '>');

Неке датотеке нису приказане због велике количине промена