Selaa lähdekoodia

Add MLIR support (#1044)

Lutz Roeder 1 kuukausi sitten
vanhempi
sitoutus
75c5bba678
4 muutettua tiedostoa jossa 485 lisäystä ja 508 poistoa
  1. 187 34
      source/mlir-metadata.json
  2. 215 470
      source/mlir.js
  3. 45 0
      tools/mlir-script.js
  4. 38 4
      tools/tablegen.js

+ 187 - 34
source/mlir-metadata.json

@@ -1473,6 +1473,9 @@
       { "name": "evict", "type": "DefaultValuedAttr<TT_EvictionPolicyAttr{evict_normal|evict_first|evict_last}, triton::EvictionPolicy::NORMAL>" },
       { "name": "contiguity", "type": "DefaultValuedAttr<I32Attr, 1>" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'dst', 'mask', 'getI1SameShape($_self)'>" }
+    ],
     "assemblyFormat": "$src `,` $dst (`mask` $mask^)?\n    oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)\n    attr-dict `:` qualified(type($src)) `->` type($dst)"
   },
   {
@@ -3741,7 +3744,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -3761,7 +3765,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -3856,7 +3861,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -3876,7 +3882,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -3896,7 +3903,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -3916,7 +3924,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -3964,7 +3973,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -3984,7 +3994,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -4005,7 +4016,8 @@
       { "name": "layout", "type": "ArmSME_TileSliceLayoutAttr{horizontal|vertical}" }
     ],
     "traits": [
-      { "type": "AttrSizedOperandSegments" }
+      { "type": "AttrSizedOperandSegments" },
+      { "type": "TypesMatchWith<'result', 'padding', '::llvm::cast<VectorType>($_self).getElementType()'>" }
     ],
     "assemblyFormat": "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?attr-dict `:` type($base) `,` type($result)"
   },
@@ -4043,7 +4055,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -4063,7 +4076,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -4083,7 +4097,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -4103,7 +4118,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -4123,7 +4139,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -4143,7 +4160,8 @@
     ],
     "traits": [
       { "type": "AttrSizedOperandSegments" },
-      { "type": "AllTypesMatch<['lhs', 'rhs']>" }
+      { "type": "AllTypesMatch<['lhs', 'rhs']>" },
+      { "type": "TypesMatchWith<'result', 'acc', '::llvm::cast<Type>($_self)'>" }
     ],
     "assemblyFormat": "$lhs `,` $rhs\n    oilist(\n        `acc` `` `(` $acc `)`\n      | `masks` `` `(` $lhsMask `,` $rhsMask `)`\n    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)"
   },
