mlir_script.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import * as fs from 'fs/promises';
  2. import * as path from 'path';
  3. import * as tablegen from './tablegen.js';
  4. import * as url from 'url';
  5. class Operator {
  6. constructor(def) {
  7. this.def = def;
  8. let opInfo = null;
  9. for (const parent of this.def.parents) {
  10. const parentClass = this.def.parser.classes.get(parent.name);
  11. if (parentClass) {
  12. opInfo = this._findOpParent(parentClass, parent.args, {});
  13. if (opInfo) {
  14. break;
  15. }
  16. }
  17. }
  18. this.dialectName = opInfo?.dialect || null;
  19. this.opName = opInfo?.mnemonic || null;
  20. }
  21. getDialectName() {
  22. return this.dialectName || '';
  23. }
  24. getOperationName() {
  25. return this.dialectName && this.opName ? `${this.dialectName}.${this.opName}` : null;
  26. }
  27. _findOpParent(parentClass, parentArgs, substitutions) {
  28. const subs = { ...substitutions };
  29. if (parentClass.templateArgs && parentArgs) {
  30. for (let i = 0; i < Math.min(parentClass.templateArgs.length, parentArgs.length); i++) {
  31. const paramName = parentClass.templateArgs[i].name;
  32. const argValue = parentArgs[i];
  33. subs[paramName] = (typeof argValue === 'string' && substitutions[argValue])
  34. ? substitutions[argValue] : argValue;
  35. }
  36. }
  37. if (parentClass.name === 'Op' && parentArgs.length >= 2) {
  38. let [dialectArg, mnemonicArg] = parentArgs;
  39. if (typeof dialectArg === 'string' && subs[dialectArg]) {
  40. dialectArg = subs[dialectArg];
  41. }
  42. if (typeof mnemonicArg === 'string' && subs[mnemonicArg]) {
  43. mnemonicArg = subs[mnemonicArg];
  44. }
  45. let dialectName = null;
  46. if (typeof dialectArg === 'string') {
  47. const dialectDef = this.def.parser.defs.get(dialectArg) || this.def.parser.classes.get(dialectArg);
  48. if (dialectDef) {
  49. dialectName = dialectDef.getValueAsString('name');
  50. }
  51. }
  52. const mnemonic = typeof mnemonicArg === 'string' ? mnemonicArg.replace(/^"|"$/g, '') : null;
  53. if (dialectName && mnemonic) {
  54. return { dialect: dialectName, mnemonic };
  55. }
  56. }
  57. for (const grandparent of parentClass.parents) {
  58. const grandparentClass = this.def.parser.classes.get(grandparent.name);
  59. if (grandparentClass) {
  60. const resolvedArgs = grandparent.args.map((arg) =>
  61. (typeof arg === 'string' && subs[arg]) ? subs[arg] : arg
  62. );
  63. const result = this._findOpParent(grandparentClass, resolvedArgs, subs);
  64. if (result) {
  65. return result;
  66. }
  67. }
  68. }
  69. return null;
  70. }
  71. }
  72. const access = async (path) => {
  73. try {
  74. await fs.access(path);
  75. return true;
  76. } catch {
  77. return false;
  78. }
  79. };
  80. const main = async () => {
  81. const dirname = path.dirname(url.fileURLToPath(import.meta.url));
  82. const source = path.join(dirname, '..', 'third_party', 'source');
  83. const paths = [
  84. path.join(source, 'llvm-project', 'mlir', 'include'),
  85. path.join(source, 'stablehlo'),
  86. path.join(source, 'onnx-mlir'),
  87. path.join(source, 'torch-mlir', 'include'),
  88. path.join(source, 'tensorflow'),
  89. path.join(source, 'mlir-hlo', 'include'),
  90. path.join(source, 'iree', 'compiler', 'src')
  91. ];
  92. const dialects = [
  93. 'mlir/IR/BuiltinAttributeInterfaces.td',
  94. 'mlir/IR/BuiltinTypeInterfaces.td',
  95. 'mlir/IR/BuiltinLocationAttributes.td',
  96. 'mlir/IR/BuiltinDialect.td',
  97. 'mlir/IR/BuiltinOps.td',
  98. 'mlir/IR/BuiltinDialectBytecode.td',
  99. 'mlir/IR/BuiltinAttributes.td',
  100. 'mlir/IR/BuiltinTypes.td',
  101. 'mlir/Dialect/Affine/IR/AffineOps.td',
  102. 'mlir/Dialect/Affine/IR/AffineOps.td',
  103. 'mlir/Dialect/Func/IR/FuncOps.td',
  104. 'mlir/Dialect/Linalg/IR/LinalgOps.td',
  105. 'mlir/Dialect/MemRef/IR/MemRefOps.td',
  106. 'mlir/Dialect/Quant/IR/QuantOps.td',
  107. 'mlir/Dialect/Tensor/IR/TensorOps.td',
  108. 'mlir/Dialect/Tosa/IR/TosaOps.td',
  109. 'mlir/Dialect/Vector/IR/VectorOps.td',
  110. 'mlir/Dialect/IRDL/IR/IRDLOps.td',
  111. 'mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td',
  112. 'mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td',
  113. 'mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td',
  114. 'mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td',
  115. 'mlir/Dialect/SPIRV/IR/SPIRVBitOps.td',
  116. 'mlir/Dialect/SPIRV/IR/SPIRVCastOps.td',
  117. 'mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td',
  118. 'mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td',
  119. 'mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td',
  120. 'mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td',
  121. 'stablehlo/dialect/StablehloOps.td',
  122. 'stablehlo/dialect/ChloOps.td',
  123. 'src/Dialect/ONNX/ONNX.td',
  124. 'src/Dialect/ONNX/ONNXOps.td.inc',
  125. 'src/Dialect/ONNX/AdditionalONNXOps.td',
  126. 'torch-mlir/Dialect/Torch/IR/TorchOps.td',
  127. 'tensorflow/compiler/mlir/lite/ir/tfl_ops.td',
  128. 'tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td',
  129. 'mlir-hlo/Dialect/mhlo/IR/hlo_ops.td',
  130. 'iree/compiler/Dialect/HAL/IR/HALOps.td',
  131. 'iree/compiler/Dialect/Flow/IR/FlowOps.td',
  132. ];
  133. const file = path.join(dirname, '..', 'source', 'mlir-metadata.json');
  134. const operations = new Map();
  135. const exists = await access(file);
  136. if (exists) {
  137. const content = await fs.readFile(file, 'utf-8');
  138. const json = JSON.parse(content);
  139. for (const op of json) {
  140. if (op.name) {
  141. operations.set(op.name, op);
  142. }
  143. }
  144. }
  145. const parser = new tablegen.Reader();
  146. await parser.parse(dialects, paths);
  147. for (const [, def] of parser.defs) {
  148. const op = new Operator(def);
  149. const operationName = op.getOperationName();
  150. if (!operationName) {
  151. continue;
  152. }
  153. const metadata = {
  154. name: operationName,
  155. dialect: op.getDialectName()
  156. };
  157. const summary = def.resolveField('summary');
  158. if (summary && summary.value) {
  159. metadata.summary = summary.value.value;
  160. }
  161. const description = def.resolveField('description');
  162. if (description && description.value) {
  163. metadata.description = description.value.value;
  164. }
  165. const argsField = def.resolveField('arguments');
  166. if (argsField && argsField.value && argsField.value.type === 'dag') {
  167. const dag = argsField.value.value;
  168. if (dag.operator === 'ins') {
  169. metadata.inputs = [];
  170. metadata.attributes = [];
  171. for (const operand of dag.operands) {
  172. if (!operand.value || !operand.name) {
  173. continue;
  174. }
  175. let typeName = '';
  176. if (operand.value.type === 'def') {
  177. typeName = operand.value.value;
  178. } else {
  179. // Try to extract from other value types
  180. typeName = String(operand.value.value);
  181. }
  182. if (typeName.includes('Attr')) {
  183. metadata.attributes.push({
  184. name: operand.name,
  185. type: typeName
  186. });
  187. } else {
  188. metadata.inputs.push({
  189. name: operand.name,
  190. type: typeName
  191. });
  192. }
  193. }
  194. }
  195. }
  196. const resultsField = def.resolveField('results');
  197. if (resultsField && resultsField.value && resultsField.value.type === 'dag') {
  198. const dag = resultsField.value.value;
  199. if (dag.operator === 'outs') {
  200. metadata.outputs = [];
  201. for (const operand of dag.operands) {
  202. if (!operand.value || !operand.name) {
  203. continue;
  204. }
  205. let typeName = '';
  206. if (operand.value.type === 'def') {
  207. typeName = operand.value.value;
  208. } else {
  209. typeName = String(operand.value.value);
  210. }
  211. metadata.outputs.push({
  212. name: operand.name,
  213. type: typeName
  214. });
  215. }
  216. }
  217. }
  218. const assemblyFormatField = def.resolveField('assemblyFormat');
  219. if (assemblyFormatField && assemblyFormatField.value) {
  220. metadata.assemblyFormat = assemblyFormatField.value.value;
  221. }
  222. const regionsField = def.resolveField('regions');
  223. if (regionsField) {
  224. metadata.hasRegions = true;
  225. }
  226. const operation = {};
  227. if (metadata.name) {
  228. operation.name = metadata.name;
  229. }
  230. if (metadata.category) {
  231. operation.category = metadata.category;
  232. }
  233. if (metadata.summary) {
  234. let summary = metadata.summary.trim();
  235. summary = summary.replace(/^"|"$/g, '');
  236. if (summary) {
  237. operation.summary = summary;
  238. }
  239. }
  240. if (metadata.description) {
  241. let desc = metadata.description.trim();
  242. desc = desc.replace(/^\[\{\s*|\s*\}\]$/g, '');
  243. desc = desc.trim();
  244. if (desc) {
  245. operation.description = desc;
  246. }
  247. }
  248. if (metadata.inputs && metadata.inputs.length > 0) {
  249. operation.inputs = metadata.inputs;
  250. }
  251. if (metadata.outputs && metadata.outputs.length > 0) {
  252. operation.outputs = metadata.outputs;
  253. }
  254. if (metadata.attributes && metadata.attributes.length > 0) {
  255. operation.attributes = metadata.attributes;
  256. }
  257. if (metadata.assemblyFormat) {
  258. let format = metadata.assemblyFormat.trim();
  259. format = format.replace(/^\[\{\s*|\s*\}\]$/g, '');
  260. if (format) {
  261. operation.assemblyFormat = format;
  262. }
  263. }
  264. if (Object.keys(operation).length > 1) {
  265. if (!operation.category) {
  266. const name = operation.name.replace(/^(stablehlo|chlo|affine|linalg|memref|quant|vector|tosa|tfl|tf|onnx|torch)\./, '');
  267. if (['reshape', 'broadcast_in_dim', 'dynamic_reshape', 'Reshape', 'Shape', 'Size', 'ConstantOfShape'].includes(name)) {
  268. operation.category = 'Shape';
  269. } else if (['transpose', 'reverse', 'pad', 'Transpose', 'Pad'].includes(name)) {
  270. operation.category = 'Transform';
  271. } else if (['slice', 'dynamic_slice', 'gather', 'scatter', 'Slice', 'Gather', 'Scatter'].includes(name)) {
  272. operation.category = 'Tensor';
  273. } else if (['tanh', 'Sigmoid', 'Tanh', 'Relu', 'Softmax', 'softmax', 'sigmoid', 'relu'].includes(name)) {
  274. operation.category = 'Activation';
  275. } else if (['convolution', 'Conv', 'matmul', 'batch_matmul', 'conv2d', 'conv3d', 'fully_connected', 'conv_2d'].includes(name)) {
  276. operation.category = 'Layer';
  277. }
  278. }
  279. operations.set(operationName, operation);
  280. }
  281. }
  282. const sorted = Array.from(operations.values()).sort((a, b) => a.name.localeCompare(b.name));
  283. const output = JSON.stringify(sorted, null, 2);
  284. const formatted = output.replace(/\{\s+"name":\s+"([^"]+)",\s+"type":\s+"([^"]+)"\s+\}/g, '{ "name": "$1", "type": "$2" }');
  285. await fs.writeFile(file, formatted, 'utf-8');
  286. };
  287. await main();