Bladeren bron

Add MLIR support (#1044)

Lutz Roeder 2 maanden geleden
bovenliggende
commit
694e7fc0f3
5 gewijzigde bestanden met toevoegingen van 1090 en 180 verwijderingen
  1. 26 3
      source/mlir-metadata.json
  2. 519 172
      source/mlir.js
  3. 3 0
      tools/mlir
  4. 11 1
      tools/mlir-script.js
  5. 531 4
      tools/tablegen.js

File diff suppressed because it is too large
+ 26 - 3
source/mlir-metadata.json


File diff suppressed because it is too large
+ 519 - 172
source/mlir.js


+ 3 - 0
tools/mlir

@@ -8,6 +8,7 @@ src_dir=./third_party/source/mlir
 entries=(
     "https://github.com/llvm/llvm-project.git|main|${src_dir}/llvm-project|mlir"
     "https://github.com/openxla/stablehlo.git|main|${src_dir}/stablehlo|stablehlo"
+    "https://github.com/openxla/shardy.git|main|${src_dir}/shardy|shardy"
     "https://github.com/openxla/xla.git|main|${src_dir}/xla|xla/mlir_hlo"
     "https://github.com/onnx/onnx-mlir.git|main|${src_dir}/onnx-mlir|"
     "https://github.com/llvm/torch-mlir.git|main|${src_dir}/torch-mlir|"
@@ -22,6 +23,8 @@ entries=(
     "https://github.com/spcl/mlir-dace.git|main|${src_dir}/mlir-dace|"
     "https://github.com/woxjro/lltz.git|master|${src_dir}/lltz|"
     "https://github.com/pengmai/lagrad.git|main|${src_dir}/lagrad|include"
+    "https://github.com/llvm/clangir.git|main|${src_dir}/clangir|clang/include/clang/CIR"
+    "https://github.com/ROCm/rocMLIR.git|develop|${src_dir}/rocMLIR|mlir/include/mlir/Dialect"
 )
 
 clean() {

+ 11 - 1
tools/mlir-script.js

@@ -64,6 +64,7 @@ const schema = async () => {
         path.join(source, 'llvm-project', 'mlir', 'examples', 'transform', 'Ch2', 'include'),
         path.join(source, 'llvm-project', 'mlir', 'examples', 'transform', 'Ch3', 'include'),
         path.join(source, 'stablehlo'),
+        path.join(source, 'shardy'),
         path.join(source, 'xla', 'xla', 'mlir_hlo'),
         path.join(source, 'onnx-mlir'),
         path.join(source, 'torch-mlir', 'include'),
@@ -93,6 +94,10 @@ const schema = async () => {
         path.join(source, 'triton', 'third_party', 'nvidia', 'include', 'Dialect', 'NVGPU', 'IR'),
         path.join(source, 'triton', 'third_party', 'nvidia', 'include', 'Dialect', 'NVWS', 'IR'),
         path.join(source, '_', 'llvm-project', 'mlir', 'include'),
+        path.join(source, 'clangir'),
+        path.join(source, 'clangir', 'clang', 'include'),
+        path.join(source, 'rocMLIR'),
+        path.join(source, 'rocMLIR', 'mlir', 'include'),
     ];
     const dialects = [
         'mlir/include/mlir/IR/BuiltinAttributeInterfaces.td',
@@ -210,6 +215,8 @@ const schema = async () => {
         'stablehlo/dialect/VhloOps.td',
         'stablehlo/reference/InterpreterOps.td',
         'stablehlo/tests/CheckOps.td',
+        'shardy/dialect/sdy/ir/ops.td',
+        'shardy/dialect/mpmd/ir/ops.td',
         'mhlo/IR/hlo_ops.td',
         'src/Dialect/ONNX/ONNX.td',
         'src/Dialect/ONNX/ONNXOps.td.inc',
@@ -233,6 +240,7 @@ const schema = async () => {
         'tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td',
         'tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.td',
         'tfrt/core_runtime/opdefs/core_runtime.td',
+        'tfrt/core_runtime/opdefs/sync/core_runtime.td',
         'tfrt/basic_kernels/opdefs/basic_kernels.td',
         'tfrt/test_kernels/opdefs/test_kernels.td',
         'tfrt/tensor/opdefs/tensor.td',
@@ -293,6 +301,8 @@ const schema = async () => {
         'mlir-kernel/Kernel/IR/Ops.td',
         'Dialect/NVGPU/IR/NVGPUOps.td',
         'Standalone/StandaloneOps.td',
+        'clang/include/clang/CIR/Dialect/IR/CIROps.td',
+        'mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td',
     ];
     const file = path.join(dirname, '..', 'source', 'mlir-metadata.json');
     const operations = new Map();
@@ -707,7 +717,7 @@ const test = async (pattern) => {
             writeLine('');
             writeLine('-'.repeat(75));
             if (errorTotals.size > 0) {
-                const sortedErrors = Array.from(errorTotals.entries()).sort((a, b) => b[1] - a[1]).slice(0, 25);
+                const sortedErrors = Array.from(errorTotals.entries()).sort((a, b) => b[1] - a[1]).slice(0, 100);
                 for (const [err, cnt] of sortedErrors) {
                     const fileCounts = filesByError.get(err);
                     const topFiles = Array.from(fileCounts.entries()).sort((a, b) => b[1] - a[1]).slice(0, 100);

+ 531 - 4
tools/tablegen.js

@@ -1678,20 +1678,547 @@ tablegen.Reader = class {
     }
 
     _parseForeach() {
+        const location = this._tokenizer.location();
         this._read();
-        let depth = 0;
-        while (!this._match('eof')) {
-            if (this._eat('{')) {
+        const iterVarName = this._expect('id');
+        this._expect('=');
+        const listValue = this._parseForeachListValue();
+        this._expect('keyword', 'in');
+        const loop = { location, iterVarName, listValue, entries: [], hasDefvar: false };
+        if (this._match('{')) {
+            this._read();
+            this._parseForeachBody(loop);
+            this._expect('}');
+        } else {
+            this._parseForeachBodyStatement(loop);
+        }
+        this._resolveForeachLoop(loop, new Map());
+    }
+
+    _parseForeachListValue() {
+        const values = [];
+        if (this._eat('[')) {
+            while (!this._match(']') && !this._match('eof')) {
+                const value = this._parseListItem();
+                if (value && value.type === 'dag') {
+                    const instantiated = this._instantiateClassTemplate(value.value);
+                    if (instantiated) {
+                        values.push(instantiated);
+                    } else {
+                        values.push(value);
+                    }
+                } else {
+                    values.push(value);
+                }
+                this._eat(',');
+            }
+            this._expect(']');
+        } else if (this._eat('!')) {
+            const op = this._expect('id');
+            if (op === 'range' && this._eat('(')) {
+                const args = [];
+                while (!this._match(')') && !this._match('eof')) {
+                    args.push(this._parseValue());
+                    this._eat(',');
+                }
+                this._expect(')');
+                if (args.length >= 1) {
+                    let start = 0;
+                    let end = 0;
+                    if (args.length === 1) {
+                        end = this._evaluateSimpleValue(args[0]);
+                    } else {
+                        start = this._evaluateSimpleValue(args[0]);
+                        end = this._evaluateSimpleValue(args[1]);
+                    }
+                    for (let i = start; i < end; i++) {
+                        values.push(new tablegen.Value('int', i));
+                    }
+                }
+            } else {
+                while (!this._match('keyword', 'in') && !this._match('eof')) {
+                    this._read();
+                }
+            }
+        } else if (this._eat('{')) {
+            const start = this._expect('number');
+            if (this._eat('-') || this._eat('...')) {
+                const end = this._expect('number');
+                for (let i = start; i <= end; i++) {
+                    values.push(new tablegen.Value('int', i));
+                }
+            }
+            this._expect('}');
+        } else {
+            while (!this._match('keyword', 'in') && !this._match('eof')) {
+                this._read();
+            }
+        }
+        return values;
+    }
+
+    _instantiateClassTemplate(dag) {
+        if (!dag || !dag.operator) {
+            return null;
+        }
+        const className = typeof dag.operator === 'string' ? dag.operator : dag.operator.value;
+        const classRecord = this.classes.get(className);
+        if (!classRecord) {
+            return null;
+        }
+        const fields = new Map();
+        const bindings = new Map();
+        if (classRecord.templateArgs && dag.operands) {
+            for (let i = 0; i < classRecord.templateArgs.length && i < dag.operands.length; i++) {
+                const paramName = classRecord.templateArgs[i].name;
+                const argValue = dag.operands[i].value;
+                bindings.set(paramName, argValue);
+            }
+        }
+        for (const [fieldName, field] of classRecord.fields) {
+            let resolvedValue = field.value;
+            if (resolvedValue && resolvedValue.type === 'def' && bindings.has(resolvedValue.value)) {
+                resolvedValue = bindings.get(resolvedValue.value);
+            } else if (resolvedValue && resolvedValue.type === 'bang') {
+                resolvedValue = this._evaluateBangOp(resolvedValue, bindings);
+            }
+            fields.set(fieldName, resolvedValue);
+        }
+        return new tablegen.Value('record_instance', { className, fields });
+    }
+
+    _evaluateBangOp(value, bindings) {
+        if (!value || value.type !== 'bang') {
+            return value;
+        }
+        const { op, args } = value.value;
+        if (op === 'tolower' && args && args.length === 1) {
+            let [arg] = args;
+            if (arg.type === 'def' && bindings.has(arg.value)) {
+                arg = bindings.get(arg.value);
+            }
+            if (arg.type === 'string') {
+                const str = String(arg.value).replace(/^"|"$/g, '');
+                return new tablegen.Value('string', str.toLowerCase());
+            }
+        }
+        return value;
+    }
+
+    _evaluateSimpleValue(value) {
+        if (!value) {
+            return 0;
+        }
+        if (value.type === 'int') {
+            return typeof value.value === 'number' ? value.value : parseInt(value.value, 10);
+        }
+        if (typeof value === 'number') {
+            return value;
+        }
+        return 0;
+    }
+
+    _parseForeachBody(loop) {
+        while (!this._match('}') && !this._match('eof')) {
+            this._parseForeachBodyStatement(loop);
+            // If we found defvar, skip the rest of the body since we can't properly expand this loop
+            if (loop.hasDefvar) {
+                this._skipUntilClosingBrace();
+                return;
+            }
+        }
+    }
+
+    // Skip tokens until we reach the matching closing brace (but don't consume it)
+    _skipUntilClosingBrace() {
+        let depth = 1;
+        while (depth > 0 && !this._match('eof')) {
+            if (this._match('{')) {
                 depth++;
-            } else if (this._eat('}')) {
+                this._read();
+            } else if (this._match('}')) {
                 depth--;
                 if (depth === 0) {
+                    // Don't consume the final } - let the caller handle it
+                    return;
+                }
+                this._read();
+            } else {
+                this._read();
+            }
+        }
+    }
+
+    _parseForeachBodyStatement(loop) {
+        const token = this._tokenizer.current();
+        if (token.type === 'keyword') {
+            switch (token.value) {
+                case 'def':
+                    loop.entries.push({ type: 'def', data: this._parseDefTemplate() });
+                    break;
+                case 'defm':
+                    this._parseDefm();
+                    break;
+                case 'let':
+                    this._parseLet();
+                    break;
+                case 'defvar':
+                    loop.hasDefvar = true;
+                    this._parseDefvar();
+                    break;
+                case 'foreach':
+                    loop.entries.push({ type: 'foreach', data: this._parseForeachTemplate() });
+                    break;
+                case 'if':
+                    loop.entries.push({ type: 'foreach', data: this._parseIfAsLoop() });
+                    break;
+                default:
+                    this._read();
                     break;
+            }
+        } else {
+            this._read();
+        }
+    }
+
+    _parseIfAsLoop() {
+        const location = this._tokenizer.location();
+        this._read();
+        const condition = this._parseValue();
+        this._expect('keyword', 'then');
+        const loop = {
+            location,
+            iterVarName: null,
+            listValue: [],
+            entries: [],
+            condition,
+            hasDefvar: false,
+            isConditional: true
+        };
+        if (this._match('{')) {
+            this._read();
+            this._parseForeachBody(loop);
+            this._expect('}');
+        } else {
+            this._parseForeachBodyStatement(loop);
+        }
+        if (this._match('keyword', 'else')) {
+            this._read();
+            if (this._match('{')) {
+                this._read();
+                let depth = 1;
+                while (depth > 0 && !this._match('eof')) {
+                    if (this._eat('{')) {
+                        depth++;
+                    } else if (this._eat('}')) {
+                        depth--;
+                    } else {
+                        this._read();
+                    }
+                }
+            }
+        }
+        return loop;
+    }
+
+    _parseDefTemplate() {
+        this._read();
+        const nameTemplate = [];
+        while (!this._match(':') && !this._match('{') && !this._match(';') && !this._match('eof')) {
+            if (this._match('id')) {
+                const value = this._read();
+                if (this._eat('.')) {
+                    const field = this._expect('id');
+                    nameTemplate.push({ type: 'field_access', base: value, field });
+                } else {
+                    nameTemplate.push({ type: 'id', value });
+                }
+            } else if (this._match('string')) {
+                nameTemplate.push({ type: 'string', value: this._read() });
+            } else if (this._match('number')) {
+                nameTemplate.push({ type: 'number', value: this._read() });
+            } else if (this._eat('#')) {
+                nameTemplate.push({ type: 'concat' });
+            } else {
+                break;
+            }
+        }
+        let parents = [];
+        if (this._match(':')) {
+            parents = this._parseParentClassList();
+        }
+        let bodyFields = new Map();
+        if (this._match('{')) {
+            this._read(); // consume '{'
+            bodyFields = this._parseRecordBodyFields();
+            this._expect('}');
+        }
+        this._eat(';');
+        return { nameTemplate, parents, bodyFields };
+    }
+
+    _parseForeachTemplate() {
+        const location = this._tokenizer.location();
+        this._read();
+        const iterVarName = this._expect('id');
+        this._expect('=');
+        const listValue = this._parseForeachListValue();
+        this._expect('keyword', 'in');
+        const loop = { location, iterVarName, listValue, entries: [], hasDefvar: false };
+        if (this._match('{')) {
+            this._read();
+            this._parseForeachBody(loop);
+            this._expect('}');
+        } else {
+            this._parseForeachBodyStatement(loop);
+        }
+        return loop;
+    }
+
+    _parseRecordBodyFields() {
+        const fields = new Map();
+        while (!this._match('}') && !this._match('eof')) {
+            if (this._match('keyword', 'let')) {
+                this._read();
+                const name = this._expect('id');
+                this._expect('=');
+                const value = this._parseValue();
+                this._eat(';');
+                fields.set(name, { name, type: null, value });
+            } else if (this._match('keyword', 'field')) {
+                this._read();
+                const type = this._parseType();
+                const name = this._expect('id');
+                let value = null;
+                if (this._eat('=')) {
+                    value = this._parseValue();
+                }
+                this._eat(';');
+                fields.set(name, { name, type, value });
+            } else if (this._match('keyword', 'assert') || this._match('keyword', 'dump')) {
+                // Skip assert and dump statements
+                this._skipUntil([';']);
+                this._eat(';');
+            } else if (this._match('id') || this._match('keyword')) {
+                // Type followed by field name
+                const type = this._parseType();
+                const name = this._expect('id');
+                let value = null;
+                if (this._eat('=')) {
+                    value = this._parseValue();
                 }
+                this._eat(';');
+                fields.set(name, { name, type, value });
             } else {
                 this._read();
             }
         }
+        return fields;
+    }
+
+    _resolveForeachLoop(loop, substitutions) {
+        if (loop.hasDefvar) {
+            return;
+        }
+        if (loop.entries.length === 0) {
+            return;
+        }
+
+        if (loop.isConditional) {
+            const conditionResult = this._evaluateCondition(loop.condition, substitutions);
+            if (conditionResult === false || conditionResult === null) {
+                return;
+            }
+            for (const entry of loop.entries) {
+                if (entry.type === 'def') {
+                    this._instantiateDef(entry.data, substitutions);
+                } else if (entry.type === 'foreach') {
+                    this._resolveForeachLoop(entry.data, substitutions);
+                }
+            }
+            return;
+        }
+        if (loop.listValue.length === 0) {
+            return;
+        }
+        for (const listItem of loop.listValue) {
+            const currentSubs = new Map(substitutions);
+            if (loop.iterVarName) {
+                currentSubs.set(loop.iterVarName, listItem);
+            }
+            for (const entry of loop.entries) {
+                if (entry.type === 'def') {
+                    this._instantiateDef(entry.data, currentSubs);
+                } else if (entry.type === 'foreach') {
+                    this._resolveForeachLoop(entry.data, currentSubs);
+                }
+            }
+        }
+    }
+
+    _evaluateCondition(condition, substitutions) {
+        if (!condition) {
+            return null;
+        }
+        if (condition.type === 'bang' && condition.value) {
+            const { op, args } = condition.value;
+            if (op === 'ne' && args.length === 2) {
+                const a = this._evaluateSimpleExpr(args[0], substitutions);
+                const b = this._evaluateSimpleExpr(args[1], substitutions);
+                if (a !== null && b !== null) {
+                    return a !== b;
+                }
+            }
+            if (op === 'eq' && args.length === 2) {
+                const a = this._evaluateSimpleExpr(args[0], substitutions);
+                const b = this._evaluateSimpleExpr(args[1], substitutions);
+                if (a !== null && b !== null) {
+                    return a === b;
+                }
+            }
+        }
+        // Can't evaluate complex conditions
+        return null;
+    }
+
+    // Evaluate a simple expression for condition evaluation
+    _evaluateSimpleExpr(expr, substitutions) {
+        if (!expr) {
+            return null;
+        }
+        if (expr.type === 'string') {
+            return String(expr.value).replace(/^"|"$/g, '');
+        }
+        if (expr.type === 'int') {
+            return typeof expr.value === 'number' ? expr.value : parseInt(expr.value, 10);
+        }
+        if ((expr.type === 'def' || expr.type === 'id') && substitutions.has(expr.value)) {
+            return this._evaluateSimpleExpr(substitutions.get(expr.value), substitutions);
+        }
+        return null;
+    }
+
+    _instantiateDef(template, substitutions) {
+        let name = '';
+        for (const part of template.nameTemplate) {
+            if (part.type === 'concat') {
+                continue;
+            } else if (part.type === 'field_access') {
+                if (substitutions.has(part.base)) {
+                    const subValue = substitutions.get(part.base);
+                    name += this._getFieldValue(subValue, part.field);
+                } else {
+                    name += `${part.base}.${part.field}`;
+                }
+            } else if (part.type === 'id') {
+                if (substitutions.has(part.value)) {
+                    const subValue = substitutions.get(part.value);
+                    name += this._valueToString(subValue);
+                } else {
+                    name += part.value;
+                }
+            } else if (part.type === 'string') {
+                name += part.value;
+            } else if (part.type === 'number') {
+                name += String(part.value);
+            }
+        }
+        const def = new tablegen.Record(name, this);
+        def.location = this._tokenizer.location();
+        def.parents = template.parents.map((parent) => ({
+            name: parent.name,
+            args: parent.args ? parent.args.map((arg) => this._substituteValue(arg, substitutions)) : []
+        }));
+        for (const [fieldName, field] of template.bodyFields) {
+            const resolvedValue = field.value ? this._substituteValue(field.value, substitutions) : null;
+            def.fields.set(fieldName, new tablegen.RecordVal(fieldName, field.type, resolvedValue));
+        }
+        this.addSubClass(def);
+        def.resolveReferences();
+        if (name) {
+            this._defs.set(name, def);
+            this.defs.push(def);
+        }
+    }
+
+    _valueToString(value) {
+        if (!value) {
+            return '';
+        }
+        if (typeof value === 'string') {
+            return value;
+        }
+        if (typeof value === 'number') {
+            return String(value);
+        }
+        if (value.type === 'string') {
+            return String(value.value).replace(/^"|"$/g, '');
+        }
+        if (value.type === 'int') {
+            return String(value.value);
+        }
+        if (value.type === 'id' || value.type === 'def') {
+            return String(value.value);
+        }
+        return '';
+    }
+
+    _getFieldValue(value, fieldName) {
+        if (!value) {
+            return '';
+        }
+        if (value.type === 'record_instance' && value.value && value.value.fields) {
+            const fieldValue = value.value.fields.get(fieldName);
+            return this._valueToString(fieldValue);
+        }
+        return '';
+    }
+
+    _substituteValue(value, substitutions) {
+        if (!value) {
+            return value;
+        }
+        if (value.type === 'def' || value.type === 'id') {
+            const varName = value.value;
+            if (substitutions.has(varName)) {
+                return substitutions.get(varName);
+            }
+            if (typeof varName === 'string' && varName.includes('.')) {
+                const [base, field] = varName.split('.', 2);
+                if (substitutions.has(base)) {
+                    const baseValue = substitutions.get(base);
+                    const fieldValue = this._getFieldValue(baseValue, field);
+                    if (fieldValue) {
+                        return new tablegen.Value('string', fieldValue);
+                    }
+                }
+            }
+        }
+        if (value.type === 'list' && Array.isArray(value.value)) {
+            return {
+                type: 'list',
+                value: value.value.map((v) => this._substituteValue(v, substitutions))
+            };
+        }
+        if (value.type === 'dag' && value.value) {
+            return {
+                type: 'dag',
+                value: {
+                    operator: value.value.operator,
+                    operands: value.value.operands.map((op) => ({
+                        value: this._substituteValue(op.value, substitutions),
+                        name: op.name
+                    }))
+                }
+            };
+        }
+        if (value.type === 'concat' && Array.isArray(value.value)) {
+            return {
+                type: 'concat',
+                value: value.value.map((v) => this._substituteValue(v, substitutions))
+            };
+        }
+        return value;
     }
 
     _parseTemplateParams() {

Some files were not shown because too many files changed in this diff