megengine-script.js 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import * as flatc from './flatc.js';
  2. import * as fs from 'fs/promises';
  3. import * as path from 'path';
  4. import * as url from 'url';
  5. const main = async () => {
  6. const dirname = path.dirname(url.fileURLToPath(import.meta.url));
  7. const schema = path.join(dirname, '..', 'third_party', 'source', 'megengine', 'src', 'serialization', 'fbs', 'schema_v2.fbs');
  8. const file = path.join(dirname, '..', 'source', 'megengine-metadata.json');
  9. const input = await fs.readFile(file, 'utf-8');
  10. const json = JSON.parse(input);
  11. const category = {
  12. Host2DeviceCopy: 'Data',
  13. Dimshuffle: 'Shape',
  14. Flip: 'Shape',
  15. Images2Neibs: 'Shape',
  16. Reshape: 'Shape',
  17. Concat: 'Tensor',
  18. GetVarShape: 'Shape',
  19. Subtensor: 'Tensor',
  20. Padding: 'Layer',
  21. AdaptivePooling: 'Activation',
  22. ConvPooling: 'Pool',
  23. TQT: 'Quantization',
  24. LSQ: 'Quantization',
  25. Pooling: 'Pool',
  26. PoolingForward: 'Pool',
  27. AdaptivePoolingForward: 'Pool',
  28. SlidingWindowTranspose: 'Transform',
  29. LRN: 'Normalization',
  30. BatchNormForward: 'Normalization',
  31. BN: 'Normalization',
  32. LayerNorm: 'Normalization',
  33. Convolution: 'Layer',
  34. ConvolutionForward: 'Layer',
  35. Convolution3D: 'Layer',
  36. SeparableConv: 'Layer',
  37. SeparableConv3D: 'Layer',
  38. ConvBiasForward: 'Layer',
  39. ConvBias: 'Layer',
  40. Conv3DBias: 'Layer',
  41. Dropout: 'Dropout',
  42. Softmax: 'Activation',
  43. RNN: 'Layer',
  44. RNNCell: 'Layer',
  45. LSTM: 'Layer'
  46. };
  47. const operators = new Map();
  48. const attributes = new Map();
  49. for (const operator of json) {
  50. if (operators.has(operator.name)) {
  51. throw new Error(`Duplicate operator '${operator.name}'.`);
  52. }
  53. operators.set(operator.name, operator);
  54. if (operator && operator.attributes) {
  55. for (const attribute of operator.attributes) {
  56. const name = `${operator.name}:${attribute.name}`;
  57. attributes.set(name, attribute);
  58. }
  59. }
  60. }
  61. const root = new flatc.Root('megengine');
  62. await root.load([], [schema]);
  63. const namespace = root.find('mgb.serialization.fbs.param', flatc.Namespace);
  64. const operatorParams = namespace.children;
  65. for (const [name, op] of operatorParams) {
  66. if (op instanceof flatc.Enum) {
  67. continue;
  68. }
  69. if (op && op.fields.size > 0) {
  70. if (!operators.has(name)) {
  71. const operator = { name };
  72. operators.set(name, operator);
  73. json.push(operator);
  74. }
  75. const operator = operators.get(name);
  76. const k = name.replace(/V\d+$/, '');
  77. if (category[k]) {
  78. operator.category = category[k];
  79. }
  80. operator.attributes = operator.attributes || [];
  81. for (const [field_name, field] of op.fields) {
  82. const attr_key = `${name}:${field_name}`;
  83. if (!attributes.has(attr_key)) {
  84. const attribute = { name: field_name };
  85. attributes.set(attr_key, attribute);
  86. operator.attributes.push(attribute);
  87. }
  88. const attribute = attributes.get(attr_key);
  89. const type = field.type;
  90. let defaultValue = field.defaultValue;
  91. if (type instanceof flatc.Enum) {
  92. if (!type.keys.has(defaultValue)) {
  93. throw new Error(`Invalid '${type.name}' default value '${defaultValue}'.`);
  94. }
  95. defaultValue = type.keys.get(defaultValue);
  96. }
  97. attribute.type = type.name + (field.repeated ? '[]' : '');
  98. attribute.default = defaultValue;
  99. }
  100. }
  101. }
  102. // json.sort((a, b) => a.name.localeCompare(b.name))
  103. let output = JSON.stringify(json, null, 2);
  104. output = output.replace(/\s {8}/g, ' ');
  105. output = output.replace(/,\s {8}/g, ', ');
  106. output = output.replace(/\s {6}}/g, ' }');
  107. await fs.writeFile(file, output, 'utf-8');
  108. };
  109. await main();