2
0

nnabla_script.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. ''' NNabla metadata script '''
  2. import json
  3. import sys
  4. import os
  5. import yaml
  6. import mako
  7. import mako.template
  8. def _render_with_template(text=None, filename=None, preprocessor=None, template_kwargs={}):
  9. tmpl = mako.template.Template(text=text, filename=filename, preprocessor=preprocessor)
  10. return tmpl.render(**template_kwargs)
  11. def _generate_from_template(path_template, **kwargs):
  12. path_out = path_template.replace('.tmpl', '')
  13. generated = _render_with_template(filename=path_template, template_kwargs=kwargs)
  14. with open(path_out, 'wb') as file:
  15. write_content = generated.encode('utf_8')
  16. write_content = write_content.replace(b'\r\n', b'\n')
  17. write_content = write_content.replace(b'\r', b'\n')
  18. file.write(write_content)
  19. def _metadata():
  20. json_file = os.path.join(os.path.dirname(__file__), '../source/nnabla-metadata.json')
  21. base = os.path.abspath(os.path.join(os.path.dirname(__file__), '../third_party/source/nnabla'))
  22. yaml_functions = os.path.join(base, 'build-tools/code_generator/functions.yaml')
  23. with open(yaml_functions, 'r', encoding='utf-8') as file:
  24. function_info = yaml.safe_load(file)
  25. functions = []
  26. # parse functions
  27. for category_name, category in function_info.items():
  28. for function_name, function_value in category.items():
  29. function = {
  30. 'name': function_name,
  31. 'description': function_value['doc'].strip()
  32. }
  33. for input_name, input_value in function_value.get('inputs', {}).items():
  34. function.setdefault('inputs', []).append({
  35. 'name': input_name,
  36. 'type': 'nnabla.Variable',
  37. 'option': 'optional' if input_value.get('optional', False) else None,
  38. 'list': input_value.get('variadic', False),
  39. 'description': input_value['doc'].strip()
  40. })
  41. for arg_name, arg_value in function_value.get('arguments', {}).items():
  42. function.setdefault('attributes', []).append({
  43. 'name': arg_name,
  44. 'type': arg_value['type'],
  45. 'required': 'default' not in arg_value,
  46. 'default': _try_eval_default(arg_value.get('default', None)),
  47. 'description': arg_value['doc'].strip()
  48. })
  49. for output_name, output_value in function_value.get('outputs', {}).items():
  50. function.setdefault('outputs', []).append({
  51. 'name': output_name,
  52. 'type': 'nnabla.Variable',
  53. 'list': output_value.get('variadic', False),
  54. 'description': output_value['doc'].strip()
  55. })
  56. if 'Pooling' in function_name:
  57. function['category'] = 'Pool'
  58. elif category_name == 'Neural Network Layer':
  59. function['category'] = 'Layer'
  60. elif category_name == 'Neural Network Activation Functions':
  61. function['category'] = 'Activation'
  62. elif category_name == 'Normalization':
  63. function['category'] = 'Normalization'
  64. elif category_name == 'Logical':
  65. function['category'] = 'Logic'
  66. elif category_name == 'Array Manipulation':
  67. function['category'] = 'Shape'
  68. functions.append(function)
  69. # clean-up redundant fields
  70. for function in functions:
  71. for inp in function.get('inputs', []):
  72. if inp['option'] is None:
  73. inp.pop('option', None)
  74. if not inp['list']:
  75. inp.pop('list', None)
  76. for attribute in function.get('attributes', []):
  77. if attribute['required']:
  78. attribute.pop('default', None)
  79. for output in function.get('outputs', []):
  80. if not output['list']:
  81. output.pop('list', None)
  82. with open(json_file, 'w', encoding='utf-8') as file:
  83. file.write(json.dumps(functions, indent=2))
  84. def _schema():
  85. base = os.path.abspath(os.path.join(os.path.dirname(__file__), '../third_party/source/nnabla'))
  86. tmpl_file = os.path.join(base, 'src/nbla/proto/nnabla.proto.tmpl')
  87. yaml_functions = os.path.join(base, 'build-tools/code_generator/functions.yaml')
  88. yaml_solvers = os.path.join(base, 'build-tools/code_generator/solvers.yaml')
  89. with open(yaml_functions, 'r', encoding='utf-8') as file:
  90. functions = yaml.safe_load(file)
  91. function_info = {k: v for _, category in functions.items() for k, v in category.items()}
  92. with open(yaml_solvers, 'r', encoding='utf-8') as file:
  93. solver_info = yaml.safe_load(file)
  94. _generate_from_template(tmpl_file, function_info=function_info, solver_info=solver_info)
  95. def _try_eval_default(default):
  96. if default and isinstance(default, str):
  97. if not default.startswith(('(', '[')):
  98. try:
  99. default = eval(default, {'__builtin__': None})
  100. except NameError:
  101. pass
  102. return default
  103. def main(): # pylint: disable=missing-function-docstring
  104. command_table = {'metadata': _metadata, 'schema': _schema}
  105. command = sys.argv[1]
  106. command_table[command]()
  107. if __name__ == '__main__':
  108. main()