mlir_script.js 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. import * as fs from 'fs/promises';
  2. import * as path from 'path';
  3. import * as tablegen from './tablegen.js';
  4. import * as url from 'url';
  5. class Operator {
  6. constructor(def) {
  7. this.def = def;
  8. let opInfo = null;
  9. for (const parent of this.def.parents) {
  10. const parentClass = this.def.parser.classes.get(parent.name);
  11. if (parentClass) {
  12. opInfo = this._findOpParent(parentClass, parent.args, {});
  13. if (opInfo) {
  14. break;
  15. }
  16. }
  17. }
  18. this.dialectName = opInfo?.dialect || null;
  19. this.opName = opInfo?.mnemonic || null;
  20. }
  21. getDialectName() {
  22. return this.dialectName || '';
  23. }
  24. getOperationName() {
  25. return this.dialectName && this.opName ? `${this.dialectName}.${this.opName}` : null;
  26. }
  27. _extractValue(arg) {
  28. // Handle both old string format and new Value object format
  29. if (typeof arg === 'string') {
  30. return arg;
  31. }
  32. if (arg && typeof arg === 'object') {
  33. if (arg.value !== undefined && arg.name !== undefined) {
  34. return this._extractValue(arg.value);
  35. }
  36. if (arg.type === 'string' && typeof arg.value === 'string') {
  37. return arg.value.replace(/^"|"$/g, '');
  38. }
  39. if (arg.type === 'def' && typeof arg.value === 'string') {
  40. return arg.value;
  41. }
  42. }
  43. return null;
  44. }
  45. _evaluateWithSubstitutions(value, subs, visited = new Set()) {
  46. if (!value) {
  47. return null;
  48. }
  49. // Detect cycles
  50. const valueKey = JSON.stringify(value);
  51. if (visited.has(valueKey)) {
  52. return null;
  53. }
  54. visited.add(valueKey);
  55. switch (value.type) {
  56. case 'string':
  57. return typeof value.value === 'string' ? value.value.replace(/^"|"$/g, '') : value.value;
  58. case 'int':
  59. return parseInt(value.value, 10);
  60. case 'concat': {
  61. const parts = value.value.map((v) => this._evaluateWithSubstitutions(v, subs, visited));
  62. const filtered = parts.filter((v) => v !== null && v !== undefined && v !== '');
  63. const result = filtered.join('');
  64. return result;
  65. }
  66. case 'id':
  67. case 'def': {
  68. const name = value.value;
  69. if (subs[name] && subs[name] !== value) {
  70. return this._evaluateWithSubstitutions(subs[name], subs, visited);
  71. }
  72. return this.def.evaluateValue(value);
  73. }
  74. default: {
  75. return this.def.evaluateValue(value);
  76. }
  77. }
  78. }
  79. _findOpParent(parentClass, parentArgs, substitutions) {
  80. const subs = { ...substitutions };
  81. if (parentClass.templateArgs && parentArgs) {
  82. for (let i = 0; i < Math.min(parentClass.templateArgs.length, parentArgs.length); i++) {
  83. const paramName = parentClass.templateArgs[i].name;
  84. const argValue = parentArgs[i];
  85. const extractedValue = this._extractValue(argValue);
  86. if (extractedValue && substitutions[extractedValue]) {
  87. subs[paramName] = substitutions[extractedValue];
  88. } else {
  89. const evaluated = this._evaluateWithSubstitutions(argValue, substitutions);
  90. subs[paramName] = evaluated === null ? argValue : subs[paramName] = { type: 'string', value: evaluated };
  91. }
  92. }
  93. }
  94. if (parentClass.name === 'Op' && parentArgs.length >= 2) {
  95. let [dialectArg, mnemonicArg] = parentArgs;
  96. if (dialectArg && dialectArg.type === 'def' && dialectArg.value && subs[dialectArg.value]) {
  97. dialectArg = subs[dialectArg.value];
  98. }
  99. if (mnemonicArg && mnemonicArg.type === 'def' && mnemonicArg.value && subs[mnemonicArg.value]) {
  100. mnemonicArg = subs[mnemonicArg.value];
  101. }
  102. let dialectName = null;
  103. const dialectStr = this._extractValue(dialectArg);
  104. if (dialectStr) {
  105. const dialectDef = this.def.parser.getDef(dialectStr) || this.def.parser.getClass(dialectStr);
  106. if (dialectDef) {
  107. dialectName = dialectDef.getValueAsString('name');
  108. }
  109. }
  110. let mnemonic = this._extractValue(mnemonicArg);
  111. if (!mnemonic && mnemonicArg) {
  112. mnemonic = this._evaluateWithSubstitutions(mnemonicArg, subs);
  113. }
  114. // Clean up mnemonic: remove leading/trailing dots, normalize multiple dots
  115. if (mnemonic && typeof mnemonic === 'string') {
  116. mnemonic = mnemonic.replace(/^\.+|\.+$/g, ''); // Remove leading/trailing dots
  117. // Skip if mnemonic is invalid after cleanup
  118. if (!mnemonic || mnemonic.includes('..')) {
  119. return null;
  120. }
  121. }
  122. if (dialectName && mnemonic) {
  123. return { dialect: dialectName, mnemonic };
  124. }
  125. }
  126. for (const grandparent of parentClass.parents) {
  127. const grandparentClass = this.def.parser.classes.get(grandparent.name);
  128. if (grandparentClass) {
  129. const resolvedArgs = grandparent.args.map((arg) => {
  130. const extracted = this._extractValue(arg);
  131. return (extracted && subs[extracted]) ? subs[extracted] : arg;
  132. });
  133. const result = this._findOpParent(grandparentClass, resolvedArgs, subs);
  134. if (result) {
  135. return result;
  136. }
  137. }
  138. }
  139. return null;
  140. }
  141. }
  142. const access = async (path) => {
  143. try {
  144. await fs.access(path);
  145. return true;
  146. } catch {
  147. return false;
  148. }
  149. };
  150. const main = async () => {
  151. const dirname = path.dirname(url.fileURLToPath(import.meta.url));
  152. const source = path.join(dirname, '..', 'third_party', 'source', 'mlir');
  153. const paths = [
  154. path.join(source, 'llvm-project', 'mlir', 'include'),
  155. path.join(source, 'llvm-project', 'mlir', 'test', 'lib', 'Dialect', 'Transform'),
  156. path.join(source, 'llvm-project', 'mlir', 'include', 'mlir', 'Dialect', 'ArmNeon'),
  157. path.join(source, 'llvm-project', 'mlir', 'include', 'mlir', 'Dialect', 'ArmSME', 'IR'),
  158. path.join(source, 'llvm-project', 'mlir', 'include', 'mlir', 'Dialect', 'ArmSVE', 'IR'),
  159. path.join(source, 'llvm-project', 'mlir', 'examples', 'toy', 'Ch7', 'include'),
  160. path.join(source, 'stablehlo'),
  161. path.join(source, 'onnx-mlir'),
  162. path.join(source, 'torch-mlir', 'include'),
  163. path.join(source, 'triton', 'include'),
  164. path.join(source, 'triton', 'third_party'),
  165. path.join(source, 'triton', 'third_party', 'amd', 'include', 'Dialect', 'TritonAMDGPU', 'IR'),
  166. path.join(source, 'mlir-hlo', 'include'),
  167. path.join(source, 'iree', 'compiler', 'src'),
  168. path.join(source, 'FlashTensor', 'include'),
  169. path.join(source, 'tpu-mlir', 'include'),
  170. path.join(source, 'tensorflow'),
  171. path.join(source, 'tensorflow', 'tensorflow', 'compiler', 'mlir', 'tfrt', 'ir'),
  172. path.join(source, 'runtime', 'include'),
  173. path.join(source, 'plaidml'),
  174. path.join(source, 'plaidml', 'pmlc', 'dialect', 'pxa', 'ir'),
  175. path.join(source, 'mlir-dace', 'include'),
  176. path.join(source, 'lltz', 'mlir', 'dialect', 'include', 'Michelson'),
  177. path.join(source, 'lagrad', 'include', 'LAGrad'),
  178. path.join(source, 'TensorRT-Incubator', 'mlir-tensorrt', 'tensorrt', 'include'),
  179. path.join(source, 'TensorRT-Incubator', 'mlir-tensorrt', 'executor', 'include'),
  180. ];
  181. const dialects = [
  182. 'mlir/IR/BuiltinAttributeInterfaces.td',
  183. 'mlir/IR/BuiltinTypeInterfaces.td',
  184. 'mlir/IR/BuiltinLocationAttributes.td',
  185. 'mlir/IR/BuiltinDialect.td',
  186. 'mlir/IR/BuiltinOps.td',
  187. 'mlir/IR/BuiltinDialectBytecode.td',
  188. 'mlir/IR/BuiltinAttributes.td',
  189. 'mlir/IR/BuiltinTypes.td',
  190. 'mlir/Dialect/Async/IR/AsyncOps.td',
  191. 'mlir/Dialect/Affine/IR/AffineOps.td',
  192. 'mlir/Dialect/Affine/IR/AffineOps.td',
  193. 'mlir/Dialect/Arith/IR/ArithOps.td',
  194. 'mlir/Dialect/ControlFlow/IR/ControlFlowOps.td',
  195. 'mlir/Dialect/Func/IR/FuncOps.td',
  196. 'mlir/Dialect/GPU/IR/GPUOps.td',
  197. 'mlir/Dialect/SCF/IR/SCFOps.td',
  198. 'mlir/Dialect/Linalg/IR/LinalgOps.td',
  199. // 'mlir/Dialect/Linalg/IR/LinalgStructuredOps.td', // File not found 'mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td'
  200. 'mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td',
  201. 'mlir/Dialect/MemRef/IR/MemRefOps.td',
  202. 'mlir/Dialect/Bufferization/IR/BufferizationOps.td',
  203. 'mlir/Dialect/Quant/IR/QuantOps.td',
  204. 'mlir/Dialect/Shape/IR/ShapeOps.td',
  205. 'mlir/Dialect/SparseTensor/IR/SparseTensorOps.td',
  206. 'mlir/Dialect/Tensor/IR/TensorOps.td',
  207. 'mlir/Dialect/Tosa/IR/TosaOps.td',
  208. 'mlir/Dialect/Vector/IR/VectorOps.td',
  209. 'mlir/Dialect/X86Vector/X86Vector.td',
  210. 'mlir/Dialect/XeGPU/IR/XeGPUOps.td',
  211. 'mlir/Dialect/Transform/IR/TransformOps.td',
  212. 'mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td',
  213. 'mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td',
  214. 'mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td',
  215. 'mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td',
  216. 'mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td',
  217. 'mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td',
  218. 'TestTransformDialectExtension.td',
  219. 'iree/compiler/Dialect/Util/TransformOps/UtilTransformOps.td',
  220. 'mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td',
  221. 'mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td',
  222. 'mlir/Dialect/SCF/TransformOps/SCFTransformOps.td',
  223. 'mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td',
  224. 'mlir/Dialect/GPU/TransformOps/GPUTransformOps.td',
  225. 'mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td',
  226. 'mlir/Dialect/Affine/TransformOps/AffineTransformOps.td',
  227. 'mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td',
  228. 'mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td',
  229. 'mlir/Dialect/Vector/TransformOps/VectorTransformOps.td',
  230. 'mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td',
  231. 'mlir/Dialect/Func/TransformOps/FuncTransformOps.td',
  232. 'mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td',
  233. 'mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td',
  234. 'mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td',
  235. 'mlir/Dialect/DLTI/TransformOps/DLTITransformOps.td',
  236. 'mlir/Dialect/WasmSSA/IR/WasmSSAOps.td',
  237. 'mlir/Dialect/IRDL/IR/IRDLOps.td',
  238. 'mlir/Dialect/LLVMIR/LLVMOps.td',
  239. 'mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td',
  240. 'mlir/Dialect/LLVMIR/NVVMOps.td',
  241. 'mlir/Dialect/LLVMIR/ROCDLOps.td',
  242. // 'mlir/Dialect/OpenMP/OpenMPOps.td', // File not found 'mlir/Dialect/OpenMP/OmpCommon.td'
  243. 'mlir/Dialect/ArmSME/IR/ArmSMEOps.td',
  244. 'mlir/Dialect/ArmNeon/ArmNeon.td',
  245. 'mlir/Dialect/ArmSVE/IR/ArmSVE.td',
  246. 'mlir/Dialect/Math/IR/MathOps.td',
  247. 'mlir/Dialect/MLProgram/IR/MLProgramOps.td',
  248. 'mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td',
  249. 'mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td',
  250. 'mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td',
  251. 'mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td',
  252. 'mlir/Dialect/SPIRV/IR/SPIRVBitOps.td',
  253. 'mlir/Dialect/SPIRV/IR/SPIRVCastOps.td',
  254. 'mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td',
  255. 'mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td',
  256. 'mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td',
  257. 'mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td',
  258. 'mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td',
  259. 'mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td',
  260. 'mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td',
  261. 'mlir/Dialect/SPIRV/IR/SPIRVImageOps.td',
  262. 'mlir/Dialect/SPIRV/IR/SPIRVGLOps.td',
  263. 'mlir/Dialect/SPIRV/IR/SPIRVCLOps.td',
  264. 'mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td',
  265. 'mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td',
  266. 'mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td',
  267. 'mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td',
  268. 'mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td',
  269. 'mlir/Dialect/SPIRV/IR/SPIRVMeshOps.td',
  270. 'mlir/Dialect/SPIRV/IR/SPIRVPrimitiveOps.td',
  271. 'mlir/Dialect/EmitC/IR/EmitC.td',
  272. 'mlir/Dialect/Complex/IR/ComplexOps.td',
  273. 'mlir/Dialect/Index/IR/IndexOps.td',
  274. 'mlir/Dialect/PDL/IR/PDLOps.td',
  275. 'mlir/Dialect/Ptr/IR/PtrOps.td',
  276. 'mlir/Dialect/UB/IR/UBOps.td',
  277. 'mlir/Dialect/AMDGPU/IR/AMDGPU.td',
  278. 'mlir/Dialect/NVGPU/IR/NVGPUOps.td',
  279. 'mlir/Dialect/Shard/IR/ShardOps.td',
  280. 'mlir/Dialect/AMX/AMX.td',
  281. 'mlir/Dialect/SMT/IR/SMTOps.td',
  282. 'mlir/Dialect/SMT/IR/SMTArrayOps.td',
  283. 'mlir/Dialect/SMT/IR/SMTBitVectorOps.td',
  284. 'mlir/Dialect/SMT/IR/SMTIntOps.td',
  285. // 'mlir/Dialect/OpenACC/OpenACCOps.td', // File not found 'mlir/Dialect/OpenACC/AccCommon.td' (requires full LLVM tree to generate)
  286. 'mlir/Dialect/LLVMIR/XeVMOps.td',
  287. 'toy/Ops.td',
  288. 'stablehlo/dialect/StablehloOps.td',
  289. 'stablehlo/dialect/ChloOps.td',
  290. 'stablehlo/dialect/VhloOps.td',
  291. 'stablehlo/reference/InterpreterOps.td',
  292. 'stablehlo/tests/CheckOps.td',
  293. 'src/Dialect/ONNX/ONNX.td',
  294. 'src/Dialect/ONNX/ONNXOps.td.inc',
  295. 'src/Dialect/ONNX/AdditionalONNXOps.td',
  296. 'src/Dialect/Krnl/Krnl.td',
  297. 'torch-mlir/Dialect/Torch/IR/TorchOps.td',
  298. 'torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td',
  299. 'torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td',
  300. 'tensorflow/compiler/mlir/lite/ir/tfl_ops.td',
  301. 'tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td',
  302. 'tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td',
  303. 'tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td',
  304. 'tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td',
  305. 'tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td',
  306. 'tensorflow/compiler/mlir/tfr/ir/tfr_ops.td',
  307. 'tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td',
  308. 'tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td',
  309. 'tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td',
  310. 'tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.td',
  311. 'tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td',
  312. 'tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.td',
  313. 'tfrt/core_runtime/opdefs/core_runtime.td',
  314. 'tfrt/basic_kernels/opdefs/basic_kernels.td',
  315. 'tfrt/test_kernels/opdefs/test_kernels.td',
  316. 'tfrt/tensor/opdefs/tensor.td',
  317. 'tfrt/tensor/opdefs/dense_host_tensor.td',
  318. 'mlir-hlo/Dialect/mhlo/IR/hlo_ops.td',
  319. 'iree/compiler/Dialect/HAL/IR/HALOps.td',
  320. 'iree/compiler/Dialect/HAL/IR/HALTypes.td',
  321. 'iree/compiler/Modules/HAL/Loader/IR/HALLoaderOps.td',
  322. 'iree/compiler/Modules/HAL/Inline/IR/HALInlineOps.td',
  323. 'iree/compiler/Dialect/Flow/IR/FlowOps.td',
  324. 'iree/compiler/Dialect/Stream/IR/StreamOps.td',
  325. 'iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td',
  326. 'iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td',
  327. 'iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.td',
  328. 'iree/compiler/Dialect/TensorExt/IR/TensorExtOps.td',
  329. 'iree/compiler/Dialect/Util/IR/UtilOps.td',
  330. 'iree/compiler/Dialect/VM/IR/VMOps.td',
  331. 'iree/compiler/Dialect/VMVX/IR/VMVXOps.td',
  332. 'iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.td',
  333. 'iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.td',
  334. 'iree/compiler/Dialect/Encoding/IR/EncodingOps.td',
  335. 'asuka/Dialect/Asuka/IR/AsukaOps.td',
  336. 'tpu_mlir/Dialect/Top/IR/TopOps.td',
  337. 'tpu_mlir/Dialect/Tpu/IR/TpuOps.td',
  338. 'pmlc/dialect/tile/ir/ops.td',
  339. 'pmlc/dialect/stdx/ir/ops.td',
  340. // 'pmlc/dialect/pxa/ir/ops.td', // File not found 'mlir/Dialect/Arithmetic/IR/ArithmeticBase.td'
  341. 'SDFG/Dialect/Ops.td',
  342. 'MichelsonOps.td',
  343. 'triton/Dialect/Triton/IR/TritonOps.td',
  344. 'triton/Dialect/TritonGPU/IR/TritonGPUOps.td',
  345. 'triton/Dialect/Gluon/IR/GluonOps.td',
  346. 'triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td',
  347. 'amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td',
  348. 'proton/Dialect/include/Dialect/Proton/IR/ProtonOps.td',
  349. 'LAGradOps.td',
  350. 'mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td',
  351. 'mlir-executor/Executor/IR/ExecutorOps.td',
  352. ];
  353. const file = path.join(dirname, '..', 'source', 'mlir-metadata.json');
  354. const operations = new Map();
  355. const exists = await access(file);
  356. if (exists) {
  357. const content = await fs.readFile(file, 'utf-8');
  358. const json = JSON.parse(content);
  359. for (const op of json) {
  360. if (op.name.endsWith('.') || op.name.includes('..') || op.name.includes('#')) {
  361. throw new Error(`Invalid operation name '${op.name}'.`);
  362. }
  363. if (op.name && !op.name.endsWith('.')) {
  364. operations.set(op.name, op);
  365. }
  366. }
  367. }
  368. const parser = new tablegen.Reader();
  369. await parser.parse(dialects, paths);
  370. for (const def of parser.defs) {
  371. const op = new Operator(def);
  372. const operationName = op.getOperationName();
  373. if (!operationName) {
  374. continue;
  375. }
  376. if (operationName.endsWith('.') || operationName.includes('..') || operationName.includes('#')) {
  377. throw new Error(`Invalid operation name '${operationName}'.`);
  378. }
  379. const operation = {
  380. name: operationName
  381. };
  382. let args = def.resolveField('arguments');
  383. // If the field value needs evaluation (e.g., it's a computed field), evaluate it
  384. if (args && args.value && (args.value.type === 'id' || args.value.type === 'bang')) {
  385. const evaluated = def.evaluateValue(args.value);
  386. if (evaluated && typeof evaluated === 'object' && evaluated.operator) {
  387. // The evaluation returned a DAG directly
  388. args = { value: new tablegen.Value('dag', evaluated) };
  389. }
  390. }
  391. if (!args || !args.value || args.value.type !== 'dag' || (args.value.value && args.value.value.operands && args.value.value.operands.length === 0)) {
  392. for (const parent of def.parents) {
  393. if (parent.name === 'Arguments' && parent.args && parent.args.length > 0) {
  394. const [dagValue] = parent.args;
  395. if (dagValue && dagValue.type === 'dag') {
  396. args = { value: dagValue };
  397. }
  398. break;
  399. }
  400. }
  401. }
  402. const name = operation.name.replace(/^(asuka|stablehlo|chlo|affine|linalg|memref|quant|vector|tosa|tfl|tf|onnx|torch\.aten|gpu)\./, '');
  403. if (['reshape', 'broadcast_in_dim', 'dynamic_reshape', 'Reshape', 'Shape', 'Size', 'ConstantOfShape'].indexOf(name) !== -1) {
  404. operation.category = 'Shape';
  405. } else if (['transpose', 'reverse', 'pad', 'Transpose', 'Pad'].indexOf(name) !== -1) {
  406. operation.category = 'Transform';
  407. } else if (['slice', 'split', 'dynamic_slice', 'gather', 'scatter', 'Slice', 'Gather', 'Scatter', 'concatenate'].indexOf(name) !== -1) {
  408. operation.category = 'Tensor';
  409. } else if (['tanh', 'Sigmoid', 'Tanh', 'Relu', 'Softmax', 'softmax', 'sigmoid', 'relu'].indexOf(name) !== -1) {
  410. operation.category = 'Activation';
  411. } else if (['convolution', 'Conv', 'matmul', 'batch_matmul', 'conv2d', 'conv3d', 'fully_connected', 'conv_2d'].indexOf(name) !== -1) {
  412. operation.category = 'Layer';
  413. } else if (['batch_norm_inference'].includes(name)) {
  414. operation.category = 'Normalization';
  415. }
  416. const summary = def.resolveField('summary');
  417. if (summary && summary.value) {
  418. const value = def.evaluateValue(summary.value);
  419. if (value) {
  420. let summary = value.trim();
  421. summary = summary.replace(/^"|"$/g, '');
  422. if (summary) {
  423. operation.summary = summary;
  424. }
  425. }
  426. }
  427. const description = def.resolveField('description');
  428. if (description && description.value) {
  429. const value = def.evaluateValue(description.value);
  430. if (value) {
  431. let desc = value.trim();
  432. desc = desc.replace(/^\[\{\s*|\s*\}\]$/g, '');
  433. desc = desc.trim();
  434. if (desc) {
  435. operation.description = desc;
  436. }
  437. }
  438. }
  439. const attributes = [];
  440. const inputs = [];
  441. const outputs = [];
  442. if (args && args.value && args.value.type === 'dag') {
  443. const dag = args.value.value;
  444. if (dag.operator === 'ins') {
  445. for (const operand of dag.operands) {
  446. if (!operand.value || !operand.name) {
  447. continue;
  448. }
  449. let typeName = '';
  450. if (operand.value.type === 'def') {
  451. typeName = operand.value.value;
  452. } else {
  453. // Try to extract from other value types
  454. typeName = String(operand.value.value);
  455. }
  456. if (typeName.includes('Attr')) {
  457. attributes.push({
  458. name: operand.name,
  459. type: typeName
  460. });
  461. } else {
  462. inputs.push({
  463. name: operand.name,
  464. type: typeName
  465. });
  466. }
  467. }
  468. }
  469. }
  470. let results = def.resolveField('results');
  471. if (!results || !results.value || results.value.type !== 'dag' || (results.value.value && results.value.value.operands && results.value.value.operands.length === 0)) {
  472. for (const parent of def.parents) {
  473. if (parent.name === 'Results' && parent.args && parent.args.length > 0) {
  474. const [dagValue] = parent.args;
  475. if (dagValue && dagValue.type === 'dag') {
  476. results = { value: dagValue };
  477. }
  478. break;
  479. }
  480. }
  481. }
  482. if (results && results.value && results.value.type === 'dag') {
  483. const dag = results.value.value;
  484. if (dag.operator === 'outs') {
  485. for (const operand of dag.operands) {
  486. if (!operand.value || !operand.name) {
  487. continue;
  488. }
  489. if (operand.value.type !== 'def') {
  490. throw new Error('Unexpected result operand value type');
  491. }
  492. const type = operand.value.value;
  493. outputs.push({ name: operand.name, type });
  494. }
  495. }
  496. }
  497. if (inputs.length > 0) {
  498. operation.inputs = inputs;
  499. }
  500. if (outputs.length > 0) {
  501. operation.outputs = outputs;
  502. }
  503. if (attributes.length > 0) {
  504. operation.attributes = attributes;
  505. }
  506. const successors = def.resolveField('successors');
  507. if (successors && successors.value && successors.value.type === 'dag') {
  508. const dag = successors.value.value;
  509. if (dag.operator === 'successor') {
  510. const successors = [];
  511. for (const operand of dag.operands) {
  512. if (operand.name) {
  513. successors.push({ name: operand.name });
  514. }
  515. }
  516. if (successors.length > 0) {
  517. operation.successors = successors;
  518. }
  519. }
  520. }
  521. const assemblyFormat = def.resolveField('assemblyFormat');
  522. if (assemblyFormat && assemblyFormat.value) {
  523. const value = def.evaluateValue(assemblyFormat.value);
  524. if (value) {
  525. const format = value.trim().replace(/^\[\{\s*|\s*\}\]$/g, '');
  526. if (format) {
  527. operation.assemblyFormat = format;
  528. }
  529. }
  530. }
  531. const hasCustomAssemblyFormat = def.resolveField('hasCustomAssemblyFormat');
  532. if (hasCustomAssemblyFormat && hasCustomAssemblyFormat.value) {
  533. operation.hasCustomAssemblyFormat = def.evaluateValue(hasCustomAssemblyFormat.value);
  534. }
  535. const parser = def.resolveField('parser');
  536. if (parser && parser.value) {
  537. operation.parser = 1;
  538. }
  539. // Extract defaultDialect from OpAsmOpInterface
  540. for (const parent of def.parents) {
  541. const possibleTraitArgs = parent.args && parent.args.length >= 2 ? [parent.args[1], parent.args[2]] : [];
  542. for (const traitsArg of possibleTraitArgs) {
  543. if (traitsArg && traitsArg.type === 'list' && traitsArg.value) {
  544. for (const trait of traitsArg.value) {
  545. const traitName = trait.type === 'def' ? trait.value : null;
  546. const traitDag = trait.type === 'dag' && trait.value?.operator ? trait.value.operator : null;
  547. if (traitName === 'OpAsmOpInterface' || traitDag === 'DeclareOpInterfaceMethods') {
  548. if (traitDag === 'DeclareOpInterfaceMethods' && trait.value?.operands) {
  549. if (trait.value.operands.some((operand) => operand.value && operand.value.type === 'list' && operand.value.value.some((method) => method.type === 'string' && method.value === 'getDefaultDialect'))) {
  550. const [dialectName] = operationName.split('.');
  551. operation.defaultDialect = dialectName;
  552. break;
  553. }
  554. }
  555. const extraClass = def.resolveField('extraClassDeclaration');
  556. if (extraClass && extraClass.value) {
  557. const code = def.evaluateValue(extraClass.value);
  558. if (code && typeof code === 'string') {
  559. const match = code.match(/getDefaultDialect\(\)\s*\{\s*return\s+"(\w+)"/);
  560. if (match) {
  561. [, operation.defaultDialect] = match;
  562. break;
  563. }
  564. }
  565. }
  566. }
  567. }
  568. if (operation.defaultDialect) {
  569. break;
  570. }
  571. }
  572. }
  573. if (operation.defaultDialect) {
  574. break;
  575. }
  576. }
  577. // Only add operation if it has meaningful data beyond just the name
  578. if (Object.keys(operation).length > 1) {
  579. operations.set(operationName, operation);
  580. }
  581. }
  582. const sorted = Array.from(operations.values()).sort((a, b) => a.name.localeCompare(b.name));
  583. const output = JSON.stringify(sorted, null, 2);
  584. const formatted = output.replace(/\{\s+"name":\s+"([^"]+)",\s+"type":\s+"([^"]+)"\s+\}/g, '{ "name": "$1", "type": "$2" }');
  585. await fs.writeFile(file, formatted, 'utf-8');
  586. };
  587. await main();