|
|
@@ -1,6 +1,5 @@
|
|
|
|
|
|
// Experimental
|
|
|
-// contributor @tucan9389
|
|
|
|
|
|
const mlir = {};
|
|
|
|
|
|
@@ -258,6 +257,7 @@ mlir.Argument = class {
|
|
|
if (this.type) {
|
|
|
switch (this.type) {
|
|
|
case 'i64': case 'si64': this.type = 'int64'; break;
|
|
|
+ case 'i48': case 'si48': this.type = 'int48'; break;
|
|
|
case 'i32': case 'si32': this.type = 'int32'; break;
|
|
|
case 'i16': case 'si16': this.type = 'int16'; break;
|
|
|
case 'i8': case 'si8': this.type = 'int8'; break;
|
|
|
@@ -545,6 +545,21 @@ mlir.Tokenizer = class {
|
|
|
}
|
|
|
return;
|
|
|
}
|
|
|
+ if (this._current === '*') {
|
|
|
+ this._read();
|
|
|
+ while (this._current) {
|
|
|
+ if (this._current === '*') {
|
|
|
+ this._read();
|
|
|
+ if (this._current === '/') {
|
|
|
+ this._read();
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ this._read();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
throw new mlir.Error('Invalid comment.');
|
|
|
}
|
|
|
|
|
|
@@ -844,7 +859,7 @@ mlir.Parser = class {
|
|
|
this._dialects.set('tfr', new mlir.TFRDialect(operations));
|
|
|
this._dialects.set('tfrt', new mlir.TFRTDialect(operations));
|
|
|
this._dialects.set('tfrt_fallback', new mlir.Dialect('tfrt_fallback', operations));
|
|
|
- this._dialects.set('tfl', new mlir.Dialect('tfl', operations));
|
|
|
+ this._dialects.set('tfl', new mlir.TFLDialect(operations));
|
|
|
this._dialects.set('stdx', new mlir.StdxDialect(operations));
|
|
|
this._dialects.set('vm', new mlir.VMDialect(operations));
|
|
|
this._dialects.set('math', new mlir.MathDialect(operations));
|
|
|
@@ -1498,8 +1513,38 @@ mlir.Parser = class {
|
|
|
}
|
|
|
const open = this.accept('(');
|
|
|
// eslint-disable-next-line no-unmodified-loop-condition
|
|
|
- while (!this.match(')') && !this.match('->') && !this.match('{') && !this.match('}') && !this.match('[') && !this.match('=') && !this.match('^') && !(this.match(':') && !open)) {
|
|
|
+ while (!this.match(')') && !this.match('->') && !this.match('{') && !this.match('}') && !this.match('=') && !this.match('^') && !(this.match(':') && !open)) {
|
|
|
const input = {};
|
|
|
+ if (this.match('[')) {
|
|
|
+ this.expect('[');
|
|
|
+ const array = [];
|
|
|
+ while (!this.match(']')) {
|
|
|
+ if (this.match('%')) {
|
|
|
+ array.push(this.expect());
|
|
|
+ } else if (this.match('int')) {
|
|
|
+ array.push(parseInt(this.expect('int'), 10));
|
|
|
+ } else if (this.match('-')) {
|
|
|
+ this.expect('-');
|
|
|
+ if (this.match('int')) {
|
|
|
+ array.push(-parseInt(this.expect('int'), 10));
|
|
|
+ } else {
|
|
|
+ throw new mlir.Error(`Expected integer after '-' in array literal ${this.location()}`);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (!this.accept(',')) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ this.expect(']');
|
|
|
+ input.value = array;
|
|
|
+ inputs.push(input);
|
|
|
+ if (!this.accept(',')) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
if (this._token.kind === 'id' && this._token.value !== 'dense' && this._token.value !== 'dense_resource') {
|
|
|
const identifier = this.expect('id');
|
|
|
if (this.accept('(')) {
|
|
|
@@ -2607,6 +2652,7 @@ mlir.Utility = class {
|
|
|
case 'i8': return 'int8';
|
|
|
case 'i16': return 'int16';
|
|
|
case 'i32': return 'int32';
|
|
|
+ case 'i48': return 'int48';
|
|
|
case 'i64': return 'int64';
|
|
|
case 'si8': return 'int8';
|
|
|
case 'si16': return 'int16';
|
|
|
@@ -2872,9 +2918,27 @@ mlir.AssemblyFormatParser = class {
|
|
|
return { type: 'custom', parser, args };
|
|
|
}
|
|
|
if (remaining.startsWith('oilist(')) {
|
|
|
- this._pos += 'oilist'.length;
|
|
|
- this._parseParenList();
|
|
|
- return null;
|
|
|
+ this._pos += 'oilist('.length;
|
|
|
+ let content = '';
|
|
|
+ let depth = 1;
|
|
|
+ while (this._pos < this._format.length && depth > 0) {
|
|
|
+ const ch = this._format[this._pos];
|
|
|
+ if (ch === '(') {
|
|
|
+ depth++;
|
|
|
+ content += ch;
|
|
|
+ this._pos++;
|
|
|
+ } else if (ch === ')') {
|
|
|
+ depth--;
|
|
|
+ if (depth > 0) {
|
|
|
+ content += ch;
|
|
|
+ }
|
|
|
+ this._pos++;
|
|
|
+ } else {
|
|
|
+ content += ch;
|
|
|
+ this._pos++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return { type: 'oilist', content };
|
|
|
}
|
|
|
if (/^[:()[\]{}<>,=|]/.test(ch)) {
|
|
|
this._pos++;
|
|
|
@@ -2984,7 +3048,10 @@ mlir.Dialect = class {
|
|
|
this.registerCustomParser('CustomCallTarget', this._parseCustomCallTarget.bind(this));
|
|
|
this.registerCustomParser('VariadicOperandWithAttribute', this._parseVariadicOperandWithAttribute.bind(this));
|
|
|
this.registerCustomParser('DynamicIndexList', this._parseDynamicIndexList.bind(this));
|
|
|
+ this.registerCustomParser('Offsets', this._parseOffsets.bind(this));
|
|
|
this.registerCustomParser('SymbolVisibility', this._parseSymbolVisibility.bind(this));
|
|
|
+ this.registerCustomParser('SymbolAlias', this._parseSymbolAlias.bind(this));
|
|
|
+ this.registerCustomParser('WorkgroupCountRegion', this._parseWorkgroupCountRegion.bind(this));
|
|
|
this.registerCustomParser('OptionalUnitAttr', this._parseOptionalUnitAttr.bind(this));
|
|
|
for (const op of operations.get(name) || []) {
|
|
|
if (op.assemblyFormat) {
|
|
|
@@ -3077,9 +3144,11 @@ mlir.Dialect = class {
|
|
|
} else if (isAttribute) {
|
|
|
const attrValue = parser.parseValue();
|
|
|
if (attrValue) {
|
|
|
+ // Check if there's a type annotation (`: type`) after the attribute value
|
|
|
+ // This handles typed attributes like AnyAttr, TypedAttrInterface, etc.
|
|
|
if ((attrValue.type === 'int64' || attrValue.type === 'float32' || attrValue.type === 'boolean' || attrValue.type === 'dense') &&
|
|
|
parser.accept(':')) {
|
|
|
- parser.parseType();
|
|
|
+ attrValue.attrType = parser.parseType();
|
|
|
}
|
|
|
// For attributes, we only store the value, not the internal "type" field
|
|
|
// The type field here is just metadata about how the value was parsed (e.g., 'dense')
|
|
|
@@ -3232,7 +3301,19 @@ mlir.Dialect = class {
|
|
|
}
|
|
|
const result = fn(parser, directive.args);
|
|
|
if (result) {
|
|
|
- if (result.kind === 'SameOperandsAndResultType' && result.type) {
|
|
|
+ if (result.kind === 'SymbolVisibility' && result.visibility && directive.args && directive.args.length > 0) {
|
|
|
+ const attrName = directive.args[0].replace(/^\$/, '');
|
|
|
+ op.attributes.push({ name: attrName, value: result.visibility });
|
|
|
+ } else if (result.kind === 'SymbolAlias' && result.symbolName) {
|
|
|
+ if (result.symNameArg) {
|
|
|
+ op.attributes.push({ name: result.symNameArg, value: result.symbolName });
|
|
|
+ }
|
|
|
+ if (result.functionRefArg) {
|
|
|
+ op.attributes.push({ name: result.functionRefArg, value: result.symbolName });
|
|
|
+ }
|
|
|
+ } else if (result.kind === 'WorkgroupCountRegion' && result.region) {
|
|
|
+ op.regions.push(result.region);
|
|
|
+ } else if (result.kind === 'SameOperandsAndResultType' && result.type) {
|
|
|
for (const operand of op.operands) {
|
|
|
if (!operand.type) {
|
|
|
operand.type = result.type;
|
|
|
@@ -3271,6 +3352,9 @@ mlir.Dialect = class {
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
+ case 'oilist': {
|
|
|
+ break;
|
|
|
+ }
|
|
|
case 'optional_group': {
|
|
|
let shouldParse = false;
|
|
|
|
|
|
@@ -3286,7 +3370,21 @@ mlir.Dialect = class {
|
|
|
if (firstElem.name === 'overflowFlags') {
|
|
|
shouldParse = parser.match('id', 'overflow');
|
|
|
} else {
|
|
|
- shouldParse = parser.match('%');
|
|
|
+ // Check if this is an attribute or an operand
|
|
|
+ let isFirstAttribute = false;
|
|
|
+ if (opInfo.metadata && opInfo.metadata.attributes) {
|
|
|
+ const attrInfo = opInfo.metadata.attributes.find((attr) => attr.name === firstElem.name);
|
|
|
+ if (attrInfo) {
|
|
|
+ isFirstAttribute = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (isFirstAttribute) {
|
|
|
+ // For attributes, check if there's a value present (id, int, float, etc.)
|
|
|
+ shouldParse = parser.match('id') || parser.match('int') || parser.match('float') || parser.match('[') || parser.match('{');
|
|
|
+ } else {
|
|
|
+ // For operands, check for %
|
|
|
+ shouldParse = parser.match('%');
|
|
|
+ }
|
|
|
}
|
|
|
} else if (firstElem.type === 'operands') {
|
|
|
shouldParse = parser.match('(') || parser.match('%');
|
|
|
@@ -3778,6 +3876,29 @@ mlir.Dialect = class {
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
+ _parseOffsets(parser, args) {
|
|
|
+ const values = [];
|
|
|
+ while (parser.match('int') || parser.match('-')) {
|
|
|
+ if (parser.accept('-')) {
|
|
|
+ if (parser.match('int')) {
|
|
|
+ values.push(-parseInt(parser.expect('int'), 10));
|
|
|
+ } else {
|
|
|
+ throw new mlir.Error(`Expected integer after '-' in offsets ${parser.location()}`);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ values.push(parseInt(parser.expect('int'), 10));
|
|
|
+ }
|
|
|
+ if (!parser.accept(',')) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (args && args.length > 0) {
|
|
|
+ const [attrName] = args;
|
|
|
+ return { name: attrName, value: values };
|
|
|
+ }
|
|
|
+ return values;
|
|
|
+ }
|
|
|
+
|
|
|
_parseVariadicOperandWithAttribute(parser, /*, args */) {
|
|
|
const result = {
|
|
|
kind: 'VariadicOperandWithAttribute',
|
|
|
@@ -3819,6 +3940,30 @@ mlir.Dialect = class {
|
|
|
return { kind: 'SymbolVisibility', visibility: null };
|
|
|
}
|
|
|
|
|
|
+ _parseSymbolAlias(parser, args) {
|
|
|
+ if (!parser.match('@')) {
|
|
|
+ return { kind: 'SymbolAlias' };
|
|
|
+ }
|
|
|
+ const symbolName = parser.expect('@');
|
|
|
+ const result = { kind: 'SymbolAlias', symbolName };
|
|
|
+ if (args && args.length >= 1) {
|
|
|
+ result.symNameArg = args[0].replace(/^\$/, '');
|
|
|
+ }
|
|
|
+ if (args && args.length >= 2) {
|
|
|
+ result.functionRefArg = args[1].replace(/^\$/, '');
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ _parseWorkgroupCountRegion(parser /*, args */) {
|
|
|
+ if (parser.match('{')) {
|
|
|
+ const region = {};
|
|
|
+ parser.parseRegion(region);
|
|
|
+ return { kind: 'WorkgroupCountRegion', region };
|
|
|
+ }
|
|
|
+ return { kind: 'WorkgroupCountRegion', region: null };
|
|
|
+ }
|
|
|
+
|
|
|
_parseOptionalUnitAttr(parser /*, args */) {
|
|
|
if (parser.match('id') || parser.match('%') || parser.match('(')) {
|
|
|
return { kind: 'OptionalUnitAttr', present: true };
|
|
|
@@ -4481,12 +4626,32 @@ mlir.VectorDialect = class extends mlir.Dialect {
|
|
|
}
|
|
|
return true;
|
|
|
}
|
|
|
+ if (name === 'vector.mask') {
|
|
|
+ if (parser.match('%')) {
|
|
|
+ const mask = parser.expect('%');
|
|
|
+ op.operands.push({ value: mask, name: 'mask' });
|
|
|
+ }
|
|
|
+ if (parser.accept(',')) {
|
|
|
+ const passthru = parser.expect('%');
|
|
|
+ op.operands.push({ value: passthru, name: 'passthru' });
|
|
|
+ }
|
|
|
+ if (parser.match('{')) {
|
|
|
+ const region = {};
|
|
|
+ parser.parseRegion(region);
|
|
|
+ op.regions.push(region);
|
|
|
+ }
|
|
|
+ parser.parseOptionalAttrDict(op.attributes);
|
|
|
+ if (parser.accept(':')) {
|
|
|
+ parser.parseArgumentTypes(op.operands);
|
|
|
+ if (parser.accept('->')) {
|
|
|
+ parser.parseArgumentTypes(op.results);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
if (name === 'vector.transfer_read' || name === 'vector.transfer_write') {
|
|
|
return this._parseTransferOp(parser, op);
|
|
|
}
|
|
|
- // Handle old vector.extract syntax (pre-2023) without 'from' keyword
|
|
|
- // Old: %r = vector.extract %v[0] : vector<4xf32>
|
|
|
- // New: %r = vector.extract %v[0] : f32 from vector<4xf32>
|
|
|
if (name === 'vector.extract' && !op.isGeneric) {
|
|
|
return this._parseExtractOp(parser, op);
|
|
|
}
|
|
|
@@ -8316,6 +8481,43 @@ mlir.SdfgDialect = class extends mlir.Dialect {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+mlir.TFLDialect = class extends mlir.Dialect {
|
|
|
+
|
|
|
+ constructor(operations) {
|
|
|
+ super('tfl', operations);
|
|
|
+ // Operations that use parseOneResultSameOperandTypeOp in tfl_ops.cc
|
|
|
+ // Format: operands attr-dict : single-type
|
|
|
+ this._binaryOps = new Set([
|
|
|
+ 'add', 'sub', 'mul', 'div', 'floor_div', 'pow', 'squared_difference',
|
|
|
+ 'less', 'less_equal', 'greater', 'greater_equal', 'not_equal',
|
|
|
+ 'logical_and', 'logical_or'
|
|
|
+ ]);
|
|
|
+ }
|
|
|
+
|
|
|
+ parseOperation(parser, opName, op) {
|
|
|
+ const name = opName.replace(/^"|"$/g, '');
|
|
|
+ const opKind = name.substring('tfl.'.length);
|
|
|
+ if (this._binaryOps.has(opKind)) {
|
|
|
+ // Parse: operands attr-dict : type
|
|
|
+ op.operands = parser.parseArguments();
|
|
|
+ parser.parseOptionalAttrDict(op.attributes);
|
|
|
+ if (parser.accept(':')) {
|
|
|
+ const type = parser.parseType();
|
|
|
+ // All operands and result share the same type
|
|
|
+ for (const operand of op.operands) {
|
|
|
+ operand.type = type;
|
|
|
+ }
|
|
|
+ if (op.results.length > 0) {
|
|
|
+ op.results[0].type = type;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ return super.parseOperation(parser, opName, op);
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
mlir.TFDialect = class extends mlir.Dialect {
|
|
|
|
|
|
constructor(operations) {
|
|
|
@@ -8451,6 +8653,15 @@ mlir.TransformDialect = class extends mlir.Dialect {
|
|
|
if (name === 'transform.get_result') {
|
|
|
return this._parseGetResultOp(parser, op);
|
|
|
}
|
|
|
+ if (name === 'transform.func.cast_and_call') {
|
|
|
+ return this._parseFuncCastAndCallOp(parser, op);
|
|
|
+ }
|
|
|
+ if (name === 'transform.func.replace_func_signature') {
|
|
|
+ return this._parseReplaceFuncSignatureOp(parser, op);
|
|
|
+ }
|
|
|
+ if (name === 'transform.func.deduplicate_func_args') {
|
|
|
+ return this._parseDeduplicateFuncArgsOp(parser, op);
|
|
|
+ }
|
|
|
if (name.startsWith('transform.test.')) {
|
|
|
return this._parseTestOp(parser, name, op);
|
|
|
}
|
|
|
@@ -8525,6 +8736,106 @@ mlir.TransformDialect = class extends mlir.Dialect {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+ _parseFuncCastAndCallOp(parser, op) {
|
|
|
+ if (parser.match('@')) {
|
|
|
+ const funcName = parser.expect('@');
|
|
|
+ op.attributes.push({ name: 'function_name', value: funcName });
|
|
|
+ }
|
|
|
+ if (parser.match('%') && !parser.match('id', 'before') && !parser.match('id', 'after')) {
|
|
|
+ const funcHandle = parser.expect('%');
|
|
|
+ op.operands.push({ value: funcHandle, name: 'function' });
|
|
|
+ }
|
|
|
+ if (parser.accept('(')) {
|
|
|
+ while (!parser.match(')')) {
|
|
|
+ const value = parser.expect('%');
|
|
|
+ op.operands.push({ value, name: 'inputs' });
|
|
|
+ if (!parser.accept(',')) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ parser.expect(')');
|
|
|
+ }
|
|
|
+ if (parser.accept('->')) {
|
|
|
+ while (!parser.match('id') && parser.match('%')) {
|
|
|
+ const value = parser.expect('%');
|
|
|
+ op.operands.push({ value, name: 'outputs' });
|
|
|
+ if (!parser.accept(',')) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (parser.accept('id', 'after')) {
|
|
|
+ op.attributes.push({ name: 'insert_after', value: true });
|
|
|
+ } else {
|
|
|
+ parser.accept('id', 'before');
|
|
|
+ }
|
|
|
+ if (parser.match('%')) {
|
|
|
+ const insertionPoint = parser.expect('%');
|
|
|
+ op.operands.push({ value: insertionPoint, name: 'insertion_point' });
|
|
|
+ }
|
|
|
+ if (parser.match('{')) {
|
|
|
+ const region = {};
|
|
|
+ parser.parseRegion(region);
|
|
|
+ op.regions.push(region);
|
|
|
+ }
|
|
|
+ parser.parseOptionalAttrDict(op.attributes);
|
|
|
+ if (parser.accept(':')) {
|
|
|
+ parser.parseArgumentTypes(op.operands);
|
|
|
+ if (parser.accept('->')) {
|
|
|
+ parser.parseArgumentTypes(op.results);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ _parseReplaceFuncSignatureOp(parser, op) {
|
|
|
+ if (parser.match('@')) {
|
|
|
+ const funcName = parser.expect('@');
|
|
|
+ op.attributes.push({ name: 'function_name', value: funcName });
|
|
|
+ }
|
|
|
+ if (parser.accept('id', 'args_interchange')) {
|
|
|
+ parser.expect('=');
|
|
|
+ const argsInterchange = parser.parseValue();
|
|
|
+ op.attributes.push({ name: 'args_interchange', value: argsInterchange.value });
|
|
|
+ }
|
|
|
+ if (parser.accept('id', 'results_interchange')) {
|
|
|
+ parser.expect('=');
|
|
|
+ const resultsInterchange = parser.parseValue();
|
|
|
+ op.attributes.push({ name: 'results_interchange', value: resultsInterchange.value });
|
|
|
+ }
|
|
|
+ if (parser.accept('id', 'at')) {
|
|
|
+ const module = parser.expect('%');
|
|
|
+ op.operands.push({ value: module, name: 'module' });
|
|
|
+ }
|
|
|
+ parser.parseOptionalAttrDict(op.attributes);
|
|
|
+ if (parser.accept(':')) {
|
|
|
+ parser.parseArgumentTypes(op.operands);
|
|
|
+ if (parser.accept('->')) {
|
|
|
+ parser.parseArgumentTypes(op.results);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ _parseDeduplicateFuncArgsOp(parser, op) {
|
|
|
+ if (parser.match('@')) {
|
|
|
+ const funcName = parser.expect('@');
|
|
|
+ op.attributes.push({ name: 'function_name', value: funcName });
|
|
|
+ }
|
|
|
+ if (parser.accept('id', 'at')) {
|
|
|
+ const module = parser.expect('%');
|
|
|
+ op.operands.push({ value: module, name: 'module' });
|
|
|
+ }
|
|
|
+ parser.parseOptionalAttrDict(op.attributes);
|
|
|
+ if (parser.accept(':')) {
|
|
|
+ parser.parseArgumentTypes(op.operands);
|
|
|
+ if (parser.accept('->')) {
|
|
|
+ parser.parseArgumentTypes(op.results);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
_parseGetResultOp(parser, op) {
|
|
|
// Parse: transform.get_result %op[0] : (!transform.any_op) -> !transform.any_value
|
|
|
if (parser.match('%')) {
|