keras-metadata.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. #!/usr/bin/env python
  2. from __future__ import unicode_literals
  3. from __future__ import print_function
  4. import io
  5. import json
  6. import pydoc
  7. import re
  8. import sys
  9. def count_leading_spaces(s):
  10. ws = re.search(r'\S', s)
  11. if ws:
  12. return ws.start()
  13. else:
  14. return 0
  15. def process_docstring(docstring):
  16. # First, extract code blocks and process them.
  17. code_blocks = []
  18. if '```' in docstring:
  19. tmp = docstring[:]
  20. while '```' in tmp:
  21. tmp = tmp[tmp.find('```'):]
  22. index = tmp[3:].find('```') + 6
  23. snippet = tmp[:index]
  24. # Place marker in docstring for later reinjection.
  25. docstring = docstring.replace(
  26. snippet, '$CODE_BLOCK_%d' % len(code_blocks))
  27. snippet_lines = snippet.split('\n')
  28. # Remove leading spaces.
  29. num_leading_spaces = snippet_lines[-1].find('`')
  30. snippet_lines = ([snippet_lines[0]] +
  31. [line[num_leading_spaces:]
  32. for line in snippet_lines[1:]])
  33. # Most code snippets have 3 or 4 more leading spaces
  34. # on inner lines, but not all. Remove them.
  35. inner_lines = snippet_lines[1:-1]
  36. leading_spaces = None
  37. for line in inner_lines:
  38. if not line or line[0] == '\n':
  39. continue
  40. spaces = count_leading_spaces(line)
  41. if leading_spaces is None:
  42. leading_spaces = spaces
  43. if spaces < leading_spaces:
  44. leading_spaces = spaces
  45. if leading_spaces:
  46. snippet_lines = ([snippet_lines[0]] +
  47. [line[leading_spaces:]
  48. for line in snippet_lines[1:-1]] +
  49. [snippet_lines[-1]])
  50. snippet = '\n'.join(snippet_lines)
  51. code_blocks.append(snippet)
  52. tmp = tmp[index:]
  53. # Format docstring section titles.
  54. docstring = re.sub(r'\n(\s+)# (.*)\n',
  55. r'\n\1__\2__\n\n',
  56. docstring)
  57. # Format docstring lists.
  58. docstring = re.sub(r' ([^\s\\\(]+):(.*)\n',
  59. r' - __\1__:\2\n',
  60. docstring)
  61. # Strip all leading spaces.
  62. lines = docstring.split('\n')
  63. docstring = '\n'.join([line.lstrip(' ') for line in lines])
  64. # Reinject code blocks.
  65. for i, code_block in enumerate(code_blocks):
  66. docstring = docstring.replace(
  67. '$CODE_BLOCK_%d' % i, code_block)
  68. return docstring
  69. def split_docstring(docstring):
  70. headers = {}
  71. current_header = ''
  72. current_lines = []
  73. lines = docstring.split('\n')
  74. for line in lines:
  75. if line.startswith('__') and line.endswith('__'):
  76. headers[current_header] = current_lines
  77. current_lines = []
  78. current_header = line[2:-2]
  79. if current_header == 'Masking' or current_header.startswith('Note '):
  80. headline = '**' + current_header + '**'
  81. current_lines = headers['']
  82. current_header = ''
  83. current_lines.append(headline)
  84. else:
  85. current_lines.append(line)
  86. if len(current_lines) > 0:
  87. headers[current_header] = current_lines
  88. return headers
  89. def update_hyperlink(description):
  90. def replace_hyperlink(match):
  91. name = match.group(1)
  92. link = match.group(2)
  93. if link.endswith('.md'):
  94. if link.startswith('../'):
  95. link = link.replace('../', 'https://keras.io/').rstrip('.md')
  96. else:
  97. link = 'https://keras.io/layers/' + link.rstrip('.md')
  98. return '[' + name + '](' + link + ')'
  99. return match.group(0)
  100. return re.sub(r'\[(.*?)\]\((.*?)\)', replace_hyperlink, description)
  101. def update_argument(schema, name, lines):
  102. attribute = None
  103. if not 'attributes' in schema:
  104. schema['attributes'] = []
  105. for current_attribute in schema['attributes']:
  106. if 'name' in current_attribute and current_attribute['name'] == name:
  107. attribute = current_attribute
  108. break
  109. if not attribute:
  110. attribute = {}
  111. attribute['name'] = name
  112. schema['attributes'].append(attribute)
  113. description = '\n'.join(lines)
  114. description = update_hyperlink(description)
  115. attribute['description'] = description
  116. def update_arguments(schema, lines):
  117. argument_name = None
  118. argument_lines = []
  119. for line in lines:
  120. if line.startswith('- __'):
  121. line = line.lstrip('- ')
  122. colon = line.index(':')
  123. if colon > 0:
  124. name = line[0:colon]
  125. line = line[colon+1:].lstrip(' ')
  126. if name.startswith('__') and name.endswith('__'):
  127. if argument_name:
  128. update_argument(schema, argument_name, argument_lines)
  129. argument_name = name[2:-2]
  130. argument_lines = []
  131. if argument_name:
  132. argument_lines.append(line)
  133. if argument_name:
  134. update_argument(schema, argument_name, argument_lines)
  135. return
  136. def update_examples(schema, lines):
  137. if 'examples' in schema:
  138. del schema['examples']
  139. summary_lines = []
  140. code_lines = None
  141. for line in lines:
  142. if line.startswith('```'):
  143. if code_lines != None:
  144. example = {}
  145. example['code'] = '\n'.join(code_lines)
  146. if len(summary_lines) > 0:
  147. example['summary'] = '\n'.join(summary_lines)
  148. if not 'examples' in schema:
  149. schema['examples'] = []
  150. schema['examples'].append(example)
  151. summary_lines = []
  152. code_lines = None
  153. else:
  154. code_lines = [ ]
  155. else:
  156. if code_lines != None:
  157. code_lines.append(line)
  158. elif line != '':
  159. summary_lines.append(line)
  160. def update_references(schema, lines):
  161. if 'references' in schema:
  162. del schema['references']
  163. for line in lines:
  164. if line != '':
  165. line = line.lstrip('- ')
  166. if not 'references' in schema:
  167. schema['references'] = []
  168. schema['references'].append({ 'description': line })
  169. def update_input(schema, description):
  170. entry = None
  171. if 'inputs' in schema:
  172. for current_input in schema['inputs']:
  173. if current_input['name'] == 'input':
  174. entry = current_input
  175. break
  176. else:
  177. entry = {}
  178. entry['name'] = 'input'
  179. schema['inputs'] = []
  180. schema['inputs'].append(entry)
  181. if entry:
  182. entry['description'] = description
  183. def update_output(schema, description):
  184. entry = None
  185. if 'outputs' in schema:
  186. for current_output in schema['outputs']:
  187. if current_output['name'] == 'output':
  188. entry = current_output
  189. break
  190. else:
  191. entry = {}
  192. entry['name'] = 'output'
  193. schema['outputs'] = []
  194. schema['outputs'].append(entry)
  195. if entry:
  196. entry['description'] = description
  197. json_file = '../src/keras-metadata.json'
  198. json_data = open(json_file).read()
  199. json_root = json.loads(json_data)
  200. for entry in json_root:
  201. name = entry['name']
  202. schema = entry['schema']
  203. if 'package' in schema:
  204. class_name = schema['package'] + '.' + name
  205. class_definition = pydoc.locate(class_name)
  206. if not class_definition:
  207. raise Exception('\'' + class_name + '\' not found.')
  208. docstring = class_definition.__doc__
  209. if not docstring:
  210. raise Exception('\'' + class_name + '\' missing __doc__.')
  211. docstring = process_docstring(docstring)
  212. headers = split_docstring(docstring)
  213. if '' in headers:
  214. schema['description'] = '\n'.join(headers[''])
  215. del headers['']
  216. if 'Arguments' in headers:
  217. update_arguments(schema, headers['Arguments'])
  218. del headers['Arguments']
  219. if 'Input shape' in headers:
  220. update_input(schema, '\n'.join(headers['Input shape']))
  221. del headers['Input shape']
  222. if 'Output shape' in headers:
  223. update_output(schema, '\n'.join(headers['Output shape']))
  224. del headers['Output shape']
  225. if 'Examples' in headers:
  226. update_examples(schema, headers['Examples'])
  227. del headers['Examples']
  228. if 'Example' in headers:
  229. update_examples(schema, headers['Example'])
  230. del headers['Example']
  231. if 'References' in headers:
  232. update_references(schema, headers['References'])
  233. del headers['References']
  234. if 'Raises' in headers:
  235. del headers['Raises']
  236. if len(headers) > 0:
  237. raise Exception('\'' + class_name + '.__doc__\' contains unprocessed headers.')
  238. with io.open(json_file, 'w', newline='') as fout:
  239. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  240. for line in json_data.splitlines():
  241. line = line.rstrip()
  242. if sys.version_info[0] < 3:
  243. line = unicode(line)
  244. fout.write(line)
  245. fout.write('\n')