|
|
@@ -96,10 +96,12 @@ mlir.Model = class {
|
|
|
if (isFunc(op.name)) {
|
|
|
funcs.push(op);
|
|
|
} else if (isModule(op.name)) {
|
|
|
- const modulePath = [...path, `$${identifier++}`];
|
|
|
+ let name = op.getAttr('sym_name');
|
|
|
+ name = name ? name.value : `$${identifier++}`;
|
|
|
+ const modulePath = [...path, name];
|
|
|
for (const region of op.regions || []) {
|
|
|
for (const blk of region.blocks || []) {
|
|
|
- collectModules(blk.operations || [], modulePath, op.attributes);
|
|
|
+ collectModules(blk.operations || [], modulePath, op.getAttrDictionary());
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
@@ -126,13 +128,13 @@ mlir.Model = class {
|
|
|
return '';
|
|
|
};
|
|
|
const functions = new Map();
|
|
|
- let funcIndex = 0;
|
|
|
+ let identifier = 0;
|
|
|
for (const module of modules) {
|
|
|
const prefix = formatPrefix(module.path, module.symName);
|
|
|
for (const func of module.funcs) {
|
|
|
- const sym_name = func.attributes.get('sym_name');
|
|
|
- const base = sym_name ? sym_name.value : `$${funcIndex}`;
|
|
|
- funcIndex++;
|
|
|
+ const sym_name = func.getAttr('sym_name');
|
|
|
+ const base = sym_name ? sym_name.value : `$${identifier}`;
|
|
|
+ identifier++;
|
|
|
const name = prefix ? `${prefix}::@${base}` : `@${base}`;
|
|
|
functions.set(name, { func, prefix, base, module });
|
|
|
}
|
|
|
@@ -145,11 +147,9 @@ mlir.Model = class {
|
|
|
for (const module of modules) {
|
|
|
if (module.ops.length > 0 || module.attributes.size > 0) {
|
|
|
const name = formatPrefix(module.path, module.symName) || '';
|
|
|
- const op = {
|
|
|
- name: 'builtin.module',
|
|
|
- attributes: module.attributes || [],
|
|
|
- regions: [{ blocks: [{ operations: module.ops, arguments: [] }] }]
|
|
|
- };
|
|
|
+ const op = new mlir.Operation('builtin.module');
|
|
|
+ op.attributes = module.attributes;
|
|
|
+ op.regions = [{ blocks: [{ operations: module.ops, arguments: [] }] }];
|
|
|
const graph = context.graph(op, name);
|
|
|
this.modules.push(graph);
|
|
|
}
|
|
|
@@ -227,7 +227,7 @@ mlir.Graph = class {
|
|
|
const operation = {
|
|
|
type: op.kind || op.name,
|
|
|
identifier: op.name,
|
|
|
- attributes: op.attributes,
|
|
|
+ attributes: op.getAttrDictionary(),
|
|
|
operands: [],
|
|
|
results: [],
|
|
|
regions: op.regions || [],
|
|
|
@@ -309,7 +309,7 @@ mlir.Graph = class {
|
|
|
op.results[0].value.length === 1) {
|
|
|
const [result] = op.results[0].value;
|
|
|
if (result.to && result.to.length === 1) {
|
|
|
- const valueAttr = op.attributes.get('value');
|
|
|
+ const valueAttr = op.attributes.get('value') || op.attributes.get('values');
|
|
|
if ((valueAttr instanceof mlir.DenseElementsAttr || valueAttr instanceof mlir.DenseResourceElementsAttr) &&
|
|
|
valueAttr.value !== null &&
|
|
|
valueAttr.type && valueAttr.type.toString().startsWith('tensor<')) {
|
|
|
@@ -345,7 +345,8 @@ mlir.Graph = class {
|
|
|
const [returnValue] = operand.value;
|
|
|
if (returnValue && typeof returnValue.name === 'string' && returnValue.name.startsWith('%')) {
|
|
|
const output = this.outputs[i];
|
|
|
- output.value[0] = new mlir.Value(returnValue.name, returnValue.type || output.value[0].type, '', null);
|
|
|
+ const returnType = returnValue.type ? mlir.Utility.valueType(returnValue.type) : output.value[0].type;
|
|
|
+ output.value[0] = new mlir.Value(returnValue.name, returnType, '', null);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -611,6 +612,7 @@ mlir.Operation = class {
|
|
|
this.operands = [];
|
|
|
this.regions = [];
|
|
|
this.results = [];
|
|
|
+ this.propertiesAttr = null;
|
|
|
}
|
|
|
|
|
|
addAttribute(name, value) {
|
|
|
@@ -622,6 +624,27 @@ mlir.Operation = class {
|
|
|
this.regions.push(region);
|
|
|
return region;
|
|
|
}
|
|
|
+
|
|
|
+ getAttr(name) {
|
|
|
+ if (this.propertiesAttr instanceof mlir.DictionaryAttr) {
|
|
|
+ const value = this.propertiesAttr.get(name);
|
|
|
+ if (value !== undefined) {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return this.attributes.get(name);
|
|
|
+ }
|
|
|
+
|
|
|
+ getAttrDictionary() {
|
|
|
+ if (this.propertiesAttr instanceof mlir.DictionaryAttr) {
|
|
|
+ const result = new Map(this.attributes);
|
|
|
+ for (const [name, value] of this.propertiesAttr.value) {
|
|
|
+ result.set(name, value);
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+ return this.attributes;
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
mlir.Token = class {
|
|
|
@@ -1379,7 +1402,7 @@ mlir.Parser = class {
|
|
|
this.parseSuccessors(op.successors);
|
|
|
}
|
|
|
if (this.accept('<')) {
|
|
|
- op.properties = this.parseAttribute();
|
|
|
+ op.propertiesAttr = this.parseAttribute();
|
|
|
this.expect('>');
|
|
|
}
|
|
|
if (this.accept('(')) {
|
|
|
@@ -2650,11 +2673,7 @@ mlir.Parser = class {
|
|
|
if (this.match('{')) {
|
|
|
const attributes = new Map();
|
|
|
this.parseAttributeDict(attributes);
|
|
|
- const dict = {};
|
|
|
- for (const [name, value] of attributes) {
|
|
|
- dict[name] = value;
|
|
|
- }
|
|
|
- return { value: dict };
|
|
|
+ return new mlir.DictionaryAttr(attributes);
|
|
|
}
|
|
|
if (this.match('#')) {
|
|
|
const attr = this.parseExtendedAttr();
|
|
|
@@ -4505,6 +4524,28 @@ mlir.ArrayAttr = class extends mlir.Attribute {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+mlir.DictionaryAttr = class extends mlir.Attribute {
|
|
|
+
|
|
|
+ constructor(value) {
|
|
|
+ super();
|
|
|
+ this._value = value; // Map of name -> Attribute
|
|
|
+ }
|
|
|
+
|
|
|
+ get value() {
|
|
|
+ return this._value;
|
|
|
+ }
|
|
|
+
|
|
|
+ get(name) {
|
|
|
+ return this._value.get(name);
|
|
|
+ }
|
|
|
+
|
|
|
+ toString() {
|
|
|
+ const entries = Array.from(this._value.entries())
|
|
|
+ .map(([k, v]) => `${k} = ${v && v.toString ? v.toString() : String(v)}`);
|
|
|
+ return `{${entries.join(', ')}}`;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
mlir.DenseArrayAttr = class extends mlir.Attribute {
|
|
|
|
|
|
constructor(elements, type) {
|
|
|
@@ -5557,12 +5598,12 @@ mlir.Dialect = class {
|
|
|
this.registerCustomAttribute('LevelAttr', this._parseIntegerAttr.bind(this, 'index'));
|
|
|
this.registerCustomType('Optional', this._parseOptional.bind(this));
|
|
|
for (const metadata of operations.get(name) || []) {
|
|
|
- const op = { metadata };
|
|
|
+ const opInfo = { metadata };
|
|
|
if (metadata.assemblyFormat) {
|
|
|
const parser = new mlir.AssemblyFormatParser(metadata);
|
|
|
- op.directives = parser.parse();
|
|
|
+ opInfo.directives = parser.parse();
|
|
|
}
|
|
|
- this._operations.set(metadata.name, op);
|
|
|
+ this._operations.set(metadata.name, opInfo);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -7482,7 +7523,7 @@ mlir.StableHLODialect = class extends mlir.HLODialect {
|
|
|
if (opName === 'stablehlo.constant') {
|
|
|
if (parser.accept('(') && parser.accept(')')) {
|
|
|
if (parser.accept('<')) {
|
|
|
- op.properties = parser.parseAttribute();
|
|
|
+ op.propertiesAttr = parser.parseAttribute();
|
|
|
parser.expect('>');
|
|
|
}
|
|
|
parser.parseOptionalAttrDict(op.attributes);
|
|
|
@@ -11901,26 +11942,16 @@ mlir.MhloDialect = class extends mlir.HLODialect {
|
|
|
}
|
|
|
block.arguments.push({ value: '%lhs', type: elementType ? `tensor<${elementType}>` : null });
|
|
|
block.arguments.push({ value: '%rhs', type: elementType ? `tensor<${elementType}>` : null });
|
|
|
- const innerOp = {
|
|
|
- name: innerOpName,
|
|
|
- operands: [{ value: '%lhs' }, { value: '%rhs' }],
|
|
|
- results: [{ value: '%0', type: elementType ? `tensor<${elementType}>` : null }],
|
|
|
- attributes: [],
|
|
|
- regions: []
|
|
|
- };
|
|
|
+ const innerOp = new mlir.Operation(innerOpName);
|
|
|
+ innerOp.operands.push({ value: '%lhs' });
|
|
|
+ innerOp.operands.push({ value: '%rhs' });
|
|
|
+ innerOp.results.push({ value: '%0', type: elementType ? `tensor<${elementType}>` : null });
|
|
|
block.operations.push(innerOp);
|
|
|
- const returnOp = {
|
|
|
- name: 'mhlo.return',
|
|
|
- operands: [{ value: '%0' }],
|
|
|
- results: [],
|
|
|
- attributes: [],
|
|
|
- regions: []
|
|
|
- };
|
|
|
+ const returnOp = new mlir.Operation('mhlo.return');
|
|
|
+ returnOp.operands.push({ value: '%0' });
|
|
|
block.operations.push(returnOp);
|
|
|
-
|
|
|
region.blocks.push(block);
|
|
|
op.regions.push(region);
|
|
|
-
|
|
|
return true;
|
|
|
}
|
|
|
|