sklearn_script.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. ''' scikit-learn metadata script '''
  2. import io
  3. import json
  4. import os
  5. import pydoc
  6. import re
  7. import sys
  8. def _split_docstring(value):
  9. headers = {}
  10. current_header = ''
  11. current_lines = []
  12. lines = value.split('\n')
  13. index = 0
  14. while index < len(lines):
  15. if index + 1 < len(lines) and len(lines[index + 1].strip(' ')) > 0 and len(lines[index + 1].strip(' ').strip('-')) == 0:
  16. headers[current_header] = current_lines
  17. current_header = lines[index].strip(' ')
  18. current_lines = []
  19. index = index + 1
  20. else:
  21. current_lines.append(lines[index])
  22. index = index + 1
  23. headers[current_header] = current_lines
  24. return headers
  25. def _update_description(schema, lines):
  26. if len(''.join(lines).strip(' ')) > 0:
  27. for i, value in enumerate(lines):
  28. lines[i] = value.lstrip(' ')
  29. schema['description'] = '\n'.join(lines)
  30. def _update_attribute(schema, name, description, attribute_type, optional, default_value):
  31. attribute = None
  32. if not 'attributes' in schema:
  33. schema['attributes'] = []
  34. for current_attribute in schema['attributes']:
  35. if 'name' in current_attribute and current_attribute['name'] == name:
  36. attribute = current_attribute
  37. break
  38. if not attribute:
  39. attribute = {}
  40. attribute['name'] = name
  41. schema['attributes'].append(attribute)
  42. attribute['description'] = description
  43. if attribute_type:
  44. attribute['type'] = attribute_type
  45. if optional:
  46. attribute['optional'] = True
  47. if default_value:
  48. if attribute_type == 'float32':
  49. if default_value == 'None':
  50. attribute['default'] = None
  51. elif default_value != "'auto'":
  52. attribute['default'] = float(default_value)
  53. else:
  54. attribute['default'] = default_value.strip("'").strip('"')
  55. elif attribute_type == 'int32':
  56. if default_value == 'None':
  57. attribute['default'] = None
  58. elif default_value in ("'auto'", '"auto"'):
  59. attribute['default'] = default_value.strip("'").strip('"')
  60. else:
  61. attribute['default'] = int(default_value)
  62. elif attribute_type == 'string':
  63. attribute['default'] = default_value.strip("'").strip('"')
  64. elif attribute_type == 'boolean':
  65. if default_value == 'True':
  66. attribute['default'] = True
  67. elif default_value == 'False':
  68. attribute['default'] = False
  69. elif default_value == "'auto'":
  70. attribute['default'] = default_value.strip("'").strip('"')
  71. else:
  72. raise Exception("Unknown boolean default value '" + str(default_value) + "'.")
  73. else:
  74. if attribute_type:
  75. raise Exception("Unknown default type '" + attribute_type + "'.")
  76. if default_value == 'None':
  77. attribute['default'] = None
  78. else:
  79. attribute['default'] = default_value.strip("'")
  80. def _update_attributes(schema, lines):
  81. i = 0
  82. while i < len(lines):
  83. line = lines[i]
  84. line = re.sub(r',\s+', ', ', line)
  85. if line.endswith('.'):
  86. line = line[0:-1]
  87. colon = line.find(':')
  88. if colon == -1:
  89. raise Exception("Expected ':' in parameter.")
  90. name = line[0:colon].strip(' ')
  91. line = line[colon + 1:].strip(' ')
  92. attribute_type = None
  93. type_map = {
  94. 'float': 'float32',
  95. 'boolean': 'boolean',
  96. 'bool': 'boolean',
  97. 'string': 'string',
  98. 'int': 'int32',
  99. 'integer': 'int32'
  100. }
  101. skip_map = {
  102. "'sigmoid' or 'isotonic'",
  103. 'instance BaseEstimator',
  104. 'callable or None (default)',
  105. 'str or callable',
  106. "string {'english'}, list, or None (default)",
  107. 'tuple (min_n, max_n)',
  108. "string, {'word', 'char', 'char_wb'} or callable",
  109. "{'word', 'char'} or callable",
  110. "string, {'word', 'char'} or callable",
  111. 'int, float, None or string',
  112. "int, float, None or str",
  113. "int or None, optional (default=None)",
  114. "'l1', 'l2' or None, optional",
  115. "{'strict', 'ignore', 'replace'} (default='strict')",
  116. "{'ascii', 'unicode', None} (default=None)",
  117. "string {'english'}, list, or None (default=None)",
  118. "tuple (min_n, max_n) (default=(1, 1))",
  119. "float in range [0.0, 1.0] or int (default=1.0)",
  120. "float in range [0.0, 1.0] or int (default=1)",
  121. "'l1', 'l2' or None, optional (default='l2')",
  122. "str {'auto', 'full', 'arpack', 'randomized'}",
  123. "str {'filename', 'file', 'content'}",
  124. "str, {'word', 'char', 'char_wb'} or callable",
  125. "str {'english'}, list, or None (default=None)",
  126. "{'scale', 'auto'} or float, optional (default='scale')",
  127. "{'word', 'char', 'char_wb'} or callable, default='word'",
  128. "{'scale', 'auto'} or float, default='scale'",
  129. "{'uniform', 'distance'} or callable, default='uniform'",
  130. "int, RandomState instance or None (default)",
  131. "list of (string, transformer) tuples",
  132. "list of tuples",
  133. "{'drop', 'passthrough'} or estimator, default='drop'",
  134. "'auto' or a list of array-like, default='auto'",
  135. "callable",
  136. "int or \"all\", optional, default=10",
  137. "number, string, np.nan (default) or None",
  138. "estimator object",
  139. "dict or list of dictionaries",
  140. "int, or str, default=n_jobs",
  141. "'raise' or numeric, default=np.nan",
  142. "'auto' or float, default=None",
  143. "float, default=np.finfo(float).eps",
  144. "int, float, str, np.nan or None, default=np.nan",
  145. "list of (str, transformer) tuples",
  146. "int, float, str, np.nan, None or pandas.NA, default=np.nan",
  147. "{'first', 'if_binary'} or an array-like of shape (n_features,), default=None",
  148. "{'first', 'if_binary'} or a array-like of shape (n_features,), default=None",
  149. "{'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'} or callable, default='rbf'",
  150. "estimator instance",
  151. "{'ascii', 'unicode'} or callable, default=None",
  152. "{'l1', 'l2'} or None, default='l2'"
  153. }
  154. if line == 'str':
  155. line = 'string'
  156. if line in skip_map:
  157. line = ''
  158. elif line.startswith('{'):
  159. if line.endswith('}'):
  160. line = ''
  161. else:
  162. end = line.find('},')
  163. if end == -1:
  164. raise Exception("Expected '}' in parameter.")
  165. # attribute_type = line[0:end + 1]
  166. line = line[end + 2:].strip(' ')
  167. elif line.startswith("'"):
  168. while line.startswith("'"):
  169. end = line.find("',")
  170. if end == -1:
  171. raise Exception("Expected \' in parameter.")
  172. line = line[end + 2:].strip(' ')
  173. elif line in type_map:
  174. attribute_type = line
  175. line = ''
  176. elif line.startswith('int, RandomState instance or None,'):
  177. line = line[len('int, RandomState instance or None,'):]
  178. elif line.startswith('int, or str, '):
  179. line = line[len('int, or str, '):]
  180. elif line.find('|') != -1:
  181. line = ''
  182. else:
  183. space = line.find(' {')
  184. if space != -1 and line[0:space] in type_map and line[space:].find('}') != -1:
  185. attribute_type = line[0:space]
  186. end = line[space:].find('}')
  187. line = line[space+end+1:]
  188. else:
  189. comma = line.find(',')
  190. if comma == -1:
  191. comma = line.find(' (')
  192. if comma == -1:
  193. raise Exception("Expected ',' in parameter.")
  194. attribute_type = line[0:comma]
  195. line = line[comma + 1:].strip(' ')
  196. attribute_type = type_map.get(attribute_type, None)
  197. # elif type == "{dict, 'balanced'}":
  198. # v = 'map'
  199. # else:
  200. # raise Exception("Unknown attribute type '" + attribute_type + "'.")
  201. optional = False
  202. default = None
  203. while len(line.strip(' ')) > 0:
  204. line = line.strip(' ')
  205. if line.startswith('optional ') or line.startswith('optional,'):
  206. optional = True
  207. line = line[9:]
  208. elif line.startswith('optional'):
  209. optional = True
  210. line = ''
  211. elif line.startswith('('):
  212. close = line.index(')')
  213. if close == -1:
  214. raise Exception("Expected ')' in parameter.")
  215. line = line[1:close]
  216. elif line.endswith(' by default'):
  217. default = line[0:-11]
  218. line = ''
  219. elif line.startswith('default =') or line.startswith('default :'):
  220. default = line[9:].strip(' ')
  221. line = ''
  222. elif line.startswith('default ') or line.startswith('default=') or line.startswith('default:'):
  223. default = line[8:].strip(' ')
  224. line = ''
  225. else:
  226. comma = line.index(',')
  227. if comma == -1:
  228. raise Exception("Expected ',' in parameter.")
  229. line = line[comma+1:]
  230. i = i + 1
  231. attribute_lines = []
  232. while i < len(lines) and (len(lines[i].strip(' ')) == 0 or lines[i].startswith(' ')):
  233. attribute_lines.append(lines[i].lstrip(' '))
  234. i = i + 1
  235. description = '\n'.join(attribute_lines)
  236. _update_attribute(schema, name, description, attribute_type, optional, default)
  237. def _metadata():
  238. json_file = os.path.join(os.path.dirname(__file__), '../source/sklearn-metadata.json')
  239. with open(json_file, 'r', encoding='utf-8') as file:
  240. json_root = json.loads(file.read())
  241. for schema in json_root:
  242. name = schema['name']
  243. skip_modules = [
  244. 'lightgbm.',
  245. 'sklearn.svm.classes',
  246. 'sklearn.ensemble.forest.',
  247. 'sklearn.ensemble.weight_boosting.',
  248. 'sklearn.neural_network.multilayer_perceptron.',
  249. 'sklearn.tree.tree.'
  250. ]
  251. if not any(name.startswith(module) for module in skip_modules):
  252. class_definition = pydoc.locate(name)
  253. if not class_definition:
  254. raise Exception('\'' + name + '\' not found.')
  255. docstring = class_definition.__doc__
  256. if not docstring:
  257. raise Exception('\'' + name + '\' missing __doc__.')
  258. headers = _split_docstring(docstring)
  259. if '' in headers:
  260. _update_description(schema, headers[''])
  261. if 'Parameters' in headers:
  262. _update_attributes(schema, headers['Parameters'])
  263. with io.open(json_file, 'w', encoding='utf-8', newline='') as fout:
  264. json_data = json.dumps(json_root, sort_keys=False, indent=2)
  265. for line in json_data.splitlines():
  266. fout.write(line.rstrip())
  267. fout.write('\n')
  268. def main(): # pylint: disable=missing-function-docstring
  269. command_table = { 'metadata': _metadata }
  270. command = sys.argv[1]
  271. command_table[command]()
  272. if __name__ == '__main__':
  273. main()