keras_script.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. """ Keras metadata script """
  2. import json
  3. import os
  4. import pydoc
  5. import re
  6. import warnings
  7. warnings.filterwarnings("ignore",
  8. category=FutureWarning,
  9. module="keras.src.export.tf2onnx_lib")
  10. os.environ["KERAS_BACKEND"] = "jax"
  11. def _parse_docstring(docstring):
  12. headers = []
  13. lines = docstring.splitlines()
  14. indents = filter(lambda s: len(s) > 0, lines[1:])
  15. indentation = min(map(lambda s: len(s) - len(s.lstrip()), indents))
  16. lines = list((s[indentation:] if len(s) > len(s.lstrip()) else s) for s in lines)
  17. docstring = "\n".join(lines)
  18. labels = [
  19. "Args", "Arguments", "Variables", "Fields", "Yields", "Call arguments",
  20. "Raises", "Examples", "Example", "Usage", "Input shape", "Output shape",
  21. "Returns", "Reference", "References"
  22. ]
  23. tag_re = re.compile("(?<=\n)(" + "|".join(labels) + "):\n", re.MULTILINE)
  24. parts = tag_re.split(docstring)
  25. headers.append(("", parts.pop(0)))
  26. while len(parts) > 0:
  27. headers.append((parts.pop(0), parts.pop(0)))
  28. return headers
  29. def _parse_arguments(arguments):
  30. result = []
  31. item_re = re.compile(r"^\s{0,4}(\*?\*?\w[\w.]*?\s*):\s", re.MULTILINE)
  32. content = item_re.split(arguments)
  33. if content.pop(0) != "":
  34. raise Exception("")
  35. while len(content) > 0:
  36. result.append((content.pop(0), content.pop(0)))
  37. return result
  38. def _convert_code_blocks(description):
  39. lines = description.splitlines()
  40. output = []
  41. while len(lines) > 0:
  42. line = lines.pop(0)
  43. if line.startswith(">>>") and len(lines) > 0 and \
  44. (lines[0].startswith(">>>") or lines[0].startswith("...")):
  45. output.append("```")
  46. output.append(line)
  47. while len(lines) > 0 and lines[0] != "":
  48. output.append(lines.pop(0))
  49. output.append("```")
  50. else:
  51. output.append(line)
  52. return "\n".join(output)
  53. def _remove_indentation(value):
  54. lines = value.splitlines()
  55. indentation = min(map(lambda s: len(s) - len(s.lstrip()), \
  56. filter(lambda s: len(s) > 0, lines)))
  57. lines = list((s[indentation:] if len(s) > 0 else s) for s in lines)
  58. return "\n".join(lines).strip()
  59. def _update_argument(schema, name, description):
  60. if "attributes" not in schema:
  61. schema["attributes"] = []
  62. attribute = next((_ for _ in schema["attributes"] if _["name"] == name), None)
  63. if not attribute:
  64. attribute = {}
  65. attribute["name"] = name
  66. schema["attributes"].append(attribute)
  67. attribute["description"] = _remove_indentation(description)
  68. def _update_input(schema, description):
  69. if "inputs" not in schema:
  70. schema["inputs"] = [ { "name": "input" } ]
  71. parameter = next((_ for _ in schema["inputs"] \
  72. if (_["name"] == "input" or _["name"] == "inputs")), None)
  73. if parameter:
  74. parameter["description"] = _remove_indentation(description)
  75. else:
  76. raise Exception("")
  77. def _update_output(schema, description):
  78. if "outputs" not in schema:
  79. schema["outputs"] = [ { "name": "output" } ]
  80. outputs = schema["outputs"]
  81. parameter = next((param for param in outputs if param["name"] == "output"), None)
  82. if parameter:
  83. parameter["description"] = _remove_indentation(description)
  84. else:
  85. raise Exception("")
  86. def _update_examples(schema, value):
  87. if "examples" in schema:
  88. del schema["examples"]
  89. value = _convert_code_blocks(value)
  90. lines = value.splitlines()
  91. code = []
  92. summary = []
  93. while len(lines) > 0:
  94. line = lines.pop(0)
  95. if len(line) > 0:
  96. if line.startswith("```"):
  97. while len(lines) > 0:
  98. line = lines.pop(0)
  99. if line == "```":
  100. break
  101. code.append(line)
  102. else:
  103. summary.append(line)
  104. if len(code) > 0:
  105. example = {}
  106. if len(summary) > 0:
  107. example["summary"] = "\n".join(summary)
  108. example["code"] = "\n".join(code)
  109. if "examples" not in schema:
  110. schema["examples"] = []
  111. schema["examples"].append(example)
  112. code = []
  113. summary = []
  114. def _update_references(schema, value):
  115. if "references" in schema:
  116. del schema["references"]
  117. references = []
  118. reference = ""
  119. lines = value.splitlines()
  120. for line in lines:
  121. if line.lstrip().startswith("- "):
  122. if len(reference) > 0:
  123. references.append(reference)
  124. reference = line.lstrip().lstrip("- ")
  125. else:
  126. if line.startswith(" "):
  127. line = line[2:]
  128. reference = " ".join([ reference, line.strip() ])
  129. if len(reference) > 0:
  130. references.append(reference)
  131. for reference in references:
  132. if "references" not in schema:
  133. schema["references"] = []
  134. if len(reference.strip()) > 0:
  135. schema["references"].append({ "description": reference })
  136. def _update_headers(schema, docstring):
  137. headers = _parse_docstring(docstring)
  138. for header in headers:
  139. key, value = header
  140. if key == "":
  141. description = _convert_code_blocks(value)
  142. schema["description"] = _remove_indentation(description)
  143. elif key in ("Args", "Arguments"):
  144. arguments = _parse_arguments(value)
  145. for argument in arguments:
  146. _update_argument(schema, argument[0], argument[1])
  147. elif key == "Input shape":
  148. _update_input(schema, value)
  149. elif key == "Output shape":
  150. _update_output(schema, value)
  151. elif key in ("Example", "Examples", "Usage"):
  152. _update_examples(schema, value)
  153. elif key in ("Reference", "References"):
  154. _update_references(schema, value)
  155. elif key in ("Call arguments", "Returns", "Variables", "Raises"):
  156. pass
  157. else:
  158. raise Exception("")
  159. def _metadata():
  160. root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  161. json_path = os.path.join(root, "source", "keras-metadata.json")
  162. with open(json_path, encoding="utf-8") as file:
  163. json_content = file.read()
  164. json_root = json.loads(json_content)
  165. skip_names = set([
  166. "keras.layers.InputLayer",
  167. "keras.layers.ThresholdedReLU",
  168. "keras.layers.LocallyConnected1D",
  169. "keras.layers.LocallyConnected2D"
  170. ])
  171. for metadata in json_root:
  172. if "module" in metadata:
  173. name = metadata["module"] + "." + metadata["name"]
  174. if name not in skip_names:
  175. cls = pydoc.locate(name)
  176. if not cls:
  177. raise KeyError(f"'{name}' not found.")
  178. if not cls.__doc__:
  179. raise AttributeError(f"'{name}' missing __doc__.")
  180. if cls.__doc__ == "DEPRECATED.":
  181. raise DeprecationWarning(f"'{name}.__doc__' is deprecated.'")
  182. _update_headers(metadata, cls.__doc__)
  183. with open(json_path, "w", encoding="utf-8") as file:
  184. content = json.dumps(json_root, sort_keys=False, indent=2)
  185. for line in content.splitlines():
  186. file.write(line.rstrip() + "\n")
  187. def main():
  188. _metadata()
  189. if __name__ == "__main__":
  190. main()