mslite-script.js 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import * as flatc from './flatc.js';
  2. import * as fs from 'fs/promises';
  3. import * as path from 'path';
  4. import * as url from 'url';
  5. const main = async () => {
  6. const dirname = path.dirname(url.fileURLToPath(import.meta.url));
  7. const schema = path.join(dirname, '..', 'third_party', 'source', 'mindspore', 'mindspore', 'lite', 'schema', 'ops.fbs');
  8. const file = path.join(dirname, '..', 'source', 'mslite-metadata.json');
  9. const input = await fs.readFile(file, 'utf-8');
  10. const json = JSON.parse(input);
  11. const operators = new Map();
  12. const attributes = new Map();
  13. for (const operator of json) {
  14. if (operators.has(operator.name)) {
  15. throw new Error(`Duplicate operator '${operator.name}'.`);
  16. }
  17. operators.set(operator.name, operator);
  18. if (operator && operator.attributes) {
  19. for (const attribute of operator.attributes) {
  20. const name = `${operator.name}:${attribute.name}`;
  21. attributes.set(name, attribute);
  22. }
  23. }
  24. }
  25. const root = new flatc.Root('mslite');
  26. await root.load([], [schema]);
  27. const namespace = root.find('mindspore.schema', flatc.Namespace);
  28. const primitiveType = namespace.find('mindspore.schema.PrimitiveType', flatc.Type);
  29. for (const value of primitiveType.values) {
  30. const table = value.type;
  31. const op_key = table.name;
  32. if (!operators.has(op_key)) {
  33. const operator = { name: op_key };
  34. operators.set(op_key, operator);
  35. json.push(operator);
  36. }
  37. const operator = operators.get(op_key);
  38. if (table && table.fields.size > 0) {
  39. operator.attributes = operator.attributes || [];
  40. const inputs = operator.inputs;
  41. const outputs = operator.outputs;
  42. delete operator.inputs;
  43. delete operator.outputs;
  44. if (inputs) {
  45. operator.inputs = inputs;
  46. }
  47. if (outputs) {
  48. operator.outputs = outputs;
  49. }
  50. for (const field of table.fields.values()) {
  51. const attr_key = `${op_key}:${field.name}`;
  52. if (!attributes.has(attr_key)) {
  53. const attribute = { name: field.name };
  54. attributes.set(attr_key, attribute);
  55. operator.attributes.push(attribute);
  56. }
  57. const attribute = attributes.get(attr_key);
  58. const type = field.type;
  59. let defaultValue = field.defaultValue;
  60. if (type instanceof flatc.Enum) {
  61. if (!type.keys.has(defaultValue)) {
  62. throw new Error(`Invalid '${type.name}' default value '${defaultValue}'.`);
  63. }
  64. defaultValue = type.keys.get(defaultValue);
  65. }
  66. attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : '');
  67. if (attribute.default === undefined) {
  68. attribute.default = defaultValue;
  69. }
  70. }
  71. }
  72. }
  73. json.sort((a, b) => a.name.localeCompare(b.name));
  74. let output = JSON.stringify(json, null, 2);
  75. output = output.replace(/\s {8}/g, ' ');
  76. output = output.replace(/,\s {8}/g, ', ');
  77. output = output.replace(/\s {6}}/g, ' }');
  78. await fs.writeFile(file, output, 'utf-8');
  79. };
  80. await main();