@@ -4166,6 +4184,9 @@
     "results": [
       { "name": "result", "type": "SVEPredicateMask" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'result', 'source', 'VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))'>" }
+    ],
     "assemblyFormat": "$source attr-dict `:` type($result)"
   },
   {
@@ -4178,6 +4199,9 @@
     "results": [
       { "name": "result", "type": "SVBoolMask" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'source', 'result', 'VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))'>" }
+    ],
     "assemblyFormat": "$source attr-dict `:` type($source)"
   },
   {
@@ -18420,7 +18444,7 @@
   {
     "name": "iree_linalg_ext.online_attention",
     "summary": "Online Attention operator.",
-    "description": "Traditional scaled dot product attention computes:\n\n    attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V\n\n    Online Attention on the other hand, uses an online normalizer instead of\n    softmax:\n\n    online_attention(Q, K, V, scale, running_max, running_sum)\n      = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V\n\n    If an additional mask argument M is included, the result of the first matmul is modified according to:\n\n    Q @ K.T += M\n\n    The advantage of this online_normalizer is that it can be tiled along\n    its reduction dimension, making the online_attention operator:\n      - Tilable along softmax reduction dimension\n      - Associative along softmax reduction dimension\n      - Commutative along softmax associative dimension\n\n    Note: The results of online_attention need to be combined after computing\n    it over the entire softmax reduction dimension by:\n      x, _, sum : results\n      x = (1 / sum) * x",
+    "description": "Traditional scaled dot product attention computes:\n\n    attention(Q, K, V, scale) = softmax(Q @ K.T * scale) @ V\n\n    Online Attention on the other hand, uses an online normalizer instead of\n    softmax:\n\n    online_attention(Q, K, V, scale, running_max, running_sum)\n      = online_normalizer(Q @ K.T * scale, running_max, running_sum) @ V\n\n    If an additional mask argument M is included, the result of the first matmul is modified according to:\n\n    Q @ K.T += M\n\n    The advantage of this online_normalizer is that it can be tiled along\n    its reduction dimension, making the online_attention operator:\n      - Tilable along softmax reduction dimension\n      - Associative along softmax reduction dimension\n      - Commutative along softmax associative dimension\n\n    Note: The results of online_attention need to be combined after computing\n    it over the entire softmax reduction dimension by:\n      x, _, sum : results\n      x = (1 / sum) * x\n\n    Decomposition Configuration:\n    The `decomposition_config` attribute is a DictionaryAttr that controls how\n    this operation is decomposed into lower-level operations. It supports:\n      - \"qk_attrs\": DictionaryAttr - Attributes to attach to the Q@K matmul\n        operation after decomposition (e.g., lowering_config, attention markers)\n      - \"pv_attrs\": DictionaryAttr - Attributes to attach to the P@V matmul\n        operation after decomposition\n      - \"use_exp2\": BoolAttr - If true, uses exp2 with log2(e) scaling instead\n        of exp. (Gives better perf on some hardware, but trades off accuracy)",
     "operands": [
       { "name": "query", "type": "AnyShaped" },
       { "name": "key", "type": "AnyShaped" },
@@ -20787,7 +20811,8 @@
       { "name": "static_inner_tiles", "type": "DenseI64ArrayAttr" }
     ],
     "traits": [
-      { "type": "AttrSizedOperandSegments" }
+      { "type": "AttrSizedOperandSegments" },
+      { "type": "TypesMatchWith<'dest', 'result', '$_self'>" }
     ],
     "hasCustomAssemblyFormat": true
   },
@@ -21297,6 +21322,9 @@
       { "name": "inner_dims_pos", "type": "DenseI64ArrayAttr" },
       { "name": "static_inner_tiles", "type": "DenseI64ArrayAttr" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'dest', 'result', '$_self'>" }
+    ],
     "hasCustomAssemblyFormat": true
   },
   {
@@ -28892,19 +28920,22 @@
       { "name": "inits", "type": "Variadic<MHLO_Tensor>" }
     ],
     "results": [
-      { "name": "results", "type": "Variadic<MHLO_Tensor>" }
+      { "name": "outputs", "type": "Variadic<MHLO_Tensor>" },
+      { "name": "carries", "type": "Variadic<MHLO_Tensor>" }
     ],
     "attributes": [
       { "name": "dimension", "type": "ConfinedAttr<I64Attr, [IntNonNegative]>" },
       { "name": "is_reverse", "type": "DefaultValuedOptionalAttr<BoolAttr, false>" },
-      { "name": "is_associative", "type": "DefaultValuedOptionalAttr<MHLO_AssociativityAttr{MAYBE|TRUE|FALSE}, ::mlir::mhlo::Associativity::MAYBE>" }
+      { "name": "is_associative", "type": "OptionalAttr<BoolAttr>" }
     ],
     "regions": [
       { "name": "body", "type": "SizedRegion<1>" }
     ],
     "traits": [
-      { "type": "AttrSizedOperandSegments" }
-    ]
+      { "type": "AttrSizedOperandSegments" },
+      { "type": "IsolatedFromAbove" }
+    ],
+    "hasCustomAssemblyFormat": true
   },
   {
     "name": "mhlo.scatter",
@@ -30318,6 +30349,20 @@
     ],
     "assemblyFormat": "$cond $body_fn `(` $arguments `)` attr-dict `:` `(` type($arguments) `)` `->` `(` type(results) `)`"
   },
