|
|
@@ -717,6 +717,23 @@ const schema = async () => {
|
|
|
}
|
|
|
return;
|
|
|
}
|
|
|
+ // Handle !listconcat expressions (stored as bang type)
|
|
|
+ if (traitsArg.type === 'bang' && traitsArg.value) {
|
|
|
+ const bangOp = traitsArg.value.operator || traitsArg.value.op;
|
|
|
+ if (bangOp === 'listconcat' && traitsArg.value.args) {
|
|
|
+ for (const arg of traitsArg.value.args) {
|
|
|
+ extractTraitsFromList(arg);
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Handle arrays directly (evaluated listconcat results)
|
|
|
+ if (Array.isArray(traitsArg)) {
|
|
|
+ for (const element of traitsArg) {
|
|
|
+ extractTraitsFromList({ type: 'list', value: [element] });
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
if (traitsArg.type === 'list' && traitsArg.value) {
|
|
|
for (const trait of traitsArg.value) {
|
|
|
const traitName = trait.type === 'def' ? trait.value : null;
|
|
|
@@ -727,7 +744,10 @@ const schema = async () => {
|
|
|
if (namesOperand && namesOperand.value && namesOperand.value.type === 'list') {
|
|
|
const names = namesOperand.value.value.filter((v) => v.type === 'string').map((v) => v.value);
|
|
|
if (names.length > 0) {
|
|
|
- traits.push({ type: `AllTypesMatch<[${names.map((n) => `'${n}'`).join(', ')}]>` });
|
|
|
+ const traitType = `AllTypesMatch<[${names.map((n) => `'${n}'`).join(', ')}]>`;
|
|
|
+ if (traits.every((t) => t.type !== traitType)) {
|
|
|
+ traits.push({ type: traitType });
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -740,7 +760,10 @@ const schema = async () => {
|
|
|
const to = getStringValue(operands[2]);
|
|
|
const transformer = getStringValue(operands[3]);
|
|
|
if (from && to && transformer) {
|
|
|
- traits.push({ type: `TypesMatchWith<'${from}', '${to}', '${transformer}'>` });
|
|
|
+ const traitType = `TypesMatchWith<'${from}', '${to}', '${transformer}'>`;
|
|
|
+ if (traits.every((t) => t.type !== traitType)) {
|
|
|
+ traits.push({ type: traitType });
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -749,10 +772,29 @@ const schema = async () => {
|
|
|
traits.every((t) => t.type !== 'AttrSizedOperandSegments')) {
|
|
|
traits.push({ type: 'AttrSizedOperandSegments' });
|
|
|
}
|
|
|
+ // Extract SameOperandsAndResultType trait (for type inference)
|
|
|
+ if ((traitName === 'SameOperandsAndResultType' || traitDag === 'SameOperandsAndResultType') &&
|
|
|
+ traits.every((t) => t.type !== 'SameOperandsAndResultType')) {
|
|
|
+ traits.push({ type: 'SameOperandsAndResultType' });
|
|
|
+ }
|
|
|
// Extract IsolatedFromAbove trait
|
|
|
if (traitName === 'IsolatedFromAbove' && traits.every((trait) => trait.type !== 'IsolatedFromAbove')) {
|
|
|
traits.push({ type: 'IsolatedFromAbove' });
|
|
|
}
|
|
|
+ // Extract InferTypeOpInterface trait (for type inference)
|
|
|
+ if (traitName === 'InferTypeOpInterface' && traits.every((trait) => trait.type !== 'InferTypeOpInterface')) {
|
|
|
+ traits.push({ type: 'InferTypeOpInterface' });
|
|
|
+ }
|
|
|
+ // Check for DeclareOpInterfaceMethods<InferTypeOpInterface>
|
|
|
+ if (traitDag === 'DeclareOpInterfaceMethods' && trait.value && trait.value.operands) {
|
|
|
+ const interfaceOperand = trait.value.operands[0];
|
|
|
+ if (interfaceOperand && interfaceOperand.value && interfaceOperand.value.type === 'def') {
|
|
|
+ const interfaceName = interfaceOperand.value.value;
|
|
|
+ if (interfaceName === 'InferTypeOpInterface' && traits.every((t) => t.type !== 'InferTypeOpInterface')) {
|
|
|
+ traits.push({ type: 'InferTypeOpInterface' });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
// Extract defaultDialect from OpAsmOpInterface
|
|
|
if (traitName === 'OpAsmOpInterface' || traitDag === 'DeclareOpInterfaceMethods') {
|
|
|
if (traitDag === 'DeclareOpInterfaceMethods' && trait.value && trait.value.operands) {
|
|
|
@@ -789,6 +831,21 @@ const schema = async () => {
|
|
|
// Recursively look at parent class definition
|
|
|
const parentClass = parser.getClass(parent.name);
|
|
|
if (parentClass && parentClass.parents) {
|
|
|
+ // Also extract traits from the parent class's own parent args
|
|
|
+ // This handles cases like Linalg_RelayoutOp which defines TypesMatchWith
|
|
|
+ // in its own inheritance from Op, not in args passed by children
|
|
|
+ for (const classParent of parentClass.parents) {
|
|
|
+ // Look for Op parent which typically has traits in args[2]
|
|
|
+ if (classParent.name === 'Op' && classParent.args && classParent.args.length >= 3) {
|
|
|
+ extractTraitsFromList(classParent.args[2]);
|
|
|
+ }
|
|
|
+ // Also check args that might contain traits (e.g., listconcat results)
|
|
|
+ if (classParent.args) {
|
|
|
+ for (const arg of classParent.args) {
|
|
|
+ extractTraitsFromList(arg);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
extractTraitsFromParents(parentClass.parents, visited);
|
|
|
}
|
|
|
}
|