2
0

nnabla_script.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. ''' NNabla metadata script '''
  2. import json
  3. import sys
  4. import os
  5. import yaml # pylint: disable=import-error
  6. import mako.template # pylint: disable=import-error
  7. def _write(path, content):
  8. with open(path, 'w', encoding='utf-8') as file:
  9. file.write(content)
  10. def _read_yaml(path):
  11. with open(path, 'r', encoding='utf-8') as file:
  12. return yaml.safe_load(file)
  13. def _metadata():
  14. def parse_functions(function_info):
  15. functions = []
  16. for category_name, category in function_info.items():
  17. for function_name, function_value in category.items():
  18. function = {
  19. 'name': function_name,
  20. 'description': function_value['doc'].strip()
  21. }
  22. for input_name, input_value in function_value.get('inputs', {}).items():
  23. function.setdefault('inputs', []).append({
  24. 'name': input_name,
  25. 'type': 'nnabla.Variable',
  26. 'option': 'optional' if input_value.get('optional', False) else None,
  27. 'list': input_value.get('variadic', False),
  28. 'description': input_value['doc'].strip()
  29. })
  30. for arg_name, arg_value in function_value.get('arguments', {}).items():
  31. attribute = _attribute(arg_name, arg_value)
  32. function.setdefault('attributes', []).append(attribute)
  33. for output_name, output_value in function_value.get('outputs', {}).items():
  34. function.setdefault('outputs', []).append({
  35. 'name': output_name,
  36. 'type': 'nnabla.Variable',
  37. 'list': output_value.get('variadic', False),
  38. 'description': output_value['doc'].strip()
  39. })
  40. if 'Pooling' in function_name:
  41. function['category'] = 'Pool'
  42. elif category_name == 'Neural Network Layer':
  43. function['category'] = 'Layer'
  44. elif category_name == 'Neural Network Activation Functions':
  45. function['category'] = 'Activation'
  46. elif category_name == 'Normalization':
  47. function['category'] = 'Normalization'
  48. elif category_name == 'Logical':
  49. function['category'] = 'Logic'
  50. elif category_name == 'Array Manipulation':
  51. function['category'] = 'Shape'
  52. functions.append(function)
  53. return functions
  54. def cleanup_functions(functions):
  55. for function in functions:
  56. for inp in function.get('inputs', []):
  57. if inp['option'] is None:
  58. inp.pop('option', None)
  59. if not inp['list']:
  60. inp.pop('list', None)
  61. for output in function.get('outputs', []):
  62. if not output['list']:
  63. output.pop('list', None)
  64. root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  65. functions_yaml_path = os.path.join(root, \
  66. 'third_party', 'source', 'nnabla', 'build-tools', 'code_generator', 'functions.yaml')
  67. function_info = _read_yaml(functions_yaml_path)
  68. functions = parse_functions(function_info)
  69. cleanup_functions(functions)
  70. _write(os.path.join(root, 'source', 'nnabla-metadata.json'), json.dumps(functions, indent=2))
  71. def _schema():
  72. root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  73. third_party_dir = os.path.join(root, 'third_party', 'source', 'nnabla')
  74. tmpl_file = os.path.join(third_party_dir, 'src/nbla/proto/nnabla.proto.tmpl')
  75. yaml_functions_path = os.path.join(third_party_dir, 'build-tools/code_generator/functions.yaml')
  76. yaml_solvers_path = os.path.join(third_party_dir, 'build-tools/code_generator/solvers.yaml')
  77. functions = _read_yaml(yaml_functions_path)
  78. function_info = {k: v for _, category in functions.items() for k, v in category.items()}
  79. solver_info = _read_yaml(yaml_solvers_path)
  80. path = tmpl_file.replace('.tmpl', '')
  81. template = mako.template.Template(text=None, filename=tmpl_file, preprocessor=None)
  82. content = template.render(function_info=function_info, solver_info=solver_info)
  83. content = content.replace('\r\n', '\n').replace('\r', '\n')
  84. _write(path, content)
  85. def _attribute(name, value): # pylint: disable=too-many-branches
  86. attribute = {}
  87. attribute['name'] = name
  88. default = 'default' in value
  89. if not default:
  90. attribute['required'] = True
  91. if value['type'] == 'float':
  92. attribute['type'] = 'float32'
  93. if default:
  94. attribute['default'] = float(value['default'])
  95. elif value['type'] == 'double':
  96. attribute['type'] = 'float64'
  97. if default:
  98. attribute['default'] = float(value['default'])
  99. elif value['type'] == 'bool':
  100. attribute['type'] = 'boolean'
  101. if default:
  102. _ = value['default']
  103. if isinstance(_, bool):
  104. attribute['default'] = _
  105. elif _ == 'True':
  106. attribute['default'] = True
  107. elif _ == 'False':
  108. attribute['default'] = False
  109. elif value['type'] == 'string':
  110. attribute['type'] = 'string'
  111. if default:
  112. _ = value['default']
  113. attribute['default'] = _.strip("'")
  114. elif value['type'] == 'int64':
  115. attribute['type'] = 'int64'
  116. if default:
  117. _ = value['default']
  118. if isinstance(_, str) and not _.startswith('len') and _ != 'None':
  119. attribute['default'] = int(_)
  120. else:
  121. attribute['default'] = _
  122. elif value['type'] == 'repeated int64':
  123. attribute['type'] = 'int64[]'
  124. elif value['type'] == 'repeated float':
  125. attribute['type'] = 'float32[]'
  126. elif value['type'] == 'Shape':
  127. attribute['type'] = 'shape'
  128. if default and 'default' not in attribute:
  129. attribute['default'] = value['default']
  130. attribute['description'] = value['doc'].strip()
  131. return attribute
  132. def main(): # pylint: disable=missing-function-docstring
  133. table = { 'metadata': _metadata, 'schema': _schema }
  134. for command in sys.argv[1:]:
  135. table[command]()
  136. if __name__ == '__main__':
  137. main()