+  {
+    "name": "mpi.allgather",
+    "summary": "Equivalent to `MPI_Allgather(sendbuf, sendcount, sendtype,\n                                 recvbuf, recvcount, recvtype,\n                                 comm)`.",
+    "description": "MPI_Allgather collects data from all processes in a given communicator and\n    stores the gathered data in the receive buffer of each process.\n\n    Each process contributes the same amount of data defined by `sendbuf`.\n    The MPI call specifies the number of elements contributed by each process\n    via the `recvcount` parameter. However, this operation, assumes `recvbuf`\n    to be sufficiently large to hold the data contributed by all processes.\n    Therefore, `recvcount` is implicitly defined as\n    `num_elements(recvbuf) / MPI_Comm_size(comm)`.\n\n    This operation may optionally return an !mpi.retval value, which can be\n    used for error checking.",
+    "operands": [
+      { "name": "sendbuf", "type": "AnyMemRef" },
+      { "name": "recvbuf", "type": "AnyMemRef" },
+      { "name": "comm", "type": "MPI_Comm" }
+    ],
+    "results": [
+      { "name": "retval", "type": "Optional<MPI_Retval>" }
+    ],
+    "assemblyFormat": "`(` $sendbuf `,` $recvbuf `,` $comm `)` attr-dict `:` type($sendbuf) `,` type($recvbuf) (`->` type($retval)^)?"
+  },
   {
     "name": "mpi.allreduce",
     "summary": "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`",
@@ -33634,6 +33679,21 @@
     ],
     "assemblyFormat": "$kind attr-dict"
   },
