|
|
@@ -140,7 +140,7 @@ mlir.Model = class {
|
|
|
functions.set(name, { func, prefix, base, module });
|
|
|
}
|
|
|
}
|
|
|
- const context = new _.Context(metadata, functions);
|
|
|
+ const context = new mlir.Context(metadata, functions);
|
|
|
for (const [name, info] of functions) {
|
|
|
const graph = context.graph(info.func, name);
|
|
|
this.functions.push(graph);
|
|
|
@@ -194,7 +194,7 @@ mlir.Graph = class {
|
|
|
const input = inputs[i];
|
|
|
// args[i] is an _.Value with .name set by parseRegion
|
|
|
const name = args[i] && args[i].name ? args[i].name : `%arg${i}`;
|
|
|
- const type = _.Utility.valueType(input.type || input);
|
|
|
+ const type = mlir.Utility.valueType(input.type || input);
|
|
|
const value = new mlir.Value(name, type, '', null);
|
|
|
const argument = new mlir.Argument(name, [value]);
|
|
|
this.inputs.push(argument);
|
|
|
@@ -202,7 +202,7 @@ mlir.Graph = class {
|
|
|
for (let i = 0; i < results.length; i++) {
|
|
|
const output = results[i];
|
|
|
const name = output.value || i.toString();
|
|
|
- const type = _.Utility.valueType(output.type);
|
|
|
+ const type = mlir.Utility.valueType(output.type);
|
|
|
const valueName = output.value || output.name || `%result${i}`;
|
|
|
const value = new mlir.Value(valueName, type, '', null);
|
|
|
const argument = new mlir.Argument(name, [value]);
|
|
|
@@ -248,9 +248,10 @@ mlir.Graph = class {
|
|
|
const input = op.operands[i];
|
|
|
// Determine operand name: use metadata if available, or variadic name if past metadata bounds
|
|
|
let inputName = null;
|
|
|
+ const isVariadicOverflow = lastVariadicIndex >= 0 && i >= lastVariadicIndex;
|
|
|
if (opMetadata && opMetadata.operands && opMetadata.operands[i]) {
|
|
|
inputName = opMetadata.operands[i].name;
|
|
|
- } else if (lastVariadicIndex >= 0 && i >= lastVariadicIndex) {
|
|
|
+ } else if (isVariadicOverflow) {
|
|
|
// Operand index exceeds metadata, use last variadic operand name
|
|
|
inputName = lastVariadicName;
|
|
|
} else {
|
|
|
@@ -261,8 +262,13 @@ mlir.Graph = class {
|
|
|
}
|
|
|
const value = values.map(input.name);
|
|
|
value.to.push(operation);
|
|
|
- const args = [{ name: input.name, type: input.type }];
|
|
|
- operation.operands.push({ name: inputName, value: args });
|
|
|
+ const arg = { name: input.name, type: input.type };
|
|
|
+ // Group variadic operands into single argument with multiple values
|
|
|
+ if (isVariadicOverflow && operation.operands.length > 0 && operation.operands[operation.operands.length - 1].name === inputName) {
|
|
|
+ operation.operands[operation.operands.length - 1].value.push(arg);
|
|
|
+ } else {
|
|
|
+ operation.operands.push({ name: inputName, value: [arg] });
|
|
|
+ }
|
|
|
}
|
|
|
const results = op.results;
|
|
|
// Find the last variadic result in metadata (if any) for grouping
|
|
|
@@ -285,7 +291,7 @@ mlir.Graph = class {
|
|
|
continue;
|
|
|
}
|
|
|
const value = values.map(output.name);
|
|
|
- value.type = _.Utility.valueType(output.type);
|
|
|
+ value.type = mlir.Utility.valueType(output.type);
|
|
|
value.from.push(operation);
|
|
|
// Determine result name: use metadata if available, or variadic name if past metadata bounds
|
|
|
let outputName = null;
|
|
|
@@ -320,6 +326,7 @@ mlir.Graph = class {
|
|
|
for (const op of operations) {
|
|
|
if (constantTypes.has(op.type) &&
|
|
|
op.operands.length === 0 &&
|
|
|
+ op.attributes.size === 1 &&
|
|
|
op.results.length === 1 &&
|
|
|
op.results[0].value.length === 1) {
|
|
|
const [result] = op.results[0].value;
|
|
|
@@ -328,7 +335,7 @@ mlir.Graph = class {
|
|
|
if ((valueAttr instanceof _.DenseElementsAttr || valueAttr instanceof _.DenseResourceElementsAttr) &&
|
|
|
valueAttr.value !== null &&
|
|
|
valueAttr.type && valueAttr.type.toString().startsWith('tensor<')) {
|
|
|
- const type = _.Utility.valueType(valueAttr.type);
|
|
|
+ const type = mlir.Utility.valueType(valueAttr.type);
|
|
|
if (type instanceof mlir.TensorType) {
|
|
|
constantMap.set(result.name, new mlir.Tensor(type, valueAttr.value));
|
|
|
op.delete = true;
|
|
|
@@ -340,7 +347,13 @@ mlir.Graph = class {
|
|
|
const tensor = (arg) => {
|
|
|
if (!tensors.has(arg.name)) {
|
|
|
const initializer = constantMap.get(arg.name) || null;
|
|
|
- tensors.set(arg.name, new mlir.Value(arg.name, arg.type, null, initializer));
|
|
|
+ let type = null;
|
|
|
+ if (arg.type instanceof mlir.TensorType) {
|
|
|
+ type = arg.type;
|
|
|
+ } else if (arg.type) {
|
|
|
+ type = mlir.Utility.valueType(arg.type);
|
|
|
+ }
|
|
|
+ tensors.set(arg.name, new mlir.Value(arg.name, type, null, initializer));
|
|
|
}
|
|
|
return tensors.get(arg.name);
|
|
|
};
|
|
|
@@ -360,7 +373,7 @@ mlir.Graph = class {
|
|
|
const [returnValue] = operand.value;
|
|
|
if (returnValue && typeof returnValue.name === 'string' && returnValue.name.startsWith('%')) {
|
|
|
const output = this.outputs[i];
|
|
|
- const returnType = _.Utility.valueType(returnValue.type);
|
|
|
+ const returnType = mlir.Utility.valueType(returnValue.type);
|
|
|
output.value[0] = new mlir.Value(returnValue.name, returnType, '', null);
|
|
|
}
|
|
|
}
|
|
|
@@ -506,7 +519,7 @@ mlir.Node = class {
|
|
|
if (input.type) {
|
|
|
const typeStr = input.type instanceof _.Type ? input.type.toString() : input.type;
|
|
|
if (typeStr.startsWith('tensor<')) {
|
|
|
- const type = _.Utility.valueType(typeStr);
|
|
|
+ const type = mlir.Utility.valueType(typeStr);
|
|
|
const value = new mlir.Tensor(type, input.value);
|
|
|
argument = new mlir.Argument(input.name, value, 'tensor');
|
|
|
} else {
|
|
|
@@ -548,10 +561,10 @@ mlir.Node = class {
|
|
|
type = 'function';
|
|
|
}
|
|
|
} else if (attr instanceof _.DenseElementsAttr && attr.value !== null) {
|
|
|
- value = new mlir.Tensor(_.Utility.valueType(attr.type), attr.value);
|
|
|
+ value = new mlir.Tensor(mlir.Utility.valueType(attr.type), attr.value);
|
|
|
type = 'tensor';
|
|
|
} else if (attr instanceof _.DenseResourceElementsAttr) {
|
|
|
- value = new mlir.Tensor(_.Utility.valueType(attr.type), null);
|
|
|
+ value = new mlir.Tensor(mlir.Utility.valueType(attr.type), null);
|
|
|
type = 'tensor';
|
|
|
} else if (attr instanceof _.ArrayAttr || attr instanceof _.DenseArrayAttr) {
|
|
|
value = attr.value;
|
|
|
@@ -568,8 +581,9 @@ mlir.Node = class {
|
|
|
const region = op.regions[i];
|
|
|
if (region.blocks && region.blocks.length > 0) {
|
|
|
const name = (opMetadata.regions && opMetadata.regions[i] ? opMetadata.regions[i].name : null) || i.toString();
|
|
|
+ const blockName = region.blocks[0].name || '';
|
|
|
const func = { name: '', attributes: new Map(), regions: [region] };
|
|
|
- const graph = new mlir.Graph(metadata, func, context, '');
|
|
|
+ const graph = new mlir.Graph(metadata, func, context, blockName);
|
|
|
const argument = new mlir.Argument(name, graph, 'graph');
|
|
|
this.blocks.push(argument);
|
|
|
}
|
|
|
@@ -590,7 +604,7 @@ mlir.Tensor = class {
|
|
|
mlir.TensorType = class {
|
|
|
|
|
|
constructor(dataType, shape) {
|
|
|
- this.dataType = _.Utility.dataType(dataType); // string
|
|
|
+ this.dataType = mlir.Utility.dataType(dataType); // string
|
|
|
this.shape = shape || new mlir.TensorShape([]); // mlir.TensorShape
|
|
|
}
|
|
|
|
|
|
@@ -613,7 +627,7 @@ mlir.TensorShape = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-_.Context = class {
|
|
|
+mlir.Context = class {
|
|
|
|
|
|
constructor(metadata, functions) {
|
|
|
this._metadata = metadata;
|
|
|
@@ -662,6 +676,179 @@ _.Context = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+mlir.Utility = class {
|
|
|
+
|
|
|
+ static dataType(value) {
|
|
|
+ if (value instanceof _.ComplexType) {
|
|
|
+ const elementType = mlir.Utility.dataType(value.elementType);
|
|
|
+ return `complex<${elementType}>`;
|
|
|
+ }
|
|
|
+ if (value instanceof _.Type) {
|
|
|
+ value = value.toString();
|
|
|
+ }
|
|
|
+ switch (value) {
|
|
|
+ case 'index': return 'int64';
|
|
|
+ case 'f16': return 'float16';
|
|
|
+ case 'f32': return 'float32';
|
|
|
+ case 'f64': return 'float64';
|
|
|
+ case 'f80': return 'float80';
|
|
|
+ case 'f128': return 'float128';
|
|
|
+ case 'bf16': return 'bfloat16';
|
|
|
+ case 'fp8': return 'float8';
|
|
|
+ case 'fp8e4m3': return 'float8e4m3';
|
|
|
+ case 'fp8_e4m3': return 'float8e4m3';
|
|
|
+ case 'fp8e4m3fn': return 'float8e4m3fn';
|
|
|
+ case 'fp8e5m2': return 'float8e5m2';
|
|
|
+ case 'fp8_e5m2': return 'float8e5m2';
|
|
|
+ case 'f4E2M1FN': return 'float4e2m1fn';
|
|
|
+ case 'f6E2M3FN': return 'float6e2m3fn';
|
|
|
+ case 'f6E3M2FN': return 'float6e3m2fn';
|
|
|
+ case 'f8E3M4': return 'float8e3m4';
|
|
|
+ case 'f8E4M3': return 'float8e4m3';
|
|
|
+ case 'f8E4M3B11FNUZ': return 'float8e4m3b11fnuz';
|
|
|
+ case 'f8E4M3FN': return 'float8e4m3fn';
|
|
|
+ case 'f8E4M3FNUZ': return 'float8e4m3fnuz';
|
|
|
+ case 'f8E5M2': return 'float8e5m2';
|
|
|
+ case 'f8E5M2FNUZ': return 'float8e5m2fnuz';
|
|
|
+ case 'f8E8M0FNU': return 'float8e8m0fnu';
|
|
|
+ case 'float8': return 'float8';
|
|
|
+ case 'tf32': return 'tensorfloat32';
|
|
|
+ case 'i1': return 'int1';
|
|
|
+ case 'i2': return 'int2';
|
|
|
+ case 'i4': return 'int4';
|
|
|
+ case 'i8': return 'int8';
|
|
|
+ case 'i16': return 'int16';
|
|
|
+ case 'i32': return 'int32';
|
|
|
+ case 'i48': return 'int48';
|
|
|
+ case 'i64': return 'int64';
|
|
|
+ case 'si8': return 'int8';
|
|
|
+ case 'si16': return 'int16';
|
|
|
+ case 'si32': return 'int32';
|
|
|
+ case 'si64': return 'int64';
|
|
|
+ case 'ui1': return 'uint1';
|
|
|
+ case 'ui2': return 'uint2';
|
|
|
+ case 'ui4': return 'uint4';
|
|
|
+ case 'ui8': return 'uint8';
|
|
|
+ case 'ui16': return 'uint16';
|
|
|
+ case 'ui32': return 'uint32';
|
|
|
+ case 'ui64': return 'uint64';
|
|
|
+ case 'b8': return 'int8';
|
|
|
+ case 'unk': return 'unk'; // torch dialect unknown dtype
|
|
|
+ case '!tf_type.string': return 'string';
|
|
|
+ default:
|
|
|
+ if (value && value.startsWith('!')) {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+ if (value && value.startsWith('vector<') && value.endsWith('>')) {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+ if (value && value.startsWith('memref<') && value.endsWith('>')) {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+ if (value && value.startsWith('tuple<') && value.endsWith('>')) {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+ if (value && value.startsWith('complex<') && value.endsWith('>')) {
|
|
|
+ const elementTypeStr = value.substring(8, value.length - 1);
|
|
|
+ const convertedElementType = mlir.Utility.dataType(elementTypeStr);
|
|
|
+ return `complex<${convertedElementType}>`;
|
|
|
+ }
|
|
|
+ if (value && /^[su]?i[0-9]+$/.test(value)) {
|
|
|
+ const match = value.match(/^(s|u)?i([0-9]+)$/);
|
|
|
+ if (match) {
|
|
|
+ const [, signed, widthStr] = match;
|
|
|
+ const width = parseInt(widthStr, 10);
|
|
|
+ if (signed === 'u') {
|
|
|
+ return `uint${width}`;
|
|
|
+ } else if (signed === 's') {
|
|
|
+ return `int${width}`;
|
|
|
+ }
|
|
|
+ return `int${width}`;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ throw new mlir.Error(`Unknown data type '${value}'.`);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ static valueType(type) {
|
|
|
+ if (type === undefined) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ const typeStr = type instanceof _.Type ? type.toString() : type;
|
|
|
+ if (typeStr.startsWith('!') && !typeStr.startsWith('!torch.vtensor<')) {
|
|
|
+ return typeStr;
|
|
|
+ }
|
|
|
+ if (typeStr.startsWith('tensor<') && typeStr.endsWith('>')) {
|
|
|
+ const spec = typeStr.substring(7, typeStr.length - 1).trim();
|
|
|
+ if (spec.startsWith('!')) {
|
|
|
+ return mlir.Utility.valueType(spec);
|
|
|
+ }
|
|
|
+ let i = 0;
|
|
|
+ const shape = [];
|
|
|
+ while (i < spec.length) {
|
|
|
+ if (spec[i] === '?' || spec[i] === '*') {
|
|
|
+ shape.push('?');
|
|
|
+ i++;
|
|
|
+ } else if (/[0-9]/.test(spec[i])) {
|
|
|
+ let numStr = '';
|
|
|
+ while (i < spec.length && /[0-9]/.test(spec[i])) {
|
|
|
+ numStr += spec[i];
|
|
|
+ i++;
|
|
|
+ }
|
|
|
+ const dim = parseInt(numStr, 10);
|
|
|
+ if (isNaN(dim)) {
|
|
|
+ shape.push('?');
|
|
|
+ } else {
|
|
|
+ shape.push(dim);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (i < spec.length && spec[i] === 'x') {
|
|
|
+ i++;
|
|
|
+ } else {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ let dataType = spec.substring(i);
|
|
|
+ const encodingIndex = dataType.indexOf(',');
|
|
|
+ if (encodingIndex !== -1) {
|
|
|
+ dataType = dataType.substring(0, encodingIndex).trim();
|
|
|
+ }
|
|
|
+ return new mlir.TensorType(dataType, new mlir.TensorShape(shape));
|
|
|
+ }
|
|
|
+ if (typeStr.startsWith('!torch.vtensor<') && typeStr.endsWith('>')) {
|
|
|
+ const spec = typeStr.substring(15, typeStr.length - 1);
|
|
|
+ let shape = null;
|
|
|
+ let dataType = null;
|
|
|
+ if (spec.startsWith('[')) {
|
|
|
+ const bracketEnd = spec.indexOf(']');
|
|
|
+ const shapeStr = spec.substring(0, bracketEnd + 1);
|
|
|
+ const jsonStr = shapeStr.replace(/\?/g, '"?"');
|
|
|
+ shape = JSON.parse(jsonStr);
|
|
|
+ const rest = spec.substring(bracketEnd + 1);
|
|
|
+ if (rest.startsWith(',')) {
|
|
|
+ const parts = rest.substring(1).split(',');
|
|
|
+ dataType = parts[0].trim();
|
|
|
+ }
|
|
|
+ } else if (spec.startsWith('*')) {
|
|
|
+ if (spec.includes(',')) {
|
|
|
+ const parts = spec.split(',');
|
|
|
+ dataType = parts[1].trim();
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ const parts = spec.split(',');
|
|
|
+ dataType = parts[0].trim();
|
|
|
+ }
|
|
|
+ return new mlir.TensorType(dataType, shape ? new mlir.TensorShape(shape) : null);
|
|
|
+ }
|
|
|
+ if (typeStr.startsWith('tuple<') && typeStr.endsWith('>')) {
|
|
|
+ return typeStr;
|
|
|
+ }
|
|
|
+ return typeStr;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
_.OperationState = class {
|
|
|
|
|
|
constructor(name) {
|
|
|
@@ -5003,179 +5190,6 @@ _.BufferReader = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-_.Utility = class {
|
|
|
-
|
|
|
- static dataType(value) {
|
|
|
- if (value instanceof _.ComplexType) {
|
|
|
- const elementType = _.Utility.dataType(value.elementType);
|
|
|
- return `complex<${elementType}>`;
|
|
|
- }
|
|
|
- if (value instanceof _.Type) {
|
|
|
- value = value.toString();
|
|
|
- }
|
|
|
- switch (value) {
|
|
|
- case 'index': return 'int64';
|
|
|
- case 'f16': return 'float16';
|
|
|
- case 'f32': return 'float32';
|
|
|
- case 'f64': return 'float64';
|
|
|
- case 'f80': return 'float80';
|
|
|
- case 'f128': return 'float128';
|
|
|
- case 'bf16': return 'bfloat16';
|
|
|
- case 'fp8': return 'float8';
|
|
|
- case 'fp8e4m3': return 'float8e4m3';
|
|
|
- case 'fp8_e4m3': return 'float8e4m3';
|
|
|
- case 'fp8e4m3fn': return 'float8e4m3fn';
|
|
|
- case 'fp8e5m2': return 'float8e5m2';
|
|
|
- case 'fp8_e5m2': return 'float8e5m2';
|
|
|
- case 'f4E2M1FN': return 'float4e2m1fn';
|
|
|
- case 'f6E2M3FN': return 'float6e2m3fn';
|
|
|
- case 'f6E3M2FN': return 'float6e3m2fn';
|
|
|
- case 'f8E3M4': return 'float8e3m4';
|
|
|
- case 'f8E4M3': return 'float8e4m3';
|
|
|
- case 'f8E4M3B11FNUZ': return 'float8e4m3b11fnuz';
|
|
|
- case 'f8E4M3FN': return 'float8e4m3fn';
|
|
|
- case 'f8E4M3FNUZ': return 'float8e4m3fnuz';
|
|
|
- case 'f8E5M2': return 'float8e5m2';
|
|
|
- case 'f8E5M2FNUZ': return 'float8e5m2fnuz';
|
|
|
- case 'f8E8M0FNU': return 'float8e8m0fnu';
|
|
|
- case 'float8': return 'float8';
|
|
|
- case 'tf32': return 'tensorfloat32';
|
|
|
- case 'i1': return 'int1';
|
|
|
- case 'i2': return 'int2';
|
|
|
- case 'i4': return 'int4';
|
|
|
- case 'i8': return 'int8';
|
|
|
- case 'i16': return 'int16';
|
|
|
- case 'i32': return 'int32';
|
|
|
- case 'i48': return 'int48';
|
|
|
- case 'i64': return 'int64';
|
|
|
- case 'si8': return 'int8';
|
|
|
- case 'si16': return 'int16';
|
|
|
- case 'si32': return 'int32';
|
|
|
- case 'si64': return 'int64';
|
|
|
- case 'ui1': return 'uint1';
|
|
|
- case 'ui2': return 'uint2';
|
|
|
- case 'ui4': return 'uint4';
|
|
|
- case 'ui8': return 'uint8';
|
|
|
- case 'ui16': return 'uint16';
|
|
|
- case 'ui32': return 'uint32';
|
|
|
- case 'ui64': return 'uint64';
|
|
|
- case 'b8': return 'int8';
|
|
|
- case 'unk': return 'unk'; // torch dialect unknown dtype
|
|
|
- case '!tf_type.string': return 'string';
|
|
|
- default:
|
|
|
- if (value && value.startsWith('!')) {
|
|
|
- return value;
|
|
|
- }
|
|
|
- if (value && value.startsWith('vector<') && value.endsWith('>')) {
|
|
|
- return value;
|
|
|
- }
|
|
|
- if (value && value.startsWith('memref<') && value.endsWith('>')) {
|
|
|
- return value;
|
|
|
- }
|
|
|
- if (value && value.startsWith('tuple<') && value.endsWith('>')) {
|
|
|
- return value;
|
|
|
- }
|
|
|
- if (value && value.startsWith('complex<') && value.endsWith('>')) {
|
|
|
- const elementTypeStr = value.substring(8, value.length - 1);
|
|
|
- const convertedElementType = _.Utility.dataType(elementTypeStr);
|
|
|
- return `complex<${convertedElementType}>`;
|
|
|
- }
|
|
|
- if (value && /^[su]?i[0-9]+$/.test(value)) {
|
|
|
- const match = value.match(/^(s|u)?i([0-9]+)$/);
|
|
|
- if (match) {
|
|
|
- const [, signed, widthStr] = match;
|
|
|
- const width = parseInt(widthStr, 10);
|
|
|
- if (signed === 'u') {
|
|
|
- return `uint${width}`;
|
|
|
- } else if (signed === 's') {
|
|
|
- return `int${width}`;
|
|
|
- }
|
|
|
- return `int${width}`;
|
|
|
- }
|
|
|
- }
|
|
|
- throw new mlir.Error(`Unknown data type '${value}'.`);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- static valueType(type) {
|
|
|
- if (type === undefined) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- const typeStr = type instanceof _.Type ? type.toString() : type;
|
|
|
- if (typeStr.startsWith('!') && !typeStr.startsWith('!torch.vtensor<')) {
|
|
|
- return typeStr;
|
|
|
- }
|
|
|
- if (typeStr.startsWith('tensor<') && typeStr.endsWith('>')) {
|
|
|
- const spec = typeStr.substring(7, typeStr.length - 1).trim();
|
|
|
- if (spec.startsWith('!')) {
|
|
|
- return _.Utility.valueType(spec);
|
|
|
- }
|
|
|
- let i = 0;
|
|
|
- const shape = [];
|
|
|
- while (i < spec.length) {
|
|
|
- if (spec[i] === '?' || spec[i] === '*') {
|
|
|
- shape.push('?');
|
|
|
- i++;
|
|
|
- } else if (/[0-9]/.test(spec[i])) {
|
|
|
- let numStr = '';
|
|
|
- while (i < spec.length && /[0-9]/.test(spec[i])) {
|
|
|
- numStr += spec[i];
|
|
|
- i++;
|
|
|
- }
|
|
|
- const dim = parseInt(numStr, 10);
|
|
|
- if (isNaN(dim)) {
|
|
|
- shape.push('?');
|
|
|
- } else {
|
|
|
- shape.push(dim);
|
|
|
- }
|
|
|
- } else {
|
|
|
- break;
|
|
|
- }
|
|
|
- if (i < spec.length && spec[i] === 'x') {
|
|
|
- i++;
|
|
|
- } else {
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- let dataType = spec.substring(i);
|
|
|
- const encodingIndex = dataType.indexOf(',');
|
|
|
- if (encodingIndex !== -1) {
|
|
|
- dataType = dataType.substring(0, encodingIndex).trim();
|
|
|
- }
|
|
|
- return new mlir.TensorType(dataType, new mlir.TensorShape(shape));
|
|
|
- }
|
|
|
- if (typeStr.startsWith('!torch.vtensor<') && typeStr.endsWith('>')) {
|
|
|
- const spec = typeStr.substring(15, typeStr.length - 1);
|
|
|
- let shape = null;
|
|
|
- let dataType = null;
|
|
|
- if (spec.startsWith('[')) {
|
|
|
- const bracketEnd = spec.indexOf(']');
|
|
|
- const shapeStr = spec.substring(0, bracketEnd + 1);
|
|
|
- const jsonStr = shapeStr.replace(/\?/g, '"?"');
|
|
|
- shape = JSON.parse(jsonStr);
|
|
|
- const rest = spec.substring(bracketEnd + 1);
|
|
|
- if (rest.startsWith(',')) {
|
|
|
- const parts = rest.substring(1).split(',');
|
|
|
- dataType = parts[0].trim();
|
|
|
- }
|
|
|
- } else if (spec.startsWith('*')) {
|
|
|
- if (spec.includes(',')) {
|
|
|
- const parts = spec.split(',');
|
|
|
- dataType = parts[1].trim();
|
|
|
- }
|
|
|
- } else {
|
|
|
- const parts = spec.split(',');
|
|
|
- dataType = parts[0].trim();
|
|
|
- }
|
|
|
- return new mlir.TensorType(dataType, shape ? new mlir.TensorShape(shape) : null);
|
|
|
- }
|
|
|
- if (typeStr.startsWith('tuple<') && typeStr.endsWith('>')) {
|
|
|
- return typeStr;
|
|
|
- }
|
|
|
- return typeStr;
|
|
|
- }
|
|
|
-};
|
|
|
-
|
|
|
// Dialect Plugin System
|
|
|
|
|
|
_.AssemblyFormatParser = class {
|
|
|
@@ -7549,15 +7563,19 @@ _.HLODialect = class extends _.Dialect {
|
|
|
resultTypes.push(type);
|
|
|
}
|
|
|
|
|
|
- // custom<SelectOpType>(type($operands), type($result))
|
|
|
- _parseSelectOpType(parser, op, operandTypes, resultTypes) {
|
|
|
+ // custom<SelectOpType>(type($pred), type($on_true), type($on_false), type($result))
|
|
|
+ _parseSelectOpType(parser, op, predTypes, onTrueTypes, onFalseTypes, resultTypes) {
|
|
|
const firstType = parser.parseType();
|
|
|
if (parser.accept(',')) {
|
|
|
const secondType = parser.parseType();
|
|
|
- operandTypes.push(firstType);
|
|
|
+ predTypes.push(firstType);
|
|
|
+ onTrueTypes.push(secondType);
|
|
|
+ onFalseTypes.push(secondType);
|
|
|
resultTypes.push(secondType);
|
|
|
} else {
|
|
|
- operandTypes.push(firstType);
|
|
|
+ predTypes.push(firstType);
|
|
|
+ onTrueTypes.push(firstType);
|
|
|
+ onFalseTypes.push(firstType);
|
|
|
resultTypes.push(firstType);
|
|
|
}
|
|
|
}
|
|
|
@@ -8039,7 +8057,17 @@ _.StableHLODialect = class extends _.HLODialect {
|
|
|
if (parser.match('{')) {
|
|
|
parser.parseAttributeDict(op.attributes);
|
|
|
}
|
|
|
- parser.resolveOperands(unresolvedOperands, parser.parseOptionalColonTypeList(), op.operands);
|
|
|
+ // Handle `: (operand-types) -> result-types` functional type format
|
|
|
+ if (parser.accept(':')) {
|
|
|
+ const type = parser.parseType();
|
|
|
+ if (type instanceof _.FunctionType) {
|
|
|
+ parser.resolveOperands(unresolvedOperands, type.inputs, op.operands);
|
|
|
+ op.addTypes(type.results);
|
|
|
+ } else {
|
|
|
+ const types = Array.isArray(type) ? type : [type];
|
|
|
+ parser.resolveOperands(unresolvedOperands, types, op.operands);
|
|
|
+ }
|
|
|
+ }
|
|
|
if (parser.accept('->') || parser.accept('id', 'to')) {
|
|
|
const types = parser.parseFunctionResultTypes();
|
|
|
op.addTypes(types);
|