| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- """ Keras metadata script """
- import json
- import os
- import pydoc
- import re
- import warnings
- warnings.filterwarnings("ignore",
- category=FutureWarning,
- module="keras.src.export.tf2onnx_lib")
- os.environ["KERAS_BACKEND"] = "jax"
- def _parse_docstring(docstring):
- headers = []
- lines = docstring.splitlines()
- indents = filter(lambda s: len(s) > 0, lines[1:])
- indentation = min(map(lambda s: len(s) - len(s.lstrip()), indents))
- lines = list((s[indentation:] if len(s) > len(s.lstrip()) else s) for s in lines)
- docstring = "\n".join(lines)
- labels = [
- "Args", "Arguments", "Variables", "Fields", "Yields", "Call arguments",
- "Raises", "Examples", "Example", "Usage", "Input shape", "Output shape",
- "Returns", "Reference", "References"
- ]
- tag_re = re.compile("(?<=\n)(" + "|".join(labels) + "):\n", re.MULTILINE)
- parts = tag_re.split(docstring)
- headers.append(("", parts.pop(0)))
- while len(parts) > 0:
- headers.append((parts.pop(0), parts.pop(0)))
- return headers
- def _parse_arguments(arguments):
- result = []
- item_re = re.compile(r"^\s{0,4}(\*?\*?\w[\w.]*?\s*):\s", re.MULTILINE)
- content = item_re.split(arguments)
- if content.pop(0) != "":
- raise Exception("")
- while len(content) > 0:
- result.append((content.pop(0), content.pop(0)))
- return result
- def _convert_code_blocks(description):
- lines = description.splitlines()
- output = []
- while len(lines) > 0:
- line = lines.pop(0)
- if line.startswith(">>>") and len(lines) > 0 and \
- (lines[0].startswith(">>>") or lines[0].startswith("...")):
- output.append("```")
- output.append(line)
- while len(lines) > 0 and lines[0] != "":
- output.append(lines.pop(0))
- output.append("```")
- else:
- output.append(line)
- return "\n".join(output)
- def _remove_indentation(value):
- lines = value.splitlines()
- indentation = min(map(lambda s: len(s) - len(s.lstrip()), \
- filter(lambda s: len(s) > 0, lines)))
- lines = list((s[indentation:] if len(s) > 0 else s) for s in lines)
- return "\n".join(lines).strip()
- def _update_argument(schema, name, description):
- if "attributes" not in schema:
- schema["attributes"] = []
- attribute = next((_ for _ in schema["attributes"] if _["name"] == name), None)
- if not attribute:
- attribute = {}
- attribute["name"] = name
- schema["attributes"].append(attribute)
- attribute["description"] = _remove_indentation(description)
- def _update_input(schema, description):
- if "inputs" not in schema:
- schema["inputs"] = [ { "name": "input" } ]
- parameter = next((_ for _ in schema["inputs"] \
- if (_["name"] == "input" or _["name"] == "inputs")), None)
- if parameter:
- parameter["description"] = _remove_indentation(description)
- else:
- raise Exception("")
- def _update_output(schema, description):
- if "outputs" not in schema:
- schema["outputs"] = [ { "name": "output" } ]
- outputs = schema["outputs"]
- parameter = next((param for param in outputs if param["name"] == "output"), None)
- if parameter:
- parameter["description"] = _remove_indentation(description)
- else:
- raise Exception("")
- def _update_examples(schema, value):
- if "examples" in schema:
- del schema["examples"]
- value = _convert_code_blocks(value)
- lines = value.splitlines()
- code = []
- summary = []
- while len(lines) > 0:
- line = lines.pop(0)
- if len(line) > 0:
- if line.startswith("```"):
- while len(lines) > 0:
- line = lines.pop(0)
- if line == "```":
- break
- code.append(line)
- else:
- summary.append(line)
- if len(code) > 0:
- example = {}
- if len(summary) > 0:
- example["summary"] = "\n".join(summary)
- example["code"] = "\n".join(code)
- if "examples" not in schema:
- schema["examples"] = []
- schema["examples"].append(example)
- code = []
- summary = []
- def _update_references(schema, value):
- if "references" in schema:
- del schema["references"]
- references = []
- reference = ""
- lines = value.splitlines()
- for line in lines:
- if line.lstrip().startswith("- "):
- if len(reference) > 0:
- references.append(reference)
- reference = line.lstrip().lstrip("- ")
- else:
- if line.startswith(" "):
- line = line[2:]
- reference = " ".join([ reference, line.strip() ])
- if len(reference) > 0:
- references.append(reference)
- for reference in references:
- if "references" not in schema:
- schema["references"] = []
- if len(reference.strip()) > 0:
- schema["references"].append({ "description": reference })
- def _update_headers(schema, docstring):
- headers = _parse_docstring(docstring)
- for header in headers:
- key, value = header
- if key == "":
- description = _convert_code_blocks(value)
- schema["description"] = _remove_indentation(description)
- elif key in ("Args", "Arguments"):
- arguments = _parse_arguments(value)
- for argument in arguments:
- _update_argument(schema, argument[0], argument[1])
- elif key == "Input shape":
- _update_input(schema, value)
- elif key == "Output shape":
- _update_output(schema, value)
- elif key in ("Example", "Examples", "Usage"):
- _update_examples(schema, value)
- elif key in ("Reference", "References"):
- _update_references(schema, value)
- elif key in ("Call arguments", "Returns", "Variables", "Raises"):
- pass
- else:
- raise Exception("")
- def _metadata():
- root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- json_path = os.path.join(root, "source", "keras-metadata.json")
- with open(json_path, encoding="utf-8") as file:
- json_content = file.read()
- json_root = json.loads(json_content)
- skip_names = set([
- "keras.layers.InputLayer",
- "keras.layers.ThresholdedReLU",
- "keras.layers.LocallyConnected1D",
- "keras.layers.LocallyConnected2D"
- ])
- for metadata in json_root:
- if "module" in metadata:
- name = metadata["module"] + "." + metadata["name"]
- if name not in skip_names:
- cls = pydoc.locate(name)
- if not cls:
- raise KeyError(f"'{name}' not found.")
- if not cls.__doc__:
- raise AttributeError(f"'{name}' missing __doc__.")
- if cls.__doc__ == "DEPRECATED.":
- raise DeprecationWarning(f"'{name}.__doc__' is deprecated.'")
- _update_headers(metadata, cls.__doc__)
- with open(json_path, "w", encoding="utf-8") as file:
- content = json.dumps(json_root, sort_keys=False, indent=2)
- for line in content.splitlines():
- file.write(line.rstrip() + "\n")
- def main():
- _metadata()
- if __name__ == "__main__":
- main()
|