+  {
+    "name": "nvvm.tensormap.replace",
+    "summary": "Modifies a field of the tensor-map object",
+    "description": "The `nvvm.tensormap.replace` replaces the specified field of the tensor-map \n    object at the location specified by `addr` with a new value (specified by \n    `new_value` or `new_value_attr`).\n\n    The `field` argument specifies the field of the tensor-map object to \n    replace.\n\n    `new_value` is an `i32`/`i64` argument that specifies the new value to \n    replace the `field` with for the `global_address`, `rank`, `box_dim`, \n    `global_dim`, `global_stride`, and `element_stride` fields. It must be an \n    `i64` for the `global_address` and `global_stride` fields and `i32` for the \n    remaining fields.\n    \n    For `rank`, `new_value` must be one less than the desired tensor rank as \n    this field uses zero-based numbering.\n\n    `new_value_attr` is an attribute that specifies the new value to replace \n    the `field` with for the `elemtype`, `interleave_layout`, `swizzle_mode`, \n    `swizzle_atomicity`, and `fill_mode` fields. It takes the place of \n    `new_value` for these fields. It must be a valid attribute corresponding to \n    the `field` type.\n\n    The ordinal `ord` is an immediate integer argument that specifies the \n    ordinal of the `field` across the tensor which needs to be replaced and is \n    required only for the `box_dim`, `global_dim`, `global_stride`, and \n    `element_stride` fields.\n\n    [For more information, see PTX ISA.](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-tensormap-replace)",
+    "operands": [
+      { "name": "addr", "type": "AnyTypeOf<[LLVM_PointerGlobal, LLVM_PointerShared]>" },
+      { "name": "new_value", "type": "Optional<AnyTypeOf<[ I64, I32 ]>>" }
+    ],
+    "attributes": [
+      { "name": "field", "type": "TensormapFieldAttr{global_address|rank|box_dim|global_dim|global_stride|element_stride|elemtype|interleave_layout|swizzle_mode|swizzle_atomicity|fill_mode}" },
+      { "name": "ord", "type": "OptionalAttr<ConfinedAttr<I32Attr, [ IntMinValue < 1 >, IntMaxValue < 5 > ]>>" },
+      { "name": "new_value_attr", "type": "OptionalAttr<TensormapFieldValueAttr>" }
+    ],
+    "assemblyFormat": "`field` `=` $field (`[` $ord^ `]`)? `,` `new_value` `=` ($new_value_attr^):($new_value)? `in` $addr attr-dict `:` type(operands)"
+  },
   {
     "name": "nvvm.vote.sync",
     "summary": "Vote across thread group",
@@ -33773,6 +33833,9 @@
     "results": [
       { "name": "result", "type": "NVWS_ArefType" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'result', 'buffers', '::llvm::cast<ArefType>($_self).getBaseType()'>" }
+    ],
     "assemblyFormat": "$buffers attr-dict `:` type($result)"
   },
   {
@@ -34142,18 +34205,19 @@
   {
     "name": "omp.declare_simd",
     "summary": "declare simd directive",
-    "description": "\"omp.declare_simd\" models the OpenMP `declare simd` directive.\n\n    This is a declarative operation (no region) intended to appear inside\n    a function body. It attaches clauses of declare simd to the enclosing\n    function.\n\n    Example:\n    ```mlir\n    func.func @add(%a: memref<16xi32>) {\n      omp.declare_simd simdlen(8) aligned(%a : memref<16xi32> -> 64 : i64)\n      ...\n    }\n    ```The `alignments` attribute additionally specifies alignment of each\n    corresponding aligned operand. Note that `aligned_vars` and `alignments`\n    must contain the same number of elements.The `linear_step_vars` operand additionally specifies the step for each\n    associated linear operand. Note that the `linear_vars` and\n    `linear_step_vars` variadic lists should contain the same number of\n    elements.When a `simdlen` clause is present, the preferred number of iterations to be\n    executed concurrently is the value provided to the `simdlen` clause.",
+    "description": "\"omp.declare_simd\" models the OpenMP `declare simd` directive.\n\n    This is a declarative operation (no region) intended to appear inside\n    a function body. It attaches clauses of declare simd to the enclosing\n    function.\n\n    Example:\n    ```mlir\n    func.func @add(%a: memref<16xi32>) {\n      omp.declare_simd simdlen(8) aligned(%a : memref<16xi32> -> 64 : i64)\n      ...\n    }\n    ```The `alignments` attribute additionally specifies alignment of each\n    corresponding aligned operand. Note that `aligned_vars` and `alignments`\n    must contain the same number of elements.The `linear_step_vars` operand additionally specifies the step for each\n    associated linear operand. Note that the `linear_vars` and\n    `linear_step_vars` variadic lists should contain the same number of\n    elements.When a `simdlen` clause is present, the preferred number of iterations to be\n    executed concurrently is the value provided to the `simdlen` clause.The `uniform` clause declares one or more arguments to have an invariant\n    value for all concurrent invocations of the function in the execution of\n    a single SIMD loop.",
     "operands": [
       { "name": "aligned_vars", "type": "Variadic<OpenMP_PointerLikeType>" },
       { "name": "linear_vars", "type": "Variadic<AnyType>" },
-      { "name": "linear_step_vars", "type": "Variadic<I32>" }
+      { "name": "linear_step_vars", "type": "Variadic<I32>" },
+      { "name": "uniform_vars", "type": "Variadic<OpenMP_PointerLikeType>" }
     ],
     "attributes": [
       { "name": "alignments", "type": "OptionalAttr<TypedArrayAttrBase<I64Attr>>" },
       { "name": "linear_var_types", "type": "OptionalAttr<ArrayAttr>" },
       { "name": "simdlen", "type": "ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>" }
     ],
-    "assemblyFormat": "oilist(`aligned` `(` custom<AlignedClause>($aligned_vars, type($aligned_vars),\n                                        $alignments) `)`|`linear` `(`\n      custom<LinearClause>($linear_vars, type($linear_vars),\n                           $linear_step_vars) `)`|`simdlen` `(` $simdlen  `)`) attr-dict"
+    "assemblyFormat": "oilist(`aligned` `(` custom<AlignedClause>($aligned_vars, type($aligned_vars),\n                                        $alignments) `)`|`linear` `(`\n      custom<LinearClause>($linear_vars, type($linear_vars),\n                           $linear_step_vars) `)`|`simdlen` `(` $simdlen  `)`|`uniform` `(` custom<UniformClause>($uniform_vars, type($uniform_vars)) `)`) attr-dict"
   },
   {
     "name": "omp.distribute",
@@ -34264,7 +34328,7 @@
     ],
     "attributes": [
       { "name": "var_type", "type": "TypeAttr" },
-      { "name": "map_type", "type": "ClauseMapFlagsAttr{none|storage|to|from|always|del|return_param|priv|literal|implicit|close|present|ompx_hold|attach|attach_always|attach_none|attach_auto|ref_ptr|ref_ptee|ref_ptr_ptee|is_device_ptr}" },
+      { "name": "map_type", "type": "ClauseMapFlagsAttr{none|storage|to|from|always|del|return_param|priv|literal|implicit|close|present|ompx_hold|attach|attach_always|attach_never|attach_auto|ref_ptr|ref_ptee|ref_ptr_ptee|is_device_ptr}" },
       { "name": "map_capture_type", "type": "VariableCaptureKindAttr{This|ByRef|ByCopy|VLAType}" },
       { "name": "members_index", "type": "OptionalAttr<TypedArrayAttrBase<TypedArrayAttrBase<I64Attr>>>" },
       { "name": "mapper_id", "type": "OptionalAttr<FlatSymbolRefAttr>" },
@@ -39794,6 +39858,9 @@
     "results": [
       { "name": "results", "type": "Variadic<AnyRankedTensor>" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'operands', 'results', 'llvm::make_range($_self.begin(), $_self.end())'>" }
+    ],
     "assemblyFormat": "$operands attr-dict  `:` type($operands)"
   },
   {
@@ -40075,6 +40142,9 @@
     "results": [
       { "name": "result", "type": "PtrLikeTypeInterface" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'result', 'metadata', 'PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))'>" }
+    ],
     "assemblyFormat": "$ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result)"
   },
   {
@@ -46530,7 +46600,8 @@
       { "name": "result", "type": "AnyNonFuncSMTType" }
     ],
     "traits": [
-      { "type": "TypesMatchWith<'func', 'result', 'cast<SMTFuncType>($_self).getRangeType()'>" }
+      { "type": "TypesMatchWith<'func', 'result', 'cast<SMTFuncType>($_self).getRangeType()'>" },
+      { "type": "TypesMatchWith<'func', 'args', 'cast<SMTFuncType>($_self).getDomainTypes()'>" }
     ],
     "assemblyFormat": "$func `(` $args `)` attr-dict `:` qualified(type($func))"
   },
@@ -48084,6 +48155,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48103,6 +48178,11 @@
       { "name": "equal_semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" },
       { "name": "unequal_semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'comparator', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $equal_semantics $unequal_semantics operands attr-dict `:`\n      type($pointer)"
   },
   {
@@ -48122,6 +48202,11 @@
       { "name": "equal_semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" },
       { "name": "unequal_semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'comparator', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $equal_semantics $unequal_semantics operands attr-dict `:`\n      type($pointer)"
   },
   {
@@ -48139,6 +48224,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48156,6 +48245,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48172,6 +48265,9 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48188,6 +48284,9 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48205,6 +48304,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48222,6 +48325,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48239,6 +48346,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48256,6 +48367,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48273,6 +48388,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48290,6 +48409,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -48307,6 +48430,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)"
   },
   {
@@ -49733,6 +49860,10 @@
       { "name": "memory_scope", "type": "SPIRV_ScopeAttr{CrossDevice|Device|Workgroup|Subgroup|Invocation|QueueFamily|ShaderCallKHR}" },
       { "name": "semantics", "type": "SPIRV_MemorySemanticsAttr{None|Acquire|Release|AcquireRelease|SequentiallyConsistent|UniformMemory|SubgroupMemory|WorkgroupMemory|CrossWorkgroupMemory|AtomicCounterMemory|ImageMemory|OutputMemory|MakeAvailable|MakeVisible|Volatile}" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'pointer', 'result', 'llvm::cast<PointerType>($_self).getPointeeType()'>" },
+      { "type": "TypesMatchWith<'pointer', 'value', 'llvm::cast<PointerType>($_self).getPointeeType()'>" }
+    ],
     "assemblyFormat": "$memory_scope $semantics operands attr-dict `:` type($pointer)",
     "hasCustomAssemblyFormat": true
   },
@@ -53993,7 +54124,7 @@
     ],
     "attributes": [
       { "name": "axis", "type": "ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>" },
-      { "name": "nan_mode", "type": "SPIRV_TosaExtNaNPropagationModeAttr" }
+      { "name": "nan_mode", "type": "SPIRV_TosaExtNaNPropagationModeAttr{Propagate|Ignore}" }
     ],
     "assemblyFormat": "`axis` `=` $axis `,` `nan_mode` `=` $nan_mode `,`\n    $input\n    attr-dict `:` type(operands) `->` type(results)",
     "hasCustomAssemblyFormat": true
@@ -60987,6 +61118,9 @@
     "results": [
       { "name": "result", "type": "Variadic<AnyType>" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'value', 'result', 'llvm::make_range($_self.begin(), $_self.end())'>" }
+    ],
     "assemblyFormat": "attr-dict $value `:` type($value)"
   },
   {
@@ -82185,7 +82319,7 @@
     "results": [
       { "name": "result0", "type": "I32" },
       { "name": "result1", "type": "I32" },
-      { "name": "result2", "type": "type" }
+      { "name": "result2", "type": "I32" }
     ],
     "traits": [
       { "type": "IsolatedFromAbove" }
@@ -82200,7 +82334,7 @@
     "results": [
       { "name": "result0", "type": "I64" },
       { "name": "result1", "type": "I64" },
-      { "name": "result2", "type": "type" }
+      { "name": "result2", "type": "I64" }
     ],
     "traits": [
       { "type": "IsolatedFromAbove" }
@@ -100208,6 +100342,19 @@
     ],
     "assemblyFormat": "operands attr-dict `:` functional-type(operands, results)"
   },
+  {
+    "name": "tosa.assert_equal_shape",
+    "summary": "Verify two shapes are equal.",
+    "description": "Verify input1 and input2 are equal. If allow_broadcast is set, shapes which\n      are broadcast compatible are allowed.",
+    "operands": [
+      { "name": "input1", "type": "Tosa_Shape" },
+      { "name": "input2", "type": "Tosa_Shape" }
+    ],
+    "attributes": [
+      { "name": "allow_broadcast", "type": "BoolAttr" }
+    ],
+    "assemblyFormat": "operands attr-dict `:` functional-type(operands, results)"
+  },
   {
     "name": "tosa.avg_pool2d",
     "summary": "Performs average pooling on the input.",
@@ -110577,7 +110724,9 @@
       { "name": "contiguity", "type": "DefaultValuedAttr<I32Attr, 1>" }
     ],
     "traits": [
-      { "type": "AttrSizedOperandSegments" }
+      { "type": "AttrSizedOperandSegments" },
+      { "type": "TypesMatchWith<'src', 'mask', 'getI1SameShape($_self)'>" },
+      { "type": "TypesMatchWith<'src', 'other', 'getPointeeType($_self)'>" }
     ],
     "assemblyFormat": "$src `,` $result (`mask` $mask^)? (`other` $other^)?\n    oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)\n    attr-dict `:` type($src) `->` type($result)"
   },
@@ -112165,6 +112314,7 @@
       { "name": "res2", "type": "AnyVectorOfNonZeroRank" }
     ],
     "traits": [
+      { "type": "TypesMatchWith<'source', 'res1', '[&]() -> ::mlir::VectorType {\n      auto vectorType = ::llvm::cast<mlir::VectorType>($_self);\n      ::mlir::VectorType::Builder builder(vectorType);\n      auto lastDim = vectorType.getRank() - 1;\n      auto newDimSize = vectorType.getDimSize(lastDim) / 2;;\n      if (newDimSize <= 0)\n         return vectorType; // (invalid input type)\n      return builder.setDim(lastDim, newDimSize);\n    }()'>" },
       { "type": "AllTypesMatch<['res1', 'res2']>" }
     ],
     "assemblyFormat": "$source attr-dict `:` type($source) `->` type($res1)"
@@ -112456,6 +112606,9 @@
       { "name": "kind", "type": "Vector_CombiningKindAttr{add|mul|minui|minsi|minnumf|maxui|maxsi|maxnumf|and|or|xor|maximumf|minimumf}" },
       { "name": "fastmath", "type": "DefaultValuedAttr<Arith_FastMathAttr{none|reassoc|nnan|ninf|nsz|arcp|contract|afn|fast}, ::mlir::arith::FastMathFlags::none>" }
     ],
