tflite_metadata.js 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. const path = require('path');
  2. const flatc = require('./flatc');
  3. const fs = require('fs').promises;
  4. const main = async () => {
  5. const schema = path.join(__dirname, '..', 'third_party', 'source', 'tensorflow', 'tensorflow', 'lite', 'schema', 'schema.fbs');
  6. const file = path.join(__dirname, '..', 'source', 'tflite-metadata.json');
  7. const input = await fs.readFile(file, 'utf-8');
  8. const json = JSON.parse(input);
  9. const operators = new Map();
  10. const attributes = new Map();
  11. for (const operator of json) {
  12. if (operators.has(operator.name)) {
  13. throw new Error("Duplicate operator '" + operator.name + "'.");
  14. }
  15. operators.set(operator.name, operator);
  16. if (operator && operator.attributes) {
  17. for (const attribute of operator.attributes) {
  18. const name = operator.name + ':' + attribute.name;
  19. attributes.set(name, attribute);
  20. }
  21. }
  22. }
  23. const root = new flatc.Root('tflite');
  24. await root.load([], [ schema ]);
  25. const namespace = root.find('tflite', flatc.Namespace);
  26. const builtOperator = namespace.find('tflite.BuiltinOperator', flatc.Type);
  27. const upperCase = new Set([ '2D', 'LSH', 'SVDF', 'RNN', 'L2', 'LSTM' ]);
  28. for (const op of builtOperator.values.keys()) {
  29. let op_key = op === 'BATCH_MATMUL' ? 'BATCH_MAT_MUL' : op;
  30. op_key = op_key.split('_').map((s) => (s.length < 1 || upperCase.has(s)) ? s : s[0] + s.substring(1).toLowerCase()).join('');
  31. const table = namespace.find('tflite.' + op_key + 'Options', flatc.Type);
  32. if (table && table.fields.size > 0) {
  33. if (!operators.has(op_key)) {
  34. const operator = { name: op_key };
  35. operators.set(op_key, operator);
  36. json.push(operator);
  37. }
  38. const operator = operators.get(op_key);
  39. operator.attributes = operator.attributes || [];
  40. for (const field of table.fields.values()) {
  41. const attr_key = op_key + ':' + field.name;
  42. if (!attributes.has(attr_key)) {
  43. const attribute = { name: field.name };
  44. attributes.set(attr_key, attribute);
  45. operator.attributes.push(attribute);
  46. }
  47. const attribute = attributes.get(attr_key);
  48. const type = field.type;
  49. let defaultValue = field.defaultValue;
  50. if (type instanceof flatc.Enum) {
  51. if (!type.keys.has(defaultValue)) {
  52. throw new Error("Invalid '" + type.name + "' default value '" + defaultValue + "'.");
  53. }
  54. defaultValue = type.keys.get(defaultValue);
  55. }
  56. attribute.type = type.name === 'bool' ? 'boolean' : type.name + (field.repeated ? '[]' : '');
  57. attribute.default = defaultValue;
  58. }
  59. }
  60. }
  61. json.sort((a, b) => a.name.localeCompare(b.name));
  62. let output = JSON.stringify(json, null, 2);
  63. output = output.replace(/\s {8}/g, ' ');
  64. output = output.replace(/,\s {8}/g, ', ');
  65. output = output.replace(/\s {6}}/g, ' }');
  66. await fs.writeFile(file, output, 'utf-8');
  67. };
  68. main();