sklearn_script.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """ scikit-learn metadata script """
  2. import json
  3. import os
  4. import pydoc
  5. import re
  6. import sys
  7. def _split_docstring(value):
  8. headers = {}
  9. current_header = ""
  10. current_lines = []
  11. lines = value.split("\n")
  12. index = 0
  13. while index < len(lines):
  14. if index + 1 < len(lines) and len(lines[index + 1].strip(" ")) > 0 and \
  15. len(lines[index + 1].strip(" ").strip("-")) == 0:
  16. headers[current_header] = current_lines
  17. current_header = lines[index].strip(" ")
  18. current_lines = []
  19. index = index + 1
  20. else:
  21. current_lines.append(lines[index])
  22. index = index + 1
  23. headers[current_header] = current_lines
  24. return headers
  25. def _update_description(schema, lines):
  26. if len("".join(lines).strip(" ")) > 0:
  27. for i, value in enumerate(lines):
  28. lines[i] = value.lstrip(" ")
  29. schema["description"] = "\n".join(lines)
  30. def _attribute_value(attribute_type, attribute_value):
  31. if attribute_value in ("None", "np.finfo(float).eps"):
  32. return None
  33. if attribute_type in ("float32", "int32", "boolean", "string"):
  34. if attribute_value in ("'auto'", '"auto"') or attribute_type == "string":
  35. return attribute_value.strip("'").strip('"')
  36. if attribute_type == "float32":
  37. return float(attribute_value)
  38. if attribute_type == "int32":
  39. return int(attribute_value)
  40. if attribute_type == "boolean":
  41. if attribute_value in ("True", "False"):
  42. return attribute_value == "True"
  43. raise ValueError(f"Unknown boolean default value '{str(attribute_value)}'.")
  44. if attribute_type:
  45. raise ValueError("Unknown default type '" + attribute_type + "'.")
  46. return attribute_value.strip("'")
  47. def _find_attribute(schema, name):
  48. schema.setdefault("attributes", [])
  49. attribute = next((_ for _ in schema["attributes"] if _["name"] == name), None)
  50. if not attribute:
  51. attribute = { "name": name }
  52. schema["attributes"].append(attribute)
  53. return attribute
  54. def _update_attributes(schema, lines):
  55. doc_indent = " " if sys.version_info[:2] >= (3, 13) else " "
  56. while len(lines) > 0:
  57. line = lines.pop(0)
  58. match = re.match(r"\s*(\w*)\s*:\s*(.*)\s*", line)
  59. if not match:
  60. raise SyntaxError("Expected ':' in parameter.")
  61. name = match.group(1)
  62. line = match.group(2)
  63. attribute = _find_attribute(schema, name)
  64. match = re.match(r"(.*),\s*default=(.*)\s*", line)
  65. default_value = None
  66. if match:
  67. line = match.group(1)
  68. default_value = match.group(2)
  69. attribute_types = {
  70. "float": "float32",
  71. "boolean": "boolean",
  72. "bool": "boolean",
  73. "str": "string",
  74. "string": "string",
  75. "int": "int32",
  76. "integer": "int32"
  77. }
  78. attribute_type = attribute_types.get(line, None)
  79. if default_value:
  80. attribute["default"] = _attribute_value(attribute_type, default_value)
  81. description = []
  82. while len(lines) > 0:
  83. if lines[0].strip() != "" and not lines[0].startswith(doc_indent):
  84. break
  85. line = lines.pop(0).lstrip(" ")
  86. description.append(line)
  87. attribute["description"] = "\n".join(description)
  88. def _metadata():
  89. root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
  90. json_file = os.path.join(root_dir, "source", "sklearn-metadata.json")
  91. with open(json_file, encoding="utf-8") as file:
  92. json_root = json.loads(file.read())
  93. for schema in json_root:
  94. name = schema["name"]
  95. skip_modules = [
  96. "lightgbm.",
  97. "sklearn.svm.classes",
  98. "sklearn.ensemble.forest.",
  99. "sklearn.ensemble.weight_boosting.",
  100. "sklearn.neural_network.multilayer_perceptron.",
  101. "sklearn.tree.tree."
  102. ]
  103. if not any(name.startswith(module) for module in skip_modules):
  104. class_definition = pydoc.locate(name)
  105. if not class_definition:
  106. raise KeyError("'" + name + "' not found.")
  107. docstring = class_definition.__doc__
  108. if not docstring:
  109. raise Exception("'" + name + "' missing __doc__.")
  110. headers = _split_docstring(docstring)
  111. if "" in headers:
  112. _update_description(schema, headers[""])
  113. if "Parameters" in headers:
  114. _update_attributes(schema, headers["Parameters"])
  115. with open(json_file, "w", encoding="utf-8") as file:
  116. file.write(json.dumps(json_root, sort_keys=False, indent=2))
  117. def main():
  118. _metadata()
  119. if __name__ == "__main__":
  120. main()