Bläddra i källkod

Add MLIR support (#1044)

Lutz Roeder 2 månader sedan
förälder
incheckning
e099daaf9b
7 ändrade filer med 1533 tillägg och 224 borttagningar
  1. 855 0
      source/mlir-metadata.json
  2. 556 209
      source/mlir.js
  3. 1 0
      source/openvino-metadata.json
  4. 14 0
      test/models.json
  5. 4 0
      tools/mlir
  6. 19 3
      tools/mlir-script.js
  7. 84 12
      tools/tablegen.js

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 855 - 0
source/mlir-metadata.json


Filskillnaden har hållts tillbaka eftersom den är för stor
+ 556 - 209
source/mlir.js


+ 1 - 0
source/openvino-metadata.json

@@ -92,6 +92,7 @@
   },
   {
     "name": "Clamp",
+    "category": "Activation",
     "description": "*Clamp* layer represents clipping activation operation.\n**Detailed description**: [Reference](https://www.tensorflow.org/versions/r1.2/api_docs/MO_DG/prepare_model/python/tf/clip_by_value)\n**Parameters**: *Clamp* layer parameters should be specified as the `data` node, which is a child of the layer node.\n**Mathematical Formulation**\n*Clamp* generally does the following with the input blobs:\n\\f[\nout_i=\\left\\{\\begin{array}{ll}\n\tmax\\_value \\quad \\mbox{if } \\quad input_i>max\\_value \\\\\n\tmin\\_value \\quad \\mbox{if } \\quad input_i\n\\end{array}\\right.\n\\f]\n**Example**\n\n```html\n<layer ... type=\"Clamp\" ... >\n    <data min=\"10\" max=\"50\" />\n    <input> ... </input>\n    <output> ... </output>\n</layer>\n```",
     "attributes": [
       {

+ 14 - 0
test/models.json

@@ -3482,6 +3482,20 @@
     "format":   "MLIR",
     "link":     "https://github.com/lutzroeder/netron/issues/1044"
   },
+  {
+    "type":     "mlir",
+    "target":   "squeezenet.mlir",
+    "source":   "https://github.com/user-attachments/files/24284768/squeezenet.mlir.zip[squeezenet.mlir]",
+    "format":   "MLIR",
+    "link":     "https://github.com/lutzroeder/netron/issues/1044"
+  },
+  {
+    "type":     "mlir",
+    "target":   "squeezenet.mlirbc",
+    "source":   "https://github.com/user-attachments/files/24284768/squeezenet.mlir.zip[squeezenet.mlirbc]",
+    "format":   "MLIR Bytecode v1",
+    "link":     "https://github.com/lutzroeder/netron/issues/1044"
+  },
   {
     "type":     "mlir",
     "target":   "stablehlo_ea.mlir",

+ 4 - 0
tools/mlir

@@ -71,6 +71,10 @@ sync() {
     echo '#ifndef LINALG_NAMED_STRUCTURED_OPS_YAMLGEN_TD
 #define LINALG_NAMED_STRUCTURED_OPS_YAMLGEN_TD
 #endif // LINALG_NAMED_STRUCTURED_OPS_YAMLGEN_TD' > "${src_dir}/_/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td"
+    mkdir -p "${src_dir}/_/mlir-hlo/thlo/IR"
+    curl --silent --show-error --location --output "${src_dir}/_/mlir-hlo/thlo/IR/thlo_ops.td" "https://raw.githubusercontent.com/rengolin/mlir-hlo/08ce879a5a04ea95dd02515baadc1796901546c5/thlo/IR/thlo_ops.td"
+    mkdir -p "${src_dir}/_/mlir-hlo/gml_st/interfaces"
+    curl --silent --show-error --location --output "${src_dir}/_/mlir-hlo/gml_st/interfaces/tiling_interface.td" "https://raw.githubusercontent.com/rengolin/mlir-hlo/08ce879a5a04ea95dd02515baadc1796901546c5/gml_st/interfaces/tiling_interface.td"
 }
 
 schema() {

+ 19 - 3
tools/mlir-script.js

@@ -64,6 +64,7 @@ const schema = async () => {
         path.join(source, 'llvm-project', 'mlir', 'examples', 'toy', 'Ch7', 'include'),
         path.join(source, 'llvm-project', 'mlir', 'examples', 'transform', 'Ch2', 'include'),
         path.join(source, 'llvm-project', 'mlir', 'examples', 'transform', 'Ch3', 'include'),
+        path.join(source, 'llvm-project', 'mlir', 'examples', 'transform', 'Ch4', 'include'),
         path.join(source, 'stablehlo'),
         path.join(source, 'shardy'),
         path.join(source, 'xla', 'xla', 'mlir_hlo'),
@@ -94,11 +95,12 @@ const schema = async () => {
         path.join(source, 'triton', 'third_party', 'nvidia', 'include'),
         path.join(source, 'triton', 'third_party', 'nvidia', 'include', 'Dialect', 'NVGPU', 'IR'),
         path.join(source, 'triton', 'third_party', 'nvidia', 'include', 'Dialect', 'NVWS', 'IR'),
-        path.join(source, '_', 'llvm-project', 'mlir', 'include'),
         path.join(source, 'clangir'),
         path.join(source, 'clangir', 'clang', 'include'),
         path.join(source, 'rocMLIR'),
         path.join(source, 'rocMLIR', 'mlir', 'include'),
+        path.join(source, '_', 'llvm-project', 'mlir', 'include'),
+        path.join(source, '_', 'mlir-hlo'),
     ];
     const dialects = [
         'mlir/include/mlir/IR/BuiltinAttributeInterfaces.td',
@@ -204,11 +206,13 @@ const schema = async () => {
         'mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td',
         'mlir/include/mlir/Dialect/WasmSSA/IR/WasmSSAOps.td',
         'mlir/include/mlir/Dialect/X86Vector/X86Vector.td',
+        'mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td',
         'mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td',
         'mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td',
         'mlir/examples/toy/Ch7/include/toy/Ops.td',
         'mlir/examples/transform/Ch2/include/MyExtension.td',
         'mlir/examples/transform/Ch3/include/MyExtension.td',
+        'mlir/examples/transform/Ch4/include/MyExtension.td',
         'stablehlo/dialect/StablehloOps.td',
         'stablehlo/dialect/ChloOps.td',
         'stablehlo/dialect/VhloOps.td',
@@ -217,6 +221,7 @@ const schema = async () => {
         'shardy/dialect/sdy/ir/ops.td',
         'shardy/dialect/mpmd/ir/ops.td',
         'mhlo/IR/hlo_ops.td',
+        'thlo/IR/thlo_ops.td',
         'src/Dialect/ONNX/ONNX.td',
         'src/Dialect/ONNX/ONNXOps.td.inc',
         'src/Dialect/ONNX/AdditionalONNXOps.td',
@@ -244,10 +249,12 @@ const schema = async () => {
         'tfrt/test_kernels/opdefs/test_kernels.td',
         'tfrt/tensor/opdefs/tensor.td',
         'tfrt/tensor/opdefs/dense_host_tensor.td',
+        'tfrt/tensor/opdefs/coo_host_tensor.td',
         'tfrt/tensor/opdefs/tensor_shape.td',
         'mlir/test/lib/Dialect/Test/TestOps.td',
         'mlir/test/lib/Dialect/Test/TestOpsSyntax.td',
         'mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td',
+        'mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td',
         'mlir/test/lib/Transforms/TestTransformsOps.td',
         'iree/compiler/Dialect/HAL/IR/HALOps.td',
         'iree/compiler/Dialect/HAL/IR/HALTypes.td',
@@ -285,10 +292,12 @@ const schema = async () => {
         'pmlc/dialect/pxa/ir/ops.td',
         'pmlc/dialect/linalgx/ir/ops.td',
         'pmlc/dialect/xsmm/ir/ops.td',
+        'pmlc/dialect/layer/ir/ops.td',
         'SDFG/Dialect/Ops.td',
         'lltz/mlir/dialect/include/Michelson/MichelsonOps.td',
         'triton/Dialect/Triton/IR/TritonOps.td',
         'triton/Dialect/TritonGPU/IR/TritonGPUOps.td',
+        'triton/Dialect/TritonInstrument/IR/TritonInstrumentOps.td',
         'triton/Dialect/Gluon/IR/GluonOps.td',
         'triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td',
         'triton/third_party/nvidia/include/Dialect/NVWS/IR/NVWSOps.td',
@@ -369,12 +378,14 @@ const schema = async () => {
             operation.category = 'Shape';
         } else if (['transpose', 'reverse', 'pad', 'Transpose', 'Pad'].includes(name)) {
             operation.category = 'Transform';
-        } else if (['slice', 'split', 'dynamic_slice', 'gather', 'scatter', 'Slice', 'Gather', 'Scatter', 'concatenate'].includes(name)) {
+        } else if (['slice', 'split', 'dynamic_slice', 'gather', 'scatter', 'Slice', 'Gather', 'Scatter', 'concat', 'concatenate'].includes(name)) {
             operation.category = 'Tensor';
-        } else if (['tanh', 'Sigmoid', 'Tanh', 'Relu', 'Softmax', 'softmax', 'sigmoid', 'relu'].includes(name)) {
+        } else if (['tanh', 'Sigmoid', 'Tanh', 'Relu', 'Softmax', 'softmax', 'sigmoid', 'relu', 'clamp'].includes(name)) {
             operation.category = 'Activation';
         } else if (['convolution', 'Conv', 'conv2d', 'conv3d', 'fully_connected', 'conv_2d'].includes(name)) {
             operation.category = 'Layer';
+        } else if (['max_pool2d'].includes(name)) {
+            operation.category = 'Pool';
         } else if (['batch_norm_inference'].includes(name)) {
             operation.category = 'Normalization';
         }
@@ -709,9 +720,14 @@ const test = async (pattern) => {
         'third_party/source/mlir/stablehlo/stablehlo/tests/vhlo/invalid_vhlo_future.mlir',
         'third_party/source/mlir/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir',
         'third_party/source/mlir/llvm-project/mlir/test/Dialect/Quant/parse-uniform-invalid.mlir',
+        'third_party/source/mlir/llvm-project/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir',
         'third_party/source/mlir/mlir-dace/design/mlir/map.mlir',
         'third_party/source/mlir/mlir-dace/design/mlir/simple_sdfg.mlir',
         'third_party/source/mlir/mlir-dace/design/mlir/symbol.mlir',
+        'third_party/source/mlir/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir',
+        'third_party/source/mlir/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir',
+        'third_party/source/mlir/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_tf_drq.mlir',
+        'third_party/source/mlir/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_xla_weight_only.mlir',
     ]);
     return new Promise((resolve, reject) => {
         const cmd = 'node';

+ 84 - 12
tools/tablegen.js

@@ -1736,7 +1736,7 @@ tablegen.Reader = class {
         this._expect('=');
         const listValue = this._parseForeachListValue();
         this._expect('keyword', 'in');
-        const loop = { location, iterVarName, listValue, entries: [], hasDefvar: false };
+        const loop = { location, iterVarName, listValue, entries: [] };
         if (this._match('{')) {
             this._read();
             this._parseForeachBody(loop);
@@ -1873,11 +1873,6 @@ tablegen.Reader = class {
     _parseForeachBody(loop) {
         while (!this._match('}') && !this._match('eof')) {
             this._parseForeachBodyStatement(loop);
-            // If we found defvar, skip the rest of the body since we can't properly expand this loop
-            if (loop.hasDefvar) {
-                this._skipUntilClosingBrace();
-                return;
-            }
         }
     }
 
@@ -1915,8 +1910,7 @@ tablegen.Reader = class {
                     this._parseLet();
                     break;
                 case 'defvar':
-                    loop.hasDefvar = true;
-                    this._parseDefvar();
+                    loop.entries.push({ type: 'defvar', data: this._parseDefvarTemplate() });
                     break;
                 case 'foreach':
                     loop.entries.push({ type: 'foreach', data: this._parseForeachTemplate() });
@@ -1991,6 +1985,11 @@ tablegen.Reader = class {
                 nameTemplate.push({ type: 'number', value: this._read() });
             } else if (this._eat('#')) {
                 nameTemplate.push({ type: 'concat' });
+            } else if (this._match('!')) {
+                const bangValue = this._parseValue();
+                if (bangValue && bangValue.type === 'bang') {
+                    nameTemplate.push({ type: 'bang', value: bangValue.value });
+                }
             } else {
                 break;
             }
@@ -2016,7 +2015,7 @@ tablegen.Reader = class {
         this._expect('=');
         const listValue = this._parseForeachListValue();
         this._expect('keyword', 'in');
-        const loop = { location, iterVarName, listValue, entries: [], hasDefvar: false };
+        const loop = { location, iterVarName, listValue, entries: [] };
         if (this._match('{')) {
             this._read();
             this._parseForeachBody(loop);
@@ -2027,6 +2026,15 @@ tablegen.Reader = class {
         return loop;
     }
 
+    _parseDefvarTemplate() {
+        this._read();
+        const name = this._expect('id');
+        this._expect('=');
+        const value = this._parseValue();
+        this._expect(';');
+        return { name, value };
+    }
+
     _parseRecordBodyFields() {
         const fields = new Map();
         while (!this._match('}') && !this._match('eof')) {
@@ -2069,9 +2077,6 @@ tablegen.Reader = class {
     }
 
     _resolveForeachLoop(loop, substitutions) {
-        if (loop.hasDefvar) {
-            return;
-        }
         if (loop.entries.length === 0) {
             return;
         }
@@ -2086,6 +2091,9 @@ tablegen.Reader = class {
                     this._instantiateDef(entry.data, substitutions);
                 } else if (entry.type === 'foreach') {
                     this._resolveForeachLoop(entry.data, substitutions);
+                } else if (entry.type === 'defvar') {
+                    const value = this._evaluateDefvar(entry.data.value, substitutions);
+                    substitutions.set(entry.data.name, value);
                 }
             }
             return;
@@ -2103,6 +2111,9 @@ tablegen.Reader = class {
                     this._instantiateDef(entry.data, currentSubs);
                 } else if (entry.type === 'foreach') {
                     this._resolveForeachLoop(entry.data, currentSubs);
+                } else if (entry.type === 'defvar') {
+                    const value = this._evaluateDefvar(entry.data.value, currentSubs);
+                    currentSubs.set(entry.data.name, value);
                 }
             }
         }
@@ -2150,6 +2161,60 @@ tablegen.Reader = class {
         return null;
     }
 
+    _evaluateDefvar(value, substitutions) {
+        if (!value) {
+            return new tablegen.Value('string', '');
+        }
+        if (value.type === 'string') {
+            return value;
+        }
+        if (value.type === 'int') {
+            return value;
+        }
+        if ((value.type === 'def' || value.type === 'id') && substitutions.has(value.value)) {
+            return substitutions.get(value.value);
+        }
+        if (value.type === 'concat') {
+            const parts = value.value.map((part) => this._evaluateDefvar(part, substitutions));
+            let result = '';
+            for (const part of parts) {
+                if (part.type === 'string') {
+                    result += String(part.value).replace(/^"|"$/g, '');
+                } else if (part.type === 'int') {
+                    result += String(part.value);
+                } else if (part.type === 'def' || part.type === 'id') {
+                    result += String(part.value);
+                }
+            }
+            return new tablegen.Value('string', result);
+        }
+        if (value.type === 'bang' && value.value) {
+            const { op, args } = value.value;
+            if (op === 'cast' && args && args.length > 0) {
+                const arg = this._evaluateDefvar(args[0], substitutions);
+                if (arg.type === 'int') {
+                    return new tablegen.Value('string', String(arg.value));
+                }
+                return arg;
+            }
+            if (op === 'toupper' && args && args.length > 0) {
+                const arg = this._evaluateDefvar(args[0], substitutions);
+                if (arg.type === 'string') {
+                    return new tablegen.Value('string', String(arg.value).toUpperCase());
+                }
+                return arg;
+            }
+            if (op === 'tolower' && args && args.length > 0) {
+                const arg = this._evaluateDefvar(args[0], substitutions);
+                if (arg.type === 'string') {
+                    return new tablegen.Value('string', String(arg.value).toLowerCase());
+                }
+                return arg;
+            }
+        }
+        return value;
+    }
+
     _instantiateDef(template, substitutions) {
         let name = '';
         for (const part of template.nameTemplate) {
@@ -2173,6 +2238,9 @@ tablegen.Reader = class {
                 name += part.value;
             } else if (part.type === 'number') {
                 name += String(part.value);
+            } else if (part.type === 'bang') {
+                const evaluated = this._evaluateDefvar(new tablegen.Value('bang', part.value), substitutions);
+                name += this._valueToString(evaluated);
             }
         }
         const def = new tablegen.Record(name, this);
@@ -2491,6 +2559,10 @@ tablegen.Reader = class {
         while (this._match('#') || (values[values.length - 1] && values[values.length - 1].type === 'string' && this._match('string'))) {
             if (this._match('#')) {
                 this._read();
+                // Handle trailing # before ; (malformed TableGen but seen in some files)
+                if (this._match(';') || this._match(',') || this._match(')') || this._match(']') || this._match('}') || this._match('eof')) {
+                    break;
+                }
             }
             values.push(this._parsePrimaryValue());
         }

Vissa filer visades inte eftersom för många filer har ändrats