keras_metadata.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. ''' Keras metadata script '''
  2. import json
  3. import os
  4. import pydoc
  5. import re
  6. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  7. def _read(path):
  8. with open(path, 'r', encoding='utf-8') as file:
  9. return file.read()
  10. def _find_docstring(class_name):
  11. class_definition = pydoc.locate(class_name)
  12. if not class_definition:
  13. raise Exception('\'' + class_name + '\' not found.') # pylint: disable=broad-exception-raised
  14. if not class_definition.__doc__:
  15. raise Exception('\'' + class_name + '\' missing __doc__.') # pylint: disable=broad-exception-raised
  16. return class_definition.__doc__
  17. def _parse_docstring(docstring):
  18. headers = []
  19. lines = docstring.splitlines()
  20. indentation = min(filter(lambda s: s > 0, map(lambda s: len(s) - len(s.lstrip()), lines)))
  21. lines = list((s[indentation:] if len(s) > len(s.lstrip()) else s) for s in lines)
  22. docstring = '\n'.join(lines)
  23. labels = [
  24. 'Args', 'Arguments', 'Variables', 'Fields', 'Yields', 'Call arguments', 'Raises',
  25. 'Examples', 'Example', 'Usage', 'Input shape', 'Output shape', 'Returns', 'References'
  26. ]
  27. tag_re = re.compile('(?<=\n)(' + '|'.join(labels) + '):\n', re.MULTILINE)
  28. parts = tag_re.split(docstring)
  29. headers.append(('', parts.pop(0)))
  30. while len(parts) > 0:
  31. headers.append((parts.pop(0), parts.pop(0)))
  32. return headers
  33. def _parse_arguments(arguments):
  34. result = []
  35. item_re = re.compile(r'^ ? ?(\*?\*?\w[\w.]*?\s*):\s', re.MULTILINE)
  36. content = item_re.split(arguments)
  37. if content.pop(0) != '':
  38. raise Exception('') # pylint: disable=broad-exception-raised
  39. while len(content) > 0:
  40. result.append((content.pop(0), content.pop(0)))
  41. return result
  42. def _convert_code_blocks(description):
  43. lines = description.splitlines()
  44. output = []
  45. while len(lines) > 0:
  46. line = lines.pop(0)
  47. if line.startswith('>>>') and len(lines) > 0 and \
  48. (lines[0].startswith('>>>') or lines[0].startswith('...')):
  49. output.append('```')
  50. output.append(line)
  51. while len(lines) > 0 and lines[0] != '':
  52. output.append(lines.pop(0))
  53. output.append('```')
  54. else:
  55. output.append(line)
  56. return '\n'.join(output)
  57. def _remove_indentation(value):
  58. lines = value.splitlines()
  59. indentation = min(map(lambda s: len(s) - len(s.lstrip()), \
  60. filter(lambda s: len(s) > 0, lines)))
  61. lines = list((s[indentation:] if len(s) > 0 else s) for s in lines)
  62. return '\n'.join(lines).strip()
  63. def _update_argument(schema, name, description):
  64. if not 'attributes' in schema:
  65. schema['attributes'] = []
  66. attribute = next((_ for _ in schema['attributes'] if _['name'] == name), None)
  67. if not attribute:
  68. attribute = {}
  69. attribute['name'] = name
  70. schema['attributes'].append(attribute)
  71. attribute['description'] = _remove_indentation(description)
  72. def _update_input(schema, description):
  73. if not 'inputs' in schema:
  74. schema['inputs'] = [ { 'name': 'input' } ]
  75. parameter = next((_ for _ in schema['inputs'] \
  76. if (_['name'] == 'input' or _['name'] == 'inputs')), None)
  77. if parameter:
  78. parameter['description'] = _remove_indentation(description)
  79. else:
  80. raise Exception('') # pylint: disable=broad-exception-raised
  81. def _update_output(schema, description):
  82. if not 'outputs' in schema:
  83. schema['outputs'] = [ { 'name': 'output' } ]
  84. parameter = next((param for param in schema['outputs'] if param['name'] == 'output'), None)
  85. if parameter:
  86. parameter['description'] = _remove_indentation(description)
  87. else:
  88. raise Exception('') # pylint: disable=broad-exception-raised
  89. def _update_examples(schema, value):
  90. if 'examples' in schema:
  91. del schema['examples']
  92. value = _convert_code_blocks(value)
  93. lines = value.splitlines()
  94. code = []
  95. summary = []
  96. while len(lines) > 0:
  97. line = lines.pop(0)
  98. if len(line) > 0:
  99. if line.startswith('```'):
  100. while len(lines) > 0:
  101. line = lines.pop(0)
  102. if line == '```':
  103. break
  104. code.append(line)
  105. else:
  106. summary.append(line)
  107. if len(code) > 0:
  108. example = {}
  109. if len(summary) > 0:
  110. example['summary'] = '\n'.join(summary)
  111. example['code'] = '\n'.join(code)
  112. if not 'examples' in schema:
  113. schema['examples'] = []
  114. schema['examples'].append(example)
  115. code = []
  116. summary = []
  117. def _update_references(schema, value):
  118. if 'references' in schema:
  119. del schema['references']
  120. references = []
  121. reference = ''
  122. lines = value.splitlines()
  123. for line in lines:
  124. if line.lstrip().startswith('- '):
  125. if len(reference) > 0:
  126. references.append(reference)
  127. reference = line.lstrip().lstrip('- ')
  128. else:
  129. if line.startswith(' '):
  130. line = line[2:]
  131. reference = ' '.join([ reference, line.strip() ])
  132. if len(reference) > 0:
  133. references.append(reference)
  134. for reference in references:
  135. if not 'references' in schema:
  136. schema['references'] = []
  137. schema['references'].append({ 'description': reference })
  138. def _update_headers(schema, docstring):
  139. headers = _parse_docstring(docstring)
  140. for header in headers:
  141. key = header[0]
  142. value = header[1]
  143. if key == '':
  144. description = _convert_code_blocks(value)
  145. schema['description'] = _remove_indentation(description)
  146. elif key in ('Args', 'Arguments'):
  147. arguments = _parse_arguments(value)
  148. for argument in arguments:
  149. _update_argument(schema, argument[0], argument[1])
  150. elif key == 'Input shape':
  151. _update_input(schema, value)
  152. elif key == 'Output shape':
  153. _update_output(schema, value)
  154. elif key in ('Example', 'Examples', 'Usage'):
  155. _update_examples(schema, value)
  156. elif key == 'References':
  157. _update_references(schema, value)
  158. elif key in ('Call arguments', 'Returns', 'Variables', 'Raises'):
  159. pass
  160. else:
  161. raise Exception('') # pylint: disable=broad-exception-raised
  162. def _metadata():
  163. root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  164. json_path = os.path.join(root, 'source', 'keras-metadata.json')
  165. json_root = json.loads(_read(json_path))
  166. for schema in json_root:
  167. if 'module' in schema:
  168. class_name = schema['module'] + '.' + schema['name']
  169. docstring = _find_docstring(class_name)
  170. _update_headers(schema, docstring)
  171. with open(json_path, 'w', encoding='utf-8') as file:
  172. content = json.dumps(json_root, sort_keys=False, indent=2)
  173. for line in content.splitlines():
  174. file.write(line.rstrip() + '\n')
  175. def main(): # pylint: disable=missing-function-docstring
  176. _metadata()
  177. if __name__ == '__main__':
  178. main()