| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import * as flatc from './flatc.js';
- import * as fs from 'fs/promises';
- import * as path from 'path';
- import * as url from 'url';
- const main = async () => {
- const dirname = path.dirname(url.fileURLToPath(import.meta.url));
- const schema = path.join(dirname, '..', 'third_party', 'source', 'megengine', 'src', 'serialization', 'fbs', 'schema_v2.fbs');
- const file = path.join(dirname, '..', 'source', 'megengine-metadata.json');
- const input = await fs.readFile(file, 'utf-8');
- const json = JSON.parse(input);
- const category = {
- Host2DeviceCopy: 'Data',
- Dimshuffle: 'Shape',
- Flip: 'Shape',
- Images2Neibs: 'Shape',
- Reshape: 'Shape',
- Concat: 'Tensor',
- GetVarShape: 'Shape',
- Subtensor: 'Tensor',
- Padding: 'Layer',
- AdaptivePooling: 'Activation',
- ConvPooling: 'Pool',
- TQT: 'Quantization',
- LSQ: 'Quantization',
- Pooling: 'Pool',
- PoolingForward: 'Pool',
- AdaptivePoolingForward: 'Pool',
- SlidingWindowTranspose: 'Transform',
- LRN: 'Normalization',
- BatchNormForward: 'Normalization',
- BN: 'Normalization',
- LayerNorm: 'Normalization',
- Convolution: 'Layer',
- ConvolutionForward: 'Layer',
- Convolution3D: 'Layer',
- SeparableConv: 'Layer',
- SeparableConv3D: 'Layer',
- ConvBiasForward: 'Layer',
- ConvBias: 'Layer',
- Conv3DBias: 'Layer',
- Dropout: 'Dropout',
- Softmax: 'Activation',
- RNN: 'Layer',
- RNNCell: 'Layer',
- LSTM: 'Layer'
- };
- const operators = new Map();
- const attributes = new Map();
- for (const operator of json) {
- if (operators.has(operator.name)) {
- throw new Error(`Duplicate operator '${operator.name}'.`);
- }
- operators.set(operator.name, operator);
- if (operator && operator.attributes) {
- for (const attribute of operator.attributes) {
- const name = `${operator.name}:${attribute.name}`;
- attributes.set(name, attribute);
- }
- }
- }
- const root = new flatc.Root('megengine');
- await root.load([], [schema]);
- const namespace = root.find('mgb.serialization.fbs.param', flatc.Namespace);
- const operatorParams = namespace.children;
- for (const [name, op] of operatorParams) {
- if (op instanceof flatc.Enum) {
- continue;
- }
- if (op && op.fields.size > 0) {
- if (!operators.has(name)) {
- const operator = { name };
- operators.set(name, operator);
- json.push(operator);
- }
- const operator = operators.get(name);
- const k = name.replace(/V\d+$/, '');
- if (category[k]) {
- operator.category = category[k];
- }
- operator.attributes = operator.attributes || [];
- for (const [field_name, field] of op.fields) {
- const attr_key = `${name}:${field_name}`;
- if (!attributes.has(attr_key)) {
- const attribute = { name: field_name };
- attributes.set(attr_key, attribute);
- operator.attributes.push(attribute);
- }
- const attribute = attributes.get(attr_key);
- const type = field.type;
- let defaultValue = field.defaultValue;
- if (type instanceof flatc.Enum) {
- if (!type.keys.has(defaultValue)) {
- throw new Error(`Invalid '${type.name}' default value '${defaultValue}'.`);
- }
- defaultValue = type.keys.get(defaultValue);
- }
- attribute.type = type.name + (field.repeated ? '[]' : '');
- attribute.default = defaultValue;
- }
- }
- }
- // json.sort((a, b) => a.name.localeCompare(b.name))
- let output = JSON.stringify(json, null, 2);
- output = output.replace(/\s {8}/g, ' ');
- output = output.replace(/,\s {8}/g, ', ');
- output = output.replace(/\s {6}}/g, ' }');
- await fs.writeFile(file, output, 'utf-8');
- };
- await main();
|