sklearn_metadata.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. ''' scikit-learn metadata script '''
  2. import json
  3. import os
  4. import pydoc
  5. import re
  6. def _split_docstring(value):
  7. headers = {}
  8. current_header = ''
  9. current_lines = []
  10. lines = value.split('\n')
  11. index = 0
  12. while index < len(lines):
  13. if index + 1 < len(lines) and len(lines[index + 1].strip(' ')) > 0 and \
  14. len(lines[index + 1].strip(' ').strip('-')) == 0:
  15. headers[current_header] = current_lines
  16. current_header = lines[index].strip(' ')
  17. current_lines = []
  18. index = index + 1
  19. else:
  20. current_lines.append(lines[index])
  21. index = index + 1
  22. headers[current_header] = current_lines
  23. return headers
  24. def _update_description(schema, lines):
  25. if len(''.join(lines).strip(' ')) > 0:
  26. for i, value in enumerate(lines):
  27. lines[i] = value.lstrip(' ')
  28. schema['description'] = '\n'.join(lines)
  29. def _attribute_value(attribute_type, attribute_value):
  30. if attribute_value in ('None', 'np.finfo(float).eps'):
  31. return None
  32. if attribute_type in ('float32', 'int32', 'boolean', 'string'):
  33. if attribute_value in ("'auto'", '"auto"') or attribute_type == 'string':
  34. return attribute_value.strip("'").strip('"')
  35. if attribute_type == 'float32':
  36. return float(attribute_value)
  37. if attribute_type == 'int32':
  38. return int(attribute_value)
  39. if attribute_type == 'boolean':
  40. if attribute_value in ('True', 'False'):
  41. return attribute_value == 'True'
  42. raise ValueError("Unknown boolean default value '" + str(attribute_value) + "'.")
  43. if attribute_type:
  44. raise ValueError("Unknown default type '" + attribute_type + "'.")
  45. return attribute_value.strip("'")
  46. def _find_attribute(schema, name):
  47. schema.setdefault('attributes', [])
  48. attribute = next((_ for _ in schema['attributes'] if _['name'] == name), None)
  49. if not attribute:
  50. attribute = { 'name': name }
  51. schema['attributes'].append(attribute)
  52. return attribute
  53. def _update_attributes(schema, lines):
  54. while len(lines) > 0:
  55. line = lines.pop(0)
  56. match = re.match(r'\s*(\w*)\s*:\s*(.*)\s*', line)
  57. if not match:
  58. raise SyntaxError("Expected ':' in parameter.")
  59. name = match.group(1)
  60. line = match.group(2)
  61. attribute = _find_attribute(schema, name)
  62. match = re.match(r'(.*),\s*default=(.*)\s*', line)
  63. default_value = None
  64. if match:
  65. line = match.group(1)
  66. default_value = match.group(2)
  67. attribute_types = {
  68. 'float': 'float32',
  69. 'boolean': 'boolean',
  70. 'bool': 'boolean',
  71. 'str': 'string',
  72. 'string': 'string',
  73. 'int': 'int32',
  74. 'integer': 'int32'
  75. }
  76. attribute_type = attribute_types.get(line, None)
  77. if default_value:
  78. attribute['default'] = _attribute_value(attribute_type, default_value)
  79. description = []
  80. while len(lines) > 0 and (len(lines[0].strip(' ')) == 0 or lines[0].startswith(' ')):
  81. line = lines.pop(0).lstrip(' ')
  82. description.append(line)
  83. attribute['description'] = '\n'.join(description)
  84. def _metadata():
  85. json_file = os.path.join(os.path.dirname(__file__), '../source/sklearn-metadata.json')
  86. with open(json_file, 'r', encoding='utf-8') as file:
  87. json_root = json.loads(file.read())
  88. for schema in json_root:
  89. name = schema['name']
  90. skip_modules = [
  91. 'lightgbm.',
  92. 'sklearn.svm.classes',
  93. 'sklearn.ensemble.forest.',
  94. 'sklearn.ensemble.weight_boosting.',
  95. 'sklearn.neural_network.multilayer_perceptron.',
  96. 'sklearn.tree.tree.'
  97. ]
  98. if not any(name.startswith(module) for module in skip_modules):
  99. class_definition = pydoc.locate(name)
  100. if not class_definition:
  101. raise KeyError('\'' + name + '\' not found.')
  102. docstring = class_definition.__doc__
  103. if not docstring:
  104. raise Exception('\'' + name + '\' missing __doc__.') # pylint: disable=broad-exception-raised
  105. headers = _split_docstring(docstring)
  106. if '' in headers:
  107. _update_description(schema, headers[''])
  108. if 'Parameters' in headers:
  109. _update_attributes(schema, headers['Parameters'])
  110. with open(json_file, 'w', encoding='utf-8') as file:
  111. file.write(json.dumps(json_root, sort_keys=False, indent=2))
  112. def main(): # pylint: disable=missing-function-docstring
  113. _metadata()
  114. if __name__ == '__main__':
  115. main()