sklearn-script.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from __future__ import unicode_literals
  2. from __future__ import print_function
  3. import io
  4. import json
  5. import os
  6. import pydoc
  7. import re
  8. import sys
  9. json_file = os.path.join(os.path.dirname(__file__), '../source/sklearn-metadata.json')
  10. json_data = open(json_file).read()
  11. json_root = json.loads(json_data)
  12. def split_docstring(docstring):
  13. headers = {}
  14. current_header = ''
  15. current_lines = []
  16. lines = docstring.split('\n')
  17. index = 0
  18. while index < len(lines):
  19. if index + 1 < len(lines) and len(lines[index + 1].strip(' ')) > 0 and len(lines[index + 1].strip(' ').strip('-')) == 0:
  20. headers[current_header] = current_lines
  21. current_header = lines[index].strip(' ')
  22. current_lines = []
  23. index = index + 1
  24. else:
  25. current_lines.append(lines[index])
  26. index = index + 1
  27. headers[current_header] = current_lines
  28. return headers
  29. def update_description(schema, lines):
  30. if len(''.join(lines).strip(' ')) > 0:
  31. for i in range(0, len(lines)):
  32. lines[i] = lines[i].lstrip(' ')
  33. schema['description'] = '\n'.join(lines)
  34. def update_attribute(schema, name, description, attribute_type, option, default_value):
  35. attribute = None
  36. if not 'attributes' in schema:
  37. schema['attributes'] = []
  38. for current_attribute in schema['attributes']:
  39. if 'name' in current_attribute and current_attribute['name'] == name:
  40. attribute = current_attribute
  41. break
  42. if not attribute:
  43. attribute = {}
  44. attribute['name'] = name
  45. schema['attributes'].append(attribute)
  46. attribute['description'] = description
  47. if attribute_type:
  48. attribute['type'] = attribute_type
  49. if option:
  50. attribute['option'] = option
  51. if default_value:
  52. if attribute_type == 'float32':
  53. if default_value == 'None':
  54. attribute['default'] = None
  55. elif default_value != "'auto'":
  56. attribute['default'] = float(default_value)
  57. else:
  58. attribute['default'] = default_value.strip("'").strip('"')
  59. elif attribute_type == 'int32':
  60. if default_value == 'None':
  61. attribute['default'] = None
  62. elif default_value == "'auto'" or default_value == '"auto"':
  63. attribute['default'] = default_value.strip("'").strip('"')
  64. else:
  65. attribute['default'] = int(default_value)
  66. elif attribute_type == 'string':
  67. attribute['default'] = default_value.strip("'").strip('"')
  68. elif attribute_type == 'boolean':
  69. if default_value == 'True':
  70. attribute['default'] = True
  71. elif default_value == 'False':
  72. attribute['default'] = False
  73. elif default_value == "'auto'":
  74. attribute['default'] = default_value.strip("'").strip('"')
  75. else:
  76. raise Exception("Unknown boolean default value '" + str(default_value) + "'.")
  77. else:
  78. if attribute_type:
  79. raise Exception("Unknown default type '" + attribute_type + "'.")
  80. else:
  81. if default_value == 'None':
  82. attribute['default'] = None
  83. else:
  84. attribute['default'] = default_value.strip("'")
  85. def update_attributes(schema, lines):
  86. index = 0
  87. while index < len(lines):
  88. line = lines[index]
  89. if line.endswith('.'):
  90. line = line[0:-1]
  91. colon = line.find(':')
  92. if colon == -1:
  93. raise Exception("Expected ':' in parameter.")
  94. name = line[0:colon].strip(' ')
  95. line = line[colon + 1:].strip(' ')
  96. attribute_type = None
  97. type_map = { 'float': 'float32', 'boolean': 'boolean', 'bool': 'boolean', 'string': 'string', 'int': 'int32', 'integer': 'int32' }
  98. skip_map = {
  99. "'sigmoid' or 'isotonic'",
  100. 'instance BaseEstimator',
  101. 'callable or None (default)',
  102. 'str or callable',
  103. "string {'english'}, list, or None (default)",
  104. 'tuple (min_n, max_n)',
  105. "string, {'word', 'char', 'char_wb'} or callable",
  106. "{'word', 'char'} or callable",
  107. "string, {'word', 'char'} or callable",
  108. 'int, float, None or string',
  109. "int, float, None or str",
  110. "int or None, optional (default=None)",
  111. "'l1', 'l2' or None, optional",
  112. "{'strict', 'ignore', 'replace'} (default='strict')",
  113. "{'ascii', 'unicode', None} (default=None)",
  114. "string {'english'}, list, or None (default=None)",
  115. "tuple (min_n, max_n) (default=(1, 1))",
  116. "float in range [0.0, 1.0] or int (default=1.0)",
  117. "float in range [0.0, 1.0] or int (default=1)",
  118. "'l1', 'l2' or None, optional (default='l2')",
  119. "{'scale', 'auto'} or float, optional (default='scale')",
  120. "str {'auto', 'full', 'arpack', 'randomized'}",
  121. "str {'filename', 'file', 'content'}",
  122. "str, {'word', 'char', 'char_wb'} or callable",
  123. "str {'english'}, list, or None (default=None)",
  124. "{'scale', 'auto'} or float, optional (default='scale')",
  125. "{'word', 'char', 'char_wb'} or callable, default='word'",
  126. "{'scale', 'auto'} or float, default='scale'",
  127. "{'uniform', 'distance'} or callable, default='uniform'",
  128. "int, RandomState instance or None (default)",
  129. "list of (string, transformer) tuples",
  130. "list of tuples",
  131. "{'drop', 'passthrough'} or estimator, default='drop'",
  132. "'auto' or a list of array-like, default='auto'",
  133. "{'first', 'if_binary'} or a array-like of shape (n_features,), default=None",
  134. "callable",
  135. "int or \"all\", optional, default=10",
  136. "number, string, np.nan (default) or None",
  137. "estimator object",
  138. "dict or list of dictionaries",
  139. "int, or str, default=n_jobs",
  140. "'raise' or numeric, default=np.nan",
  141. "'auto' or float, default=None",
  142. "float, default=np.finfo(float).eps",
  143. "int, float, str, np.nan or None, default=np.nan"
  144. }
  145. if line == 'str':
  146. line = 'string'
  147. if line in skip_map:
  148. line = ''
  149. elif line.startswith('{'):
  150. if line.endswith('}'):
  151. line = ''
  152. else:
  153. end = line.find('},')
  154. if end == -1:
  155. raise Exception("Expected '}' in parameter.")
  156. # attribute_type = line[0:end + 1]
  157. line = line[end + 2:].strip(' ')
  158. elif line.startswith("'"):
  159. while line.startswith("'"):
  160. end = line.find("',")
  161. if end == -1:
  162. raise Exception("Expected \' in parameter.")
  163. line = line[end + 2:].strip(' ')
  164. elif line in type_map:
  165. attribute_type = line
  166. line = ''
  167. elif line.startswith('int, RandomState instance or None,'):
  168. line = line[len('int, RandomState instance or None,'):]
  169. elif line.find('|') != -1:
  170. line = ''
  171. else:
  172. space = line.find(' {')
  173. if space != -1 and line[0:space] in type_map and line[space:].find('}') != -1:
  174. attribute_type = line[0:space]
  175. end = line[space:].find('}')
  176. line = line[space+end+1:]
  177. else:
  178. comma = line.find(',')
  179. if comma == -1:
  180. comma = line.find(' (')
  181. if comma == -1:
  182. raise Exception("Expected ',' in parameter.")
  183. attribute_type = line[0:comma]
  184. line = line[comma + 1:].strip(' ')
  185. if attribute_type in type_map:
  186. attribute_type = type_map[attribute_type]
  187. else:
  188. attribute_type = None
  189. # elif type == "{dict, 'balanced'}":
  190. # v = 'map'
  191. # else:
  192. # raise Exception("Unknown attribute type '" + attribute_type + "'.")
  193. option = None
  194. default = None
  195. while len(line.strip(' ')) > 0:
  196. line = line.strip(' ')
  197. if line.startswith('optional ') or line.startswith('optional,'):
  198. option = 'optional'
  199. line = line[9:]
  200. elif line.startswith('optional'):
  201. option = 'optional'
  202. line = ''
  203. elif line.startswith('('):
  204. close = line.index(')')
  205. if (close == -1):
  206. raise Exception("Expected ')' in parameter.")
  207. line = line[1:close]
  208. elif line.endswith(' by default'):
  209. default = line[0:-11]
  210. line = ''
  211. elif line.startswith('default =') or line.startswith('default :'):
  212. default = line[9:].strip(' ')
  213. line = ''
  214. elif line.startswith('default ') or line.startswith('default=') or line.startswith('default:'):
  215. default = line[8:].strip(' ')
  216. line = ''
  217. else:
  218. comma = line.index(',')
  219. if comma == -1:
  220. raise Exception("Expected ',' in parameter.")
  221. line = line[comma+1:]
  222. index = index + 1
  223. attribute_lines = []
  224. while index < len(lines) and (len(lines[index].strip(' ')) == 0 or lines[index].startswith(' ')):
  225. attribute_lines.append(lines[index].lstrip(' '))
  226. index = index + 1
  227. description = '\n'.join(attribute_lines)
  228. update_attribute(schema, name, description, attribute_type, option, default)
  229. for schema in json_root:
  230. name = schema['name']
  231. skip_modules = [
  232. 'lightgbm.',
  233. 'sklearn.svm.classes',
  234. 'sklearn.ensemble.forest.',
  235. 'sklearn.ensemble.weight_boosting.',
  236. 'sklearn.neural_network.multilayer_perceptron.',
  237. 'sklearn.tree.tree.'
  238. ]
  239. if not any(name.startswith(module) for module in skip_modules):
  240. class_definition = pydoc.locate(name)
  241. if not class_definition:
  242. raise Exception('\'' + name + '\' not found.')
  243. docstring = class_definition.__doc__
  244. if not docstring:
  245. raise Exception('\'' + name + '\' missing __doc__.')
  246. headers = split_docstring(docstring)
  247. if '' in headers:
  248. update_description(schema, headers[''])
  249. if 'Parameters' in headers:
  250. update_attributes(schema, headers['Parameters'])
  251. with io.open(json_file, 'w', newline='') as fout:
  252. json_data = json.dumps(json_root, sort_keys=False, indent=2)
  253. for line in json_data.splitlines():
  254. fout.write(line.rstrip())
  255. fout.write('\n')