+    "traits": [
+      { "type": "TypesMatchWith<'dest', 'acc', '::llvm::cast<Type>($_self)'>" }
+    ],
     "assemblyFormat": "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)? attr-dict `:` type($vector) `into` type($dest)"
   },
   {
@@ -117198,7 +117351,7 @@
     "operands": [
       { "name": "a", "type": "F32" },
       { "name": "b", "type": "F32" },
-      { "name": "c", "type": "type" }
+      { "name": "c", "type": "F32" }
     ],
     "results": [
       { "name": "result", "type": "F32" }
@@ -117215,7 +117368,7 @@
     "operands": [
       { "name": "a", "type": "F64" },
       { "name": "b", "type": "F64" },
-      { "name": "c", "type": "type" }
+      { "name": "c", "type": "F64" }
     ],
     "results": [
       { "name": "result", "type": "F64" }
@@ -117232,7 +117385,7 @@
     "operands": [
       { "name": "a", "type": "I32" },
       { "name": "b", "type": "I32" },
-      { "name": "c", "type": "type" }
+      { "name": "c", "type": "I32" }
     ],
     "results": [
       { "name": "result", "type": "I32" }
@@ -117249,7 +117402,7 @@
     "operands": [
       { "name": "a", "type": "I64" },
       { "name": "b", "type": "I64" },
-      { "name": "c", "type": "type" }
+      { "name": "c", "type": "I64" }
     ],
     "results": [
       { "name": "result", "type": "I64" }

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 215 - 470
source/mlir.js


+ 45 - 0
tools/mlir-script.js

@@ -768,6 +768,49 @@ const schema = async () => {
                             }
                         }
                     }
+                    // Handle classes that inherit from TypesMatchWith (e.g., PointeeTypeMatchTrait)
+                    if (traitDag && traitDag !== 'TypesMatchWith' && trait.value && trait.value.operands) {
+                        const traitClass = parser.getClass(traitDag);
+                        if (traitClass) {
+                            // Check if this class inherits from TypesMatchWith
+                            for (const classParent of traitClass.parents || []) {
+                                if (classParent.name === 'TypesMatchWith' && classParent.args && classParent.args.length >= 4) {
+                                    // Build template bindings from trait operands to class template args
+                                    const bindings = new Map();
+                                    for (let i = 0; i < traitClass.templateArgs.length && i < trait.value.operands.length; i++) {
+                                        const paramName = traitClass.templateArgs[i].name;
+                                        const argValue = trait.value.operands[i];
+                                        if (argValue && argValue.value) {
+                                            bindings.set(paramName, argValue.value.type === 'string' ? argValue.value.value : argValue.value);
+                                        }
+                                    }
+                                    // Extract from, to, transformer from TypesMatchWith parent args (indices 1, 2, 3)
+                                    const resolveArg = (arg) => {
+                                        if (!arg) {
+                                            return null;
+                                        }
+                                        if (arg.type === 'string' || arg.type === 'code') {
+                                            return arg.value;
+                                        }
+                                        if (arg.type === 'def' && typeof arg.value === 'string') {
+                                            // Check if it's a template parameter reference
+                                            return bindings.has(arg.value) ? bindings.get(arg.value) : arg.value;
+                                        }
+                                        return null;
+                                    };
+                                    const from = resolveArg(classParent.args[1]);
+                                    const to = resolveArg(classParent.args[2]);
+                                    const transformer = resolveArg(classParent.args[3]);
+                                    if (from && to && transformer) {
+                                        const traitType = `TypesMatchWith<'${from}', '${to}', '${transformer}'>`;
+                                        if (traits.every((t) => t.type !== traitType)) {
+                                            traits.push({ type: traitType });
+                                        }
+                                    }
+                                }
+                            }
+                        }
+                    }
                     if ((traitName === 'AttrSizedOperandSegments' || traitDag === 'AttrSizedOperandSegments') && traits.every((t) => t.type !== 'AttrSizedOperandSegments')) {
                         traits.push({ type: 'AttrSizedOperandSegments' });
                     }
@@ -973,6 +1016,7 @@ const test = async (pattern) => {
         'third_party/source/mlir/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir',
         'third_party/source/mlir/tensorflow/tensorflow/compiler/mlir/tfr/tests/ops.mlir',
         'third_party/source/mlir/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism_stablehlo.mlir',
+        'third_party/source/mlir/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir',
         'third_party/source/mlir/xla/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir',
         'third_party/source/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_tf_drq.mlir',
         'third_party/source/tensorflow/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir',
@@ -982,6 +1026,7 @@ const test = async (pattern) => {
         'third_party/source/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir',
         'third_party/source/tensorflow/tensorflow/compiler/mlir/tfr/tests/ops.mlir',
         'third_party/source/tensorflow/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_bounded_dynamism_stablehlo.mlir',
+        'third_party/source/tensorflow/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir',
         'third_party/source/tensorflow/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir',
         'third_party/test/mlir/sample.mlir',
     ]);

+ 38 - 4
tools/tablegen.js

@@ -939,8 +939,21 @@ tablegen.Record = class {
                 if (casesArg && casesArg.type === 'list' && Array.isArray(casesArg.value)) {
                     const cases = [];
                     for (const caseValue of casesArg.value) {
-                        // Each case is a def reference
-                        if (caseValue.type === 'def' && typeof caseValue.value === 'string') {
+                        // Each case can be either a DAG or a def reference
+                        if (caseValue.type === 'dag' && caseValue.value) {
+                            // DAG format: I32EnumAttrCase<"symbol", value>
+                            // The first operand is the symbol name
+                            const operands = caseValue.value.operands;
+                            if (operands && operands.length > 0) {
+                                const strOperand = operands[0];
+                                if (strOperand && strOperand.value) {
+                                    const str = this.evaluateValue(strOperand.value);
+                                    if (str && typeof str === 'string') {
+                                        cases.push(str);
+                                    }
+                                }
+                            }
+                        } else if (caseValue.type === 'def' && typeof caseValue.value === 'string') {
                             const caseDef = this.parser.getDef(caseValue.value) || this.parser.getClass(caseValue.value);
                             if (caseDef) {
                                 const str = caseDef.getValueAsString('str');
@@ -1517,8 +1530,29 @@ tablegen.Record = class {
                         return null;
                 }
             }
-            case 'dag':
-                return value.value;
+            case 'dag': {
+                const dag = value.value;
+                const evaluatedOperands = dag.operands.map((operand) => {
+                    if (!operand.value) {
+                        return operand;
+                    }
+                    const valType = operand.value.type;
+                    if (valType === 'def' || valType === 'id') {
+                        const refName = operand.value.value;
+                        if (this.templateBindings.has(refName)) {
+                            const evaluated = this.evaluateValue(operand.value, visited);
+                            if (typeof evaluated === 'string') {
+                                return {
+                                    value: { type: 'def', value: evaluated },
+                                    name: operand.name
+                                };
+                            }
+                        }
+                    }
+                    return operand;
+                });
+                return new tablegen.DAG(dag.operator, evaluatedOperands);
+            }
             case 'uninitialized':
                 return null;
             default:

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä