tflite-script.js 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import * as flatc from './flatc.js';
  2. import * as fs from 'fs/promises';
  3. import * as path from 'path';
  4. import * as url from 'url';
  5. /* eslint-disable no-extend-native */
  6. BigInt.prototype.toNumber = function() {
  7. if (this > Number.MAX_SAFE_INTEGER || this < Number.MIN_SAFE_INTEGER) {
  8. throw new Error('64-bit value exceeds safe integer.');
  9. }
  10. return Number(this);
  11. };
  12. /* eslint-enable no-extend-native */
  13. const main = async () => {
  14. const dirname = path.dirname(url.fileURLToPath(import.meta.url));
  15. const schema = path.join(dirname, '..', 'third_party', 'source', 'tensorflow', 'tensorflow', 'compiler', 'mlir', 'lite', 'schema', 'schema.fbs');
  16. const file = path.join(dirname, '..', 'source', 'tflite-metadata.json');
  17. const input = await fs.readFile(file, 'utf-8');
  18. const json = JSON.parse(input);
  19. const operators = new Map();
  20. const attributes = new Map();
  21. for (const operator of json) {
  22. if (operators.has(operator.name)) {
  23. throw new Error(`Duplicate operator '${operator.name}'.`);
  24. }
  25. operators.set(operator.name, operator);
  26. if (operator && operator.attributes) {
  27. for (const attribute of operator.attributes) {
  28. const name = `${operator.name}:${attribute.name}`;
  29. attributes.set(name, attribute);
  30. }
  31. }
  32. }
  33. const root = new flatc.Root('tflite');
  34. await root.load([], [schema]);
  35. const namespace = root.find('tflite', flatc.Namespace);
  36. const builtOperator = namespace.find('tflite.BuiltinOperator', flatc.Type);
  37. const upperCase = new Set(['2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM']);
  38. for (const op of builtOperator.values.keys()) {
  39. let op_key = op === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : op;
  40. op_key = op_key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
  41. const table = namespace.find(`tflite.${op_key}Options`, flatc.Type);
  42. if (table && table.fields.size > 0) {
  43. if (!operators.has(op_key)) {
  44. const operator = { name: op_key };
  45. operators.set(op_key, operator);
  46. json.push(operator);
  47. }
  48. const operator = operators.get(op_key);
  49. operator.attributes = operator.attributes || [];
  50. for (const field of table.fields.values()) {
  51. const attr_key = `${op_key}:${field.name}`;
  52. if (!attributes.has(attr_key)) {
  53. const attribute = { name: field.name };
  54. attributes.set(attr_key, attribute);
  55. operator.attributes.push(attribute);
  56. }
  57. const attribute = attributes.get(attr_key);
  58. const type = field.type;
  59. let defaultValue = field.defaultValue;
  60. if (typeof defaultValue === 'bigint') {
  61. defaultValue = defaultValue.toNumber();
  62. }
  63. if (type instanceof flatc.Enum) {
  64. if (!type.keys.has(defaultValue)) {
  65. throw new Error(`Invalid '${type.name}' default value '${defaultValue}'.`);
  66. }
  67. defaultValue = type.keys.get(defaultValue);
  68. }
  69. attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : '');
  70. attribute.default = defaultValue;
  71. }
  72. }
  73. }
  74. json.sort((a, b) => a.name.localeCompare(b.name));
  75. let output = JSON.stringify(json, null, 2);
  76. output = output.replace(/\s {8}/g, ' ');
  77. output = output.replace(/,\s {8}/g, ', ');
  78. output = output.replace(/\s {6}}/g, ' }');
  79. await fs.writeFile(file, output, 'utf-8');
  80. };
  81. await main();