2
0

keras_metadata.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. ''' Keras metadata script '''
  2. import json
  3. import os
  4. import pydoc
  5. import re
  6. os.environ["KERAS_BACKEND"] = "jax"
  7. def _read(path):
  8. with open(path, 'r', encoding='utf-8') as file:
  9. return file.read()
  10. def _parse_docstring(docstring):
  11. headers = []
  12. lines = docstring.splitlines()
  13. indentation = min(filter(lambda s: s > 0, map(lambda s: len(s) - len(s.lstrip()), lines)))
  14. lines = list((s[indentation:] if len(s) > len(s.lstrip()) else s) for s in lines)
  15. docstring = '\n'.join(lines)
  16. labels = [
  17. 'Args', 'Arguments', 'Variables', 'Fields', 'Yields', 'Call arguments', 'Raises',
  18. 'Examples', 'Example', 'Usage', 'Input shape', 'Output shape', 'Returns', 'References'
  19. ]
  20. tag_re = re.compile('(?<=\n)(' + '|'.join(labels) + '):\n', re.MULTILINE)
  21. parts = tag_re.split(docstring)
  22. headers.append(('', parts.pop(0)))
  23. while len(parts) > 0:
  24. headers.append((parts.pop(0), parts.pop(0)))
  25. return headers
  26. def _parse_arguments(arguments):
  27. result = []
  28. item_re = re.compile(r'^ ? ?(\*?\*?\w[\w.]*?\s*):\s', re.MULTILINE)
  29. content = item_re.split(arguments)
  30. if content.pop(0) != '':
  31. raise Exception('') # pylint: disable=broad-exception-raised
  32. while len(content) > 0:
  33. result.append((content.pop(0), content.pop(0)))
  34. return result
  35. def _convert_code_blocks(description):
  36. lines = description.splitlines()
  37. output = []
  38. while len(lines) > 0:
  39. line = lines.pop(0)
  40. if line.startswith('>>>') and len(lines) > 0 and \
  41. (lines[0].startswith('>>>') or lines[0].startswith('...')):
  42. output.append('```')
  43. output.append(line)
  44. while len(lines) > 0 and lines[0] != '':
  45. output.append(lines.pop(0))
  46. output.append('```')
  47. else:
  48. output.append(line)
  49. return '\n'.join(output)
  50. def _remove_indentation(value):
  51. lines = value.splitlines()
  52. indentation = min(map(lambda s: len(s) - len(s.lstrip()), \
  53. filter(lambda s: len(s) > 0, lines)))
  54. lines = list((s[indentation:] if len(s) > 0 else s) for s in lines)
  55. return '\n'.join(lines).strip()
  56. def _update_argument(schema, name, description):
  57. if not 'attributes' in schema:
  58. schema['attributes'] = []
  59. attribute = next((_ for _ in schema['attributes'] if _['name'] == name), None)
  60. if not attribute:
  61. attribute = {}
  62. attribute['name'] = name
  63. schema['attributes'].append(attribute)
  64. attribute['description'] = _remove_indentation(description)
  65. def _update_input(schema, description):
  66. if not 'inputs' in schema:
  67. schema['inputs'] = [ { 'name': 'input' } ]
  68. parameter = next((_ for _ in schema['inputs'] \
  69. if (_['name'] == 'input' or _['name'] == 'inputs')), None)
  70. if parameter:
  71. parameter['description'] = _remove_indentation(description)
  72. else:
  73. raise Exception('') # pylint: disable=broad-exception-raised
  74. def _update_output(schema, description):
  75. if not 'outputs' in schema:
  76. schema['outputs'] = [ { 'name': 'output' } ]
  77. parameter = next((param for param in schema['outputs'] if param['name'] == 'output'), None)
  78. if parameter:
  79. parameter['description'] = _remove_indentation(description)
  80. else:
  81. raise Exception('') # pylint: disable=broad-exception-raised
  82. def _update_examples(schema, value):
  83. if 'examples' in schema:
  84. del schema['examples']
  85. value = _convert_code_blocks(value)
  86. lines = value.splitlines()
  87. code = []
  88. summary = []
  89. while len(lines) > 0:
  90. line = lines.pop(0)
  91. if len(line) > 0:
  92. if line.startswith('```'):
  93. while len(lines) > 0:
  94. line = lines.pop(0)
  95. if line == '```':
  96. break
  97. code.append(line)
  98. else:
  99. summary.append(line)
  100. if len(code) > 0:
  101. example = {}
  102. if len(summary) > 0:
  103. example['summary'] = '\n'.join(summary)
  104. example['code'] = '\n'.join(code)
  105. if not 'examples' in schema:
  106. schema['examples'] = []
  107. schema['examples'].append(example)
  108. code = []
  109. summary = []
  110. def _update_references(schema, value):
  111. if 'references' in schema:
  112. del schema['references']
  113. references = []
  114. reference = ''
  115. lines = value.splitlines()
  116. for line in lines:
  117. if line.lstrip().startswith('- '):
  118. if len(reference) > 0:
  119. references.append(reference)
  120. reference = line.lstrip().lstrip('- ')
  121. else:
  122. if line.startswith(' '):
  123. line = line[2:]
  124. reference = ' '.join([ reference, line.strip() ])
  125. if len(reference) > 0:
  126. references.append(reference)
  127. for reference in references:
  128. if not 'references' in schema:
  129. schema['references'] = []
  130. schema['references'].append({ 'description': reference })
  131. def _update_headers(schema, docstring):
  132. headers = _parse_docstring(docstring)
  133. for header in headers:
  134. key = header[0]
  135. value = header[1]
  136. if key == '':
  137. description = _convert_code_blocks(value)
  138. schema['description'] = _remove_indentation(description)
  139. elif key in ('Args', 'Arguments'):
  140. arguments = _parse_arguments(value)
  141. for argument in arguments:
  142. _update_argument(schema, argument[0], argument[1])
  143. elif key == 'Input shape':
  144. _update_input(schema, value)
  145. elif key == 'Output shape':
  146. _update_output(schema, value)
  147. elif key in ('Example', 'Examples', 'Usage'):
  148. _update_examples(schema, value)
  149. elif key == 'References':
  150. _update_references(schema, value)
  151. elif key in ('Call arguments', 'Returns', 'Variables', 'Raises'):
  152. pass
  153. else:
  154. raise Exception('') # pylint: disable=broad-exception-raised
  155. def _metadata():
  156. root = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  157. json_path = os.path.join(root, 'source', 'keras-metadata.json')
  158. json_root = json.loads(_read(json_path))
  159. skip_names = set([
  160. 'keras.layers.InputLayer',
  161. 'keras.layers.ThresholdedReLU',
  162. 'keras.layers.LocallyConnected1D',
  163. 'keras.layers.LocallyConnected2D'
  164. ])
  165. for metadata in json_root:
  166. if 'module' in metadata:
  167. name = metadata['module'] + '.' + metadata['name']
  168. if not name in skip_names:
  169. cls = pydoc.locate(name)
  170. if not cls:
  171. raise KeyError('\'' + name + '\' not found.')
  172. if not cls.__doc__:
  173. raise AttributeError("'" + name + "' missing __doc__.")
  174. if cls.__doc__ == 'DEPRECATED.':
  175. raise DeprecationWarning("'" + name + "' __doc__ string is deprecated.'")
  176. _update_headers(metadata, cls.__doc__)
  177. with open(json_path, 'w', encoding='utf-8') as file:
  178. content = json.dumps(json_root, sort_keys=False, indent=2)
  179. for line in content.splitlines():
  180. file.write(line.rstrip() + '\n')
  181. def main(): # pylint: disable=missing-function-docstring
  182. _metadata()
  183. if __name__ == '__main__':
  184. main()