sklearn_script.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. ''' scikit-learn metadata script '''
  2. import io
  3. import json
  4. import os
  5. import pydoc
  6. import re
  7. import sys
  8. def _split_docstring(value):
  9. headers = {}
  10. current_header = ''
  11. current_lines = []
  12. lines = value.split('\n')
  13. index = 0
  14. while index < len(lines):
  15. if index + 1 < len(lines) and len(lines[index + 1].strip(' ')) > 0 and \
  16. len(lines[index + 1].strip(' ').strip('-')) == 0:
  17. headers[current_header] = current_lines
  18. current_header = lines[index].strip(' ')
  19. current_lines = []
  20. index = index + 1
  21. else:
  22. current_lines.append(lines[index])
  23. index = index + 1
  24. headers[current_header] = current_lines
  25. return headers
  26. def _update_description(schema, lines):
  27. if len(''.join(lines).strip(' ')) > 0:
  28. for i, value in enumerate(lines):
  29. lines[i] = value.lstrip(' ')
  30. schema['description'] = '\n'.join(lines)
  31. def _update_attribute(schema, name, description, attribute_type, optional, default_value):
  32. attribute = None
  33. if not 'attributes' in schema:
  34. schema['attributes'] = []
  35. for current_attribute in schema['attributes']:
  36. if 'name' in current_attribute and current_attribute['name'] == name:
  37. attribute = current_attribute
  38. break
  39. if not attribute:
  40. attribute = {}
  41. attribute['name'] = name
  42. schema['attributes'].append(attribute)
  43. attribute['description'] = description
  44. if attribute_type:
  45. attribute['type'] = attribute_type
  46. if optional:
  47. attribute['optional'] = True
  48. if default_value:
  49. if attribute_type == 'float32':
  50. if default_value == 'None':
  51. attribute['default'] = None
  52. elif default_value != "'auto'":
  53. attribute['default'] = float(default_value)
  54. else:
  55. attribute['default'] = default_value.strip("'").strip('"')
  56. elif attribute_type == 'int32':
  57. if default_value == 'None':
  58. attribute['default'] = None
  59. elif default_value in ("'auto'", '"auto"'):
  60. attribute['default'] = default_value.strip("'").strip('"')
  61. else:
  62. attribute['default'] = int(default_value)
  63. elif attribute_type == 'string':
  64. attribute['default'] = default_value.strip("'").strip('"')
  65. elif attribute_type == 'boolean':
  66. if default_value == 'True':
  67. attribute['default'] = True
  68. elif default_value == 'False':
  69. attribute['default'] = False
  70. elif default_value == "'auto'":
  71. attribute['default'] = default_value.strip("'").strip('"')
  72. else:
  73. raise Exception("Unknown boolean default value '" + str(default_value) + "'.")
  74. else:
  75. if attribute_type:
  76. raise Exception("Unknown default type '" + attribute_type + "'.")
  77. if default_value == 'None':
  78. attribute['default'] = None
  79. else:
  80. attribute['default'] = default_value.strip("'")
  81. def _update_attributes(schema, lines):
  82. i = 0
  83. while i < len(lines):
  84. line = lines[i]
  85. line = re.sub(r',\s+', ', ', line)
  86. if line.endswith('.'):
  87. line = line[0:-1]
  88. colon = line.find(':')
  89. if colon == -1:
  90. raise Exception("Expected ':' in parameter.")
  91. name = line[0:colon].strip(' ')
  92. line = line[colon + 1:].strip(' ')
  93. attribute_type = None
  94. type_map = {
  95. 'float': 'float32',
  96. 'boolean': 'boolean',
  97. 'bool': 'boolean',
  98. 'string': 'string',
  99. 'int': 'int32',
  100. 'integer': 'int32'
  101. }
  102. skip_map = {
  103. "'sigmoid' or 'isotonic'",
  104. 'instance BaseEstimator',
  105. 'callable or None (default)',
  106. 'str or callable',
  107. "string {'english'}, list, or None (default)",
  108. 'tuple (min_n, max_n)',
  109. "string, {'word', 'char', 'char_wb'} or callable",
  110. "{'word', 'char'} or callable",
  111. "string, {'word', 'char'} or callable",
  112. 'int, float, None or string',
  113. "int, float, None or str",
  114. "int or None, optional (default=None)",
  115. "'l1', 'l2' or None, optional",
  116. "{'strict', 'ignore', 'replace'} (default='strict')",
  117. "{'ascii', 'unicode', None} (default=None)",
  118. "string {'english'}, list, or None (default=None)",
  119. "tuple (min_n, max_n) (default=(1, 1))",
  120. "float in range [0.0, 1.0] or int (default=1.0)",
  121. "float in range [0.0, 1.0] or int (default=1)",
  122. "'l1', 'l2' or None, optional (default='l2')",
  123. "str {'auto', 'full', 'arpack', 'randomized'}",
  124. "str {'filename', 'file', 'content'}",
  125. "str, {'word', 'char', 'char_wb'} or callable",
  126. "str {'english'}, list, or None (default=None)",
  127. "{'scale', 'auto'} or float, optional (default='scale')",
  128. "{'word', 'char', 'char_wb'} or callable, default='word'",
  129. "{'scale', 'auto'} or float, default='scale'",
  130. "{'uniform', 'distance'} or callable, default='uniform'",
  131. "int, RandomState instance or None (default)",
  132. "list of (string, transformer) tuples",
  133. "list of tuples",
  134. "{'drop', 'passthrough'} or estimator, default='drop'",
  135. "'auto' or a list of array-like, default='auto'",
  136. "callable",
  137. "int or \"all\", optional, default=10",
  138. "number, string, np.nan (default) or None",
  139. "estimator object",
  140. "dict or list of dictionaries",
  141. "int, or str, default=n_jobs",
  142. "'raise' or numeric, default=np.nan",
  143. "'auto' or float, default=None",
  144. "float, default=np.finfo(float).eps",
  145. "int, float, str, np.nan or None, default=np.nan",
  146. "list of (str, transformer) tuples",
  147. "int, float, str, np.nan, None or pandas.NA, default=np.nan",
  148. "{'first', 'if_binary'} or an array-like of shape (n_features,), default=None",
  149. "{'first', 'if_binary'} or a array-like of shape (n_features,), default=None",
  150. "{'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'} or callable, default='rbf'",
  151. "estimator instance",
  152. "{'ascii', 'unicode'} or callable, default=None",
  153. "{'l1', 'l2'} or None, default='l2'"
  154. }
  155. if line == 'str':
  156. line = 'string'
  157. if line in skip_map:
  158. line = ''
  159. elif line.startswith('{'):
  160. if line.endswith('}'):
  161. line = ''
  162. else:
  163. end = line.find('},')
  164. if end == -1:
  165. raise Exception("Expected '}' in parameter.")
  166. # attribute_type = line[0:end + 1]
  167. line = line[end + 2:].strip(' ')
  168. elif line.startswith("'"):
  169. while line.startswith("'"):
  170. end = line.find("',")
  171. if end == -1:
  172. raise Exception("Expected \' in parameter.")
  173. line = line[end + 2:].strip(' ')
  174. elif line in type_map:
  175. attribute_type = line
  176. line = ''
  177. elif line.startswith('int, RandomState instance or None,'):
  178. line = line[len('int, RandomState instance or None,'):]
  179. elif line.startswith('int, or str, '):
  180. line = line[len('int, or str, '):]
  181. elif line.find('|') != -1:
  182. line = ''
  183. else:
  184. space = line.find(' {')
  185. if space != -1 and line[0:space] in type_map and line[space:].find('}') != -1:
  186. attribute_type = line[0:space]
  187. end = line[space:].find('}')
  188. line = line[space+end+1:]
  189. else:
  190. comma = line.find(',')
  191. if comma == -1:
  192. comma = line.find(' (')
  193. if comma == -1:
  194. raise Exception("Expected ',' in parameter.")
  195. attribute_type = line[0:comma]
  196. line = line[comma + 1:].strip(' ')
  197. attribute_type = type_map.get(attribute_type, None)
  198. # elif type == "{dict, 'balanced'}":
  199. # v = 'map'
  200. # else:
  201. # raise Exception("Unknown attribute type '" + attribute_type + "'.")
  202. optional = False
  203. default = None
  204. while len(line.strip(' ')) > 0:
  205. line = line.strip(' ')
  206. if line.startswith('optional ') or line.startswith('optional,'):
  207. optional = True
  208. line = line[9:]
  209. elif line.startswith('optional'):
  210. optional = True
  211. line = ''
  212. elif line.startswith('('):
  213. close = line.index(')')
  214. if close == -1:
  215. raise Exception("Expected ')' in parameter.")
  216. line = line[1:close]
  217. elif line.endswith(' by default'):
  218. default = line[0:-11]
  219. line = ''
  220. elif line.startswith('default =') or line.startswith('default :'):
  221. default = line[9:].strip(' ')
  222. line = ''
  223. elif line.startswith('default ') or \
  224. line.startswith('default=') or line.startswith('default:'):
  225. default = line[8:].strip(' ')
  226. line = ''
  227. else:
  228. comma = line.index(',')
  229. if comma == -1:
  230. raise Exception("Expected ',' in parameter.")
  231. line = line[comma+1:]
  232. i = i + 1
  233. attribute_lines = []
  234. while i < len(lines) and (len(lines[i].strip(' ')) == 0 or lines[i].startswith(' ')):
  235. attribute_lines.append(lines[i].lstrip(' '))
  236. i = i + 1
  237. description = '\n'.join(attribute_lines)
  238. _update_attribute(schema, name, description, attribute_type, optional, default)
  239. def _metadata():
  240. json_file = os.path.join(os.path.dirname(__file__), '../source/sklearn-metadata.json')
  241. with open(json_file, 'r', encoding='utf-8') as file:
  242. json_root = json.loads(file.read())
  243. for schema in json_root:
  244. name = schema['name']
  245. skip_modules = [
  246. 'lightgbm.',
  247. 'sklearn.svm.classes',
  248. 'sklearn.ensemble.forest.',
  249. 'sklearn.ensemble.weight_boosting.',
  250. 'sklearn.neural_network.multilayer_perceptron.',
  251. 'sklearn.tree.tree.'
  252. ]
  253. if not any(name.startswith(module) for module in skip_modules):
  254. class_definition = pydoc.locate(name)
  255. if not class_definition:
  256. raise Exception('\'' + name + '\' not found.')
  257. docstring = class_definition.__doc__
  258. if not docstring:
  259. raise Exception('\'' + name + '\' missing __doc__.')
  260. headers = _split_docstring(docstring)
  261. if '' in headers:
  262. _update_description(schema, headers[''])
  263. if 'Parameters' in headers:
  264. _update_attributes(schema, headers['Parameters'])
  265. with io.open(json_file, 'w', encoding='utf-8', newline='') as fout:
  266. json_data = json.dumps(json_root, sort_keys=False, indent=2)
  267. for line in json_data.splitlines():
  268. fout.write(line.rstrip())
  269. fout.write('\n')
  270. def main(): # pylint: disable=missing-function-docstring
  271. command_table = { 'metadata': _metadata }
  272. command = sys.argv[1]
  273. command_table[command]()
  274. if __name__ == '__main__':
  275. main()