|
|
@@ -184,12 +184,6 @@ mlir.Graph = class {
|
|
|
this.nodes = [];
|
|
|
this.metadata = [];
|
|
|
const tensors = new Map();
|
|
|
- const tensor = (arg) => {
|
|
|
- if (!tensors.has(arg.name)) {
|
|
|
- tensors.set(arg.name, new mlir.Value(arg.name, arg.type, null, arg.value));
|
|
|
- }
|
|
|
- return tensors.get(arg.name);
|
|
|
- };
|
|
|
// Handle function inputs/outputs if function_type exists
|
|
|
if (func.attributes.has('function_type')) {
|
|
|
const function_type = func.attributes.get('function_type');
|
|
|
@@ -245,7 +239,7 @@ mlir.Graph = class {
|
|
|
const operands = op.operands || [];
|
|
|
for (let i = 0; i < operands.length; i++) {
|
|
|
const input = op.operands[i];
|
|
|
- const inputName = input.name || (opMetadata && opMetadata.operands && opMetadata.operands[i] ? opMetadata.operands[i].name : null) || i.toString();
|
|
|
+ const inputName = (opMetadata && opMetadata.operands && opMetadata.operands[i] ? opMetadata.operands[i].name : null) || input.name || i.toString();
|
|
|
if (input.value instanceof Uint8Array) {
|
|
|
operation.operands.push({
|
|
|
name: inputName,
|
|
|
@@ -270,7 +264,7 @@ mlir.Graph = class {
|
|
|
value: input.value
|
|
|
});
|
|
|
} else if (typeof input.value === 'string' && input.value) {
|
|
|
- const value = values.map(input);
|
|
|
+ const value = values.map(input.value);
|
|
|
value.to.push(operation);
|
|
|
const args = [{ name: input.value, type: input.type }];
|
|
|
operation.operands.push({
|
|
|
@@ -294,7 +288,7 @@ mlir.Graph = class {
|
|
|
const value = values.map(output.value);
|
|
|
value.type = mlir.Utility.valueType(output.type);
|
|
|
value.from.push(operation);
|
|
|
- const outputName = output.name || (opMetadata && opMetadata.results && opMetadata.results[i] ? opMetadata.results[i].name : null) || i.toString();
|
|
|
+ const outputName = (opMetadata && opMetadata.results && opMetadata.results[i] ? opMetadata.results[i].name : null) || output.name || i.toString();
|
|
|
operation.results.push({
|
|
|
name: outputName,
|
|
|
value: [value]
|
|
|
@@ -304,6 +298,39 @@ mlir.Graph = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+ // Build map of single-use constant tensors to convert to initializers
|
|
|
+ const constantMap = new Map();
|
|
|
+ const constantTypes = new Set([
|
|
|
+ 'tosa.const', 'stablehlo.constant', 'arith.constant',
|
|
|
+ 'mhlo.constant', 'torch.constant.tensor'
|
|
|
+ ]);
|
|
|
+ for (const op of operations) {
|
|
|
+ if (constantTypes.has(op.type) &&
|
|
|
+ op.operands.length === 0 &&
|
|
|
+ op.results.length === 1 &&
|
|
|
+ op.results[0].value.length === 1) {
|
|
|
+ const [result] = op.results[0].value;
|
|
|
+ if (result.to && result.to.length === 1) {
|
|
|
+ const valueAttr = op.attributes.get('value');
|
|
|
+ if ((valueAttr instanceof mlir.DenseElementsAttr || valueAttr instanceof mlir.DenseResourceElementsAttr) &&
|
|
|
+ valueAttr.value !== null &&
|
|
|
+ valueAttr.type && valueAttr.type.toString().startsWith('tensor<')) {
|
|
|
+ const type = mlir.Utility.valueType(valueAttr.type);
|
|
|
+ if (type instanceof mlir.TensorType) {
|
|
|
+ constantMap.set(result.name, new mlir.Tensor(type, valueAttr.value));
|
|
|
+ op.delete = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const tensor = (arg) => {
|
|
|
+ if (!tensors.has(arg.name)) {
|
|
|
+ const initializer = constantMap.get(arg.name) || null;
|
|
|
+ tensors.set(arg.name, new mlir.Value(arg.name, arg.type, null, initializer));
|
|
|
+ }
|
|
|
+ return tensors.get(arg.name);
|
|
|
+ };
|
|
|
for (const input of this.inputs) {
|
|
|
for (const arg of input.value) {
|
|
|
if (!tensors.has(arg.name)) {
|
|
|
@@ -366,7 +393,7 @@ mlir.Argument = class {
|
|
|
case 'i32': case 'si32': this.type = 'int32'; break;
|
|
|
case 'i16': case 'si16': this.type = 'int16'; break;
|
|
|
case 'i8': case 'si8': this.type = 'int8'; break;
|
|
|
- case 'i1': this.type = 'boolean'; break;
|
|
|
+ case 'i1': this.type = 'int1'; break;
|
|
|
case 'f32': case 'float32': this.type = 'float32'; break;
|
|
|
case 'f64': case 'float64': this.type = 'float64'; break;
|
|
|
case 'f16': this.type = 'float16'; break;
|
|
|
@@ -463,9 +490,12 @@ mlir.Node = class {
|
|
|
value = graph;
|
|
|
type = 'function';
|
|
|
}
|
|
|
- } else if (attr instanceof mlir.DenseElementsAttr) {
|
|
|
+ } else if (attr instanceof mlir.DenseElementsAttr && attr.value !== null) {
|
|
|
value = new mlir.Tensor(mlir.Utility.valueType(attr.type), attr.value);
|
|
|
type = 'tensor';
|
|
|
+ } else if (attr instanceof mlir.DenseResourceElementsAttr) {
|
|
|
+ value = new mlir.Tensor(mlir.Utility.valueType(attr.type), null);
|
|
|
+ type = 'tensor';
|
|
|
} else if (attr instanceof mlir.ArrayAttr) {
|
|
|
value = attr.value;
|
|
|
} else if (attr instanceof mlir.DenseArrayAttr) {
|
|
|
@@ -1441,6 +1471,17 @@ mlir.Parser = class {
|
|
|
parseGenericOperation() {
|
|
|
const name = this.expect('string');
|
|
|
const op = new mlir.OperationState(name);
|
|
|
+ const index = name.indexOf('.');
|
|
|
+ if (index !== -1) {
|
|
|
+ const dialectName = name.substring(0, index);
|
|
|
+ const dialect = this._context.getDialect(dialectName);
|
|
|
+ if (dialect) {
|
|
|
+ const opInfo = dialect.getOperation(name);
|
|
|
+ if (opInfo) {
|
|
|
+ op.metadata = opInfo.metadata;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
return this.parseGenericOperationAfterOpName(op);
|
|
|
}
|
|
|
|
|
|
@@ -1957,6 +1998,10 @@ mlir.Parser = class {
|
|
|
this._token.text = prefix;
|
|
|
return this.parseType();
|
|
|
}
|
|
|
+ // Return as PrimitiveType for known primitive types
|
|
|
+ if (/^[su]?i[0-9]+$/.test(prefix) || /^[fb]f?[0-9]+/.test(prefix) || prefix === 'index') {
|
|
|
+ return new mlir.PrimitiveType(prefix);
|
|
|
+ }
|
|
|
return prefix;
|
|
|
}
|
|
|
|
|
|
@@ -2760,7 +2805,7 @@ mlir.Parser = class {
|
|
|
this.expect(':');
|
|
|
type = this.parseType();
|
|
|
}
|
|
|
- return { value: handle, type };
|
|
|
+ return new mlir.DenseResourceElementsAttr(handle, type);
|
|
|
}
|
|
|
|
|
|
parseDenseArrayAttr(/* attrType */) {
|
|
|
@@ -3049,31 +3094,60 @@ mlir.TensorLiteralParser = class {
|
|
|
const elementType = type.getElementType ? type.getElementType() : null;
|
|
|
const numElements = type.getNumElements ? type.getNumElements() : 0;
|
|
|
const isComplex = elementType instanceof mlir.ComplexType;
|
|
|
+ const baseElemType = isComplex && elementType.elementType ? elementType.elementType : elementType;
|
|
|
+ // Determine conversion function once based on element type
|
|
|
+ let convert = (v) => v;
|
|
|
+ if (baseElemType) {
|
|
|
+ const typeStr = baseElemType.toString();
|
|
|
+ const intMatch = typeStr.match(/^[su]?i(\d+)$/);
|
|
|
+ if (intMatch) {
|
|
|
+ const bitWidth = parseInt(intMatch[1], 10);
|
|
|
+ if (bitWidth >= 64) {
|
|
|
+ convert = (v) => typeof v === 'bigint' ? v : BigInt(v);
|
|
|
+ }
|
|
|
+ // For smaller ints, values are already numbers from tokenizer
|
|
|
+ } else if (typeStr === 'index') {
|
|
|
+ convert = (v) => typeof v === 'bigint' ? v : BigInt(v);
|
|
|
+ }
|
|
|
+ // For floats and other types, values are already correct from tokenizer
|
|
|
+ }
|
|
|
+ // Handle zero-element tensors (e.g., tensor<2x0x3xi4>)
|
|
|
+ if (numElements === 0) {
|
|
|
+ return [];
|
|
|
+ }
|
|
|
// Limit splat expansion to avoid memory issues with huge tensors
|
|
|
const maxSplatExpansion = 10000000;
|
|
|
// Handle splats - Reference: if shape.empty() and storage has elements, it's a splat
|
|
|
const isSplat = this._shape.length === 0 && this._storage.length > 0;
|
|
|
- if (isSplat && numElements > 1 && numElements <= maxSplatExpansion) {
|
|
|
+ if (isSplat && numElements > 1) {
|
|
|
+ if (numElements > maxSplatExpansion) {
|
|
|
+ // Too large to expand - return null to indicate we can't provide the data
|
|
|
+ return null;
|
|
|
+ }
|
|
|
if (isComplex && this._storage.length === 2) {
|
|
|
// Complex splat: storage has 2 elements (real, imag)
|
|
|
const result = [];
|
|
|
+ const real = convert(this._storage[0]);
|
|
|
+ const imag = convert(this._storage[1]);
|
|
|
for (let i = 0; i < numElements; i++) {
|
|
|
- result.push(new base.Complex(this._storage[0], this._storage[1]));
|
|
|
+ result.push(new base.Complex(real, imag));
|
|
|
}
|
|
|
return result;
|
|
|
}
|
|
|
// Regular splat: replicate single value
|
|
|
- return new Array(numElements).fill(this._storage[0]);
|
|
|
+ const converted = convert(this._storage[0]);
|
|
|
+ return new Array(numElements).fill(converted);
|
|
|
}
|
|
|
// Non-splat complex: convert pairs to base.Complex objects
|
|
|
if (isComplex && Array.isArray(this._storage)) {
|
|
|
const result = [];
|
|
|
for (let i = 0; i < this._storage.length; i += 2) {
|
|
|
- result.push(new base.Complex(this._storage[i], this._storage[i + 1]));
|
|
|
+ result.push(new base.Complex(convert(this._storage[i]), convert(this._storage[i + 1])));
|
|
|
}
|
|
|
return result;
|
|
|
}
|
|
|
- return this._storage;
|
|
|
+ // Convert all values
|
|
|
+ return this._storage.map(convert);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
@@ -3301,7 +3375,7 @@ mlir.AttrTypeReader = class {
|
|
|
const typeIdx = reader.varint().toNumber();
|
|
|
const type = this.readType(typeIdx);
|
|
|
const handleIdx = reader.varint().toNumber();
|
|
|
- return { name: 'dense_resource', value: `resource<${handleIdx}>`, type };
|
|
|
+ return new mlir.DenseResourceElementsAttr(`resource<${handleIdx}>`, type);
|
|
|
}
|
|
|
default: {
|
|
|
return { name: 'builtin', value: `<builtin code ${typeCode}>` };
|
|
|
@@ -3440,17 +3514,15 @@ mlir.AttrTypeReader = class {
|
|
|
const shape = this._readShape(reader);
|
|
|
const elementTypeIdx = reader.varint().toNumber();
|
|
|
const elementType = this.readType(elementTypeIdx);
|
|
|
- const shapeStr = shape.map((d) => d < 0 ? '?' : d.toString()).join('x');
|
|
|
- return new mlir.Type(`tensor<${shapeStr}x${elementType.name}>`);
|
|
|
+ return new mlir.RankedTensorType(shape, elementType, null);
|
|
|
}
|
|
|
case 14: { // RankedTensorTypeWithEncoding
|
|
|
const encodingAttrIdx = reader.varint().toNumber();
|
|
|
- this.readAttribute(encodingAttrIdx); // encoding
|
|
|
+ const encoding = this.readAttribute(encodingAttrIdx);
|
|
|
const shape = this._readShape(reader);
|
|
|
const elementTypeIdx = reader.varint().toNumber();
|
|
|
const elementType = this.readType(elementTypeIdx);
|
|
|
- const shapeStr = shape.map((d) => d < 0 ? '?' : d.toString()).join('x');
|
|
|
- return new mlir.Type(`tensor<${shapeStr}x${elementType.name}>`);
|
|
|
+ return new mlir.RankedTensorType(shape, elementType, encoding);
|
|
|
}
|
|
|
case 15: { // TupleType
|
|
|
const numTypes = reader.varint().toNumber();
|
|
|
@@ -3475,8 +3547,7 @@ mlir.AttrTypeReader = class {
|
|
|
const shape = this._readShape(reader);
|
|
|
const elementTypeIdx = reader.varint().toNumber();
|
|
|
const elementType = this.readType(elementTypeIdx);
|
|
|
- const shapeStr = shape.map((d) => d < 0 ? '?' : d.toString()).join('x');
|
|
|
- return new mlir.Type(`vector<${shapeStr}x${elementType.name}>`);
|
|
|
+ return new mlir.VectorType(shape, elementType);
|
|
|
}
|
|
|
case 19: { // VectorTypeWithScalableDims - simplified
|
|
|
return new mlir.Type('vector<?>');
|
|
|
@@ -3900,6 +3971,14 @@ mlir.BytecodeReader = class {
|
|
|
const kHasProperties = 0x40;
|
|
|
|
|
|
const op = new mlir.OperationState(fullName);
|
|
|
+ const [dialectName] = fullName.split('.');
|
|
|
+ const dialect = this._context.getDialect(dialectName);
|
|
|
+ if (dialect) {
|
|
|
+ const opInfo = dialect.getOperation(fullName);
|
|
|
+ if (opInfo) {
|
|
|
+ op.metadata = opInfo.metadata;
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
// Parse location
|
|
|
const locIdx = reader.varint().toNumber();
|
|
|
@@ -4371,6 +4450,19 @@ mlir.DenseElementsAttr = class extends mlir.Attribute {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+mlir.DenseResourceElementsAttr = class extends mlir.Attribute {
|
|
|
+
|
|
|
+ constructor(handle, type) {
|
|
|
+ super();
|
|
|
+ this.handle = handle;
|
|
|
+ this.type = type;
|
|
|
+ }
|
|
|
+
|
|
|
+ toString() {
|
|
|
+ return `dense_resource<${this.handle}>`;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
mlir.ArrayAttr = class extends mlir.Attribute {
|
|
|
|
|
|
constructor(elements) {
|
|
|
@@ -4573,7 +4665,7 @@ mlir.Utility = class {
|
|
|
value = value.toString();
|
|
|
}
|
|
|
switch (value) {
|
|
|
- case 'index': return 'index';
|
|
|
+ case 'index': return 'int64';
|
|
|
case 'f16': return 'float16';
|
|
|
case 'f32': return 'float32';
|
|
|
case 'f64': return 'float64';
|
|
|
@@ -4599,7 +4691,7 @@ mlir.Utility = class {
|
|
|
case 'f8E8M0FNU': return 'float8e8m0fnu';
|
|
|
case 'float8': return 'float8';
|
|
|
case 'tf32': return 'tensorfloat32';
|
|
|
- case 'i1': return 'boolean';
|
|
|
+ case 'i1': return 'int1';
|
|
|
case 'i2': return 'int2';
|
|
|
case 'i4': return 'int4';
|
|
|
case 'i8': return 'int8';
|
|
|
@@ -4620,6 +4712,7 @@ mlir.Utility = class {
|
|
|
case 'ui64': return 'uint64';
|
|
|
case 'b8': return 'int8';
|
|
|
case 'unk': return 'unk'; // torch dialect unknown dtype
|
|
|
+ case '!tf_type.string': return 'string';
|
|
|
default:
|
|
|
if (value && value.startsWith('!')) {
|
|
|
return value;
|
|
|
@@ -4648,7 +4741,7 @@ mlir.Utility = class {
|
|
|
} else if (signed === 's') {
|
|
|
return `int${width}`;
|
|
|
}
|
|
|
- return width === 1 ? 'boolean' : `int${width}`;
|
|
|
+ return `int${width}`;
|
|
|
}
|
|
|
}
|
|
|
throw new mlir.Error(`Unknown data type '${value}'.`);
|