nnabla-script.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import json
  2. import sys
  3. import yaml
  4. import os
  5. import mako
  6. import mako.template
  7. def render_with_template(text=None, filename=None, preprocessor=None, template_kwargs={}):
  8. tmpl = mako.template.Template(text=text, filename=filename, preprocessor=preprocessor)
  9. try:
  10. return tmpl.render(**template_kwargs)
  11. except Exception as e:
  12. import sys
  13. print('-' * 78, file=sys.stderr)
  14. print('Template exceptions', file=sys.stderr)
  15. print('-' * 78, file=sys.stderr)
  16. print(mako.exceptions.text_error_template().render(), file=sys.stderr)
  17. print('-' * 78, file=sys.stderr)
  18. raise e
  19. def generate_from_template(path_template, **kwargs):
  20. path_out = path_template.replace('.tmpl', '')
  21. generated = render_with_template(filename=path_template, template_kwargs=kwargs)
  22. with open(path_out, 'wb') as file:
  23. write_content = generated.encode('utf_8')
  24. write_content = write_content.replace(b'\r\n', b'\n')
  25. write_content = write_content.replace(b'\r', b'\n')
  26. file.write(write_content)
  27. def metadata():
  28. json_file = os.path.join(os.path.dirname(__file__), '../source/nnabla-metadata.json')
  29. yaml_functions = os.path.join(os.path.dirname(__file__), '../third_party/source/nnabla/build-tools/code_generator/functions.yaml')
  30. with open(yaml_functions, 'r') as file:
  31. function_info = yaml.safe_load(file)
  32. functions = []
  33. # Parse functions
  34. for category_name, category in function_info.items():
  35. for function_name, function_value in category.items():
  36. function = {
  37. 'name': function_name,
  38. 'description': function_value['doc'].strip()
  39. }
  40. for input_name, input_value in function_value.get('inputs', {}).items():
  41. function.setdefault('inputs', []).append({
  42. 'name': input_name,
  43. 'type': 'nnabla.Variable',
  44. 'option': 'optional' if input_value.get('optional', False) else None,
  45. 'list': input_value.get('variadic', False),
  46. 'description': input_value['doc'].strip()
  47. })
  48. for arg_name, arg_value in function_value.get('arguments', {}).items():
  49. function.setdefault('attributes', []).append({
  50. 'name': arg_name,
  51. 'type': arg_value['type'],
  52. 'required': 'default' not in arg_value,
  53. 'default': try_eval_default(arg_value.get('default', None)),
  54. 'description': arg_value['doc'].strip()
  55. })
  56. for output_name, output_value in function_value.get('outputs', {}).items():
  57. function.setdefault('outputs', []).append({
  58. 'name': output_name,
  59. 'type': 'nnabla.Variable',
  60. 'list': output_value.get('variadic', False),
  61. 'description': output_value['doc'].strip()
  62. })
  63. if 'Pooling' in function_name:
  64. function['category'] = 'Pool'
  65. elif category_name == 'Neural Network Layer':
  66. function['category'] = 'Layer'
  67. elif category_name == 'Neural Network Activation Functions':
  68. function['category'] = 'Activation'
  69. elif category_name == 'Normalization':
  70. function['category'] = 'Normalization'
  71. elif category_name == 'Logical':
  72. function['category'] = 'Logic'
  73. elif category_name == 'Array Manipulation':
  74. function['category'] = 'Shape'
  75. functions.append(function)
  76. # Clean-up redundant fields
  77. for function in functions:
  78. for inp in function.get('inputs', []):
  79. if inp['option'] is None:
  80. inp.pop('option', None)
  81. if not inp['list']:
  82. inp.pop('list', None)
  83. for attribute in function.get('attributes', []):
  84. if attribute['required']:
  85. attribute.pop('default', None)
  86. for output in function.get('outputs', []):
  87. if not output['list']:
  88. output.pop('list', None)
  89. with open(json_file, 'w') as file:
  90. file.write(json.dumps(functions, indent=2))
  91. def proto():
  92. base = os.path.abspath(os.path.join(os.path.dirname(__file__), '../third_party/source/nnabla'))
  93. tmpl_file = os.path.join(base, 'src/nbla/proto/nnabla.proto.tmpl')
  94. yaml_functions = os.path.join(base, 'build-tools/code_generator/functions.yaml')
  95. yaml_solvers = os.path.join(base, 'build-tools/code_generator/solvers.yaml')
  96. with open(yaml_functions, 'r') as file:
  97. functions = yaml.safe_load(file)
  98. function_info = {k: v for _, category in functions.items() for k, v in category.items()}
  99. with open(yaml_solvers, 'r') as file:
  100. solver_info = yaml.safe_load(file)
  101. generate_from_template(tmpl_file, function_info=function_info, solver_info=solver_info)
  102. def try_eval_default(default):
  103. if default and isinstance(default, str):
  104. if not default.startswith(('(', '[')):
  105. try:
  106. return eval(default, {'__builtin__': None})
  107. except NameError:
  108. pass
  109. return default
  110. if __name__ == '__main__':
  111. command_table = {'metadata': metadata, 'proto': proto}
  112. command = sys.argv[1]
  113. command_table[command]()