nnabla_script.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. """ NNabla metadata script """
  2. import json
  3. import os
  4. import sys
  5. import mako.template
  6. import yaml
  7. def _write(path, content):
  8. with open(path, "w", encoding="utf-8") as file:
  9. file.write(content)
  10. def _read_yaml(path):
  11. with open(path, encoding="utf-8") as file:
  12. return yaml.safe_load(file)
  13. def _metadata():
  14. def parse_functions(function_info):
  15. functions = []
  16. for category_name, category in function_info.items():
  17. for function_name, function_value in category.items():
  18. function = {
  19. "name": function_name,
  20. "description": function_value["doc"].strip()
  21. }
  22. for input_name, input_value in function_value.get("inputs", {}).items():
  23. option = "optional" if input_value.get("optional", False) else None
  24. variadic = input_value.get("variadic", False)
  25. function.setdefault("inputs", []).append({
  26. "name": input_name,
  27. "type": "nnabla.Variable",
  28. "option": option,
  29. "list": variadic,
  30. "description": input_value["doc"].strip()
  31. })
  32. for arg_name, arg_value in function_value.get("arguments", {}).items():
  33. attribute = _attribute(arg_name, arg_value)
  34. function.setdefault("attributes", []).append(attribute)
  35. outputs = function_value.get("outputs", {})
  36. for output_name, output_value in outputs.items():
  37. function.setdefault("outputs", []).append({
  38. "name": output_name,
  39. "type": "nnabla.Variable",
  40. "list": output_value.get("variadic", False),
  41. "description": output_value["doc"].strip()
  42. })
  43. if "Pooling" in function_name:
  44. function["category"] = "Pool"
  45. elif category_name == "Neural Network Layer":
  46. function["category"] = "Layer"
  47. elif category_name == "Neural Network Activation Functions":
  48. function["category"] = "Activation"
  49. elif category_name == "Normalization":
  50. function["category"] = "Normalization"
  51. elif category_name == "Logical":
  52. function["category"] = "Logic"
  53. elif category_name == "Array Manipulation":
  54. function["category"] = "Shape"
  55. functions.append(function)
  56. return functions
  57. def cleanup_functions(functions):
  58. for function in functions:
  59. for inp in function.get("inputs", []):
  60. if inp["option"] is None:
  61. inp.pop("option", None)
  62. if not inp["list"]:
  63. inp.pop("list", None)
  64. for output in function.get("outputs", []):
  65. if not output["list"]:
  66. output.pop("list", None)
  67. root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  68. nnabla_dir = os.path.join(root, "third_party", "source", "nnabla")
  69. code_generator_dir = os.path.join(nnabla_dir, "build-tools", "code_generator")
  70. functions_yaml_path = os.path.join(code_generator_dir, "functions.yaml")
  71. function_info = _read_yaml(functions_yaml_path)
  72. functions = parse_functions(function_info)
  73. cleanup_functions(functions)
  74. metadata_file = os.path.join(root, "source", "nnabla-metadata.json")
  75. metadata = json.dumps(functions, indent=2)
  76. _write(metadata_file, metadata)
  77. def _schema():
  78. root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  79. nnabla_dir = os.path.join(root, "third_party", "source", "nnabla")
  80. tmpl_file = os.path.join(nnabla_dir, "src/nbla/proto/nnabla.proto.tmpl")
  81. code_generator_dir = os.path.join(nnabla_dir, "build-tools", "code_generator")
  82. yaml_functions_path = os.path.join(code_generator_dir, "functions.yaml")
  83. yaml_solvers_path = os.path.join(code_generator_dir, "solvers.yaml")
  84. functions = _read_yaml(yaml_functions_path)
  85. function_info = {}
  86. for _, category in functions.items():
  87. function_info.update(category)
  88. solver_info = _read_yaml(yaml_solvers_path)
  89. path = tmpl_file.replace(".tmpl", "")
  90. template = mako.template.Template(text=None, filename=tmpl_file, preprocessor=None)
  91. content = template.render(function_info=function_info, solver_info=solver_info)
  92. content = content.replace("\r\n", "\n").replace("\r", "\n")
  93. _write(path, content)
  94. def _attribute(name, value):
  95. attribute = {}
  96. attribute["name"] = name
  97. default = "default" in value
  98. if not default:
  99. attribute["required"] = True
  100. if value["type"] == "float":
  101. attribute["type"] = "float32"
  102. if default:
  103. attribute["default"] = float(value["default"])
  104. elif value["type"] == "double":
  105. attribute["type"] = "float64"
  106. if default:
  107. attribute["default"] = float(value["default"])
  108. elif value["type"] == "bool":
  109. attribute["type"] = "boolean"
  110. if default:
  111. _ = value["default"]
  112. if isinstance(_, bool):
  113. attribute["default"] = _
  114. elif _ == "True":
  115. attribute["default"] = True
  116. elif _ == "False":
  117. attribute["default"] = False
  118. elif value["type"] == "string":
  119. attribute["type"] = "string"
  120. if default:
  121. _ = value["default"]
  122. attribute["default"] = _.strip("'")
  123. elif value["type"] == "int64":
  124. attribute["type"] = "int64"
  125. if default:
  126. _ = value["default"]
  127. if isinstance(_, str) and not _.startswith("len") and _ != "None":
  128. attribute["default"] = int(_)
  129. else:
  130. attribute["default"] = _
  131. elif value["type"] == "repeated int64":
  132. attribute["type"] = "int64[]"
  133. elif value["type"] == "repeated float":
  134. attribute["type"] = "float32[]"
  135. elif value["type"] == "Shape":
  136. attribute["type"] = "shape"
  137. if default and "default" not in attribute:
  138. attribute["default"] = value["default"]
  139. attribute["description"] = value["doc"].strip()
  140. return attribute
  141. def main():
  142. table = { "metadata": _metadata, "schema": _schema }
  143. for command in sys.argv[1:]:
  144. table[command]()
  145. if __name__ == "__main__":
  146. main()