| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- ''' NNabla metadata script '''
- import json
- import sys
- import os
- import yaml # pylint: disable=import-error
- import mako.template # pylint: disable=import-error
- def _write(path, content):
- with open(path, 'w', encoding='utf-8') as file:
- file.write(content)
- def _read_yaml(path):
- with open(path, 'r', encoding='utf-8') as file:
- return yaml.safe_load(file)
- def _metadata():
- def parse_functions(function_info):
- functions = []
- for category_name, category in function_info.items():
- for function_name, function_value in category.items():
- function = {
- 'name': function_name,
- 'description': function_value['doc'].strip()
- }
- for input_name, input_value in function_value.get('inputs', {}).items():
- function.setdefault('inputs', []).append({
- 'name': input_name,
- 'type': 'nnabla.Variable',
- 'option': 'optional' if input_value.get('optional', False) else None,
- 'list': input_value.get('variadic', False),
- 'description': input_value['doc'].strip()
- })
- for arg_name, arg_value in function_value.get('arguments', {}).items():
- attribute = _attribute(arg_name, arg_value)
- function.setdefault('attributes', []).append(attribute)
- for output_name, output_value in function_value.get('outputs', {}).items():
- function.setdefault('outputs', []).append({
- 'name': output_name,
- 'type': 'nnabla.Variable',
- 'list': output_value.get('variadic', False),
- 'description': output_value['doc'].strip()
- })
- if 'Pooling' in function_name:
- function['category'] = 'Pool'
- elif category_name == 'Neural Network Layer':
- function['category'] = 'Layer'
- elif category_name == 'Neural Network Activation Functions':
- function['category'] = 'Activation'
- elif category_name == 'Normalization':
- function['category'] = 'Normalization'
- elif category_name == 'Logical':
- function['category'] = 'Logic'
- elif category_name == 'Array Manipulation':
- function['category'] = 'Shape'
- functions.append(function)
- return functions
- def cleanup_functions(functions):
- for function in functions:
- for inp in function.get('inputs', []):
- if inp['option'] is None:
- inp.pop('option', None)
- if not inp['list']:
- inp.pop('list', None)
- for output in function.get('outputs', []):
- if not output['list']:
- output.pop('list', None)
- root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- functions_yaml_path = os.path.join(root, \
- 'third_party', 'source', 'nnabla', 'build-tools', 'code_generator', 'functions.yaml')
- function_info = _read_yaml(functions_yaml_path)
- functions = parse_functions(function_info)
- cleanup_functions(functions)
- _write(os.path.join(root, 'source', 'nnabla-metadata.json'), json.dumps(functions, indent=2))
- def _schema():
- root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- third_party_dir = os.path.join(root, 'third_party', 'source', 'nnabla')
- tmpl_file = os.path.join(third_party_dir, 'src/nbla/proto/nnabla.proto.tmpl')
- yaml_functions_path = os.path.join(third_party_dir, 'build-tools/code_generator/functions.yaml')
- yaml_solvers_path = os.path.join(third_party_dir, 'build-tools/code_generator/solvers.yaml')
- functions = _read_yaml(yaml_functions_path)
- function_info = {k: v for _, category in functions.items() for k, v in category.items()}
- solver_info = _read_yaml(yaml_solvers_path)
- path = tmpl_file.replace('.tmpl', '')
- template = mako.template.Template(text=None, filename=tmpl_file, preprocessor=None)
- content = template.render(function_info=function_info, solver_info=solver_info)
- content = content.replace('\r\n', '\n').replace('\r', '\n')
- _write(path, content)
- def _attribute(name, value): # pylint: disable=too-many-branches
- attribute = {}
- attribute['name'] = name
- default = 'default' in value
- if not default:
- attribute['required'] = True
- if value['type'] == 'float':
- attribute['type'] = 'float32'
- if default:
- attribute['default'] = float(value['default'])
- elif value['type'] == 'double':
- attribute['type'] = 'float64'
- if default:
- attribute['default'] = float(value['default'])
- elif value['type'] == 'bool':
- attribute['type'] = 'boolean'
- if default:
- _ = value['default']
- if isinstance(_, bool):
- attribute['default'] = _
- elif _ == 'True':
- attribute['default'] = True
- elif _ == 'False':
- attribute['default'] = False
- elif value['type'] == 'string':
- attribute['type'] = 'string'
- if default:
- _ = value['default']
- attribute['default'] = _.strip("'")
- elif value['type'] == 'int64':
- attribute['type'] = 'int64'
- if default:
- _ = value['default']
- if isinstance(_, str) and not _.startswith('len') and _ != 'None':
- attribute['default'] = int(_)
- else:
- attribute['default'] = _
- elif value['type'] == 'repeated int64':
- attribute['type'] = 'int64[]'
- elif value['type'] == 'repeated float':
- attribute['type'] = 'float32[]'
- elif value['type'] == 'Shape':
- attribute['type'] = 'shape'
- if default and 'default' not in attribute:
- attribute['default'] = value['default']
- attribute['description'] = value['doc'].strip()
- return attribute
- def main(): # pylint: disable=missing-function-docstring
- table = { 'metadata': _metadata, 'schema': _schema }
- for command in sys.argv[1:]:
- table[command]()
- if __name__ == '__main__':
- main()
|