2
0

keras_script.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. ''' Keras metadata script '''
  2. import json
  3. import os
  4. import pydoc
  5. import re
  6. import sys
  7. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  8. def _metadata():
  9. def parse_docstring(docstring):
  10. headers = []
  11. lines = docstring.splitlines()
  12. indentation = min(filter(lambda s: s > 0, map(lambda s: len(s) - len(s.lstrip()), lines)))
  13. lines = list((s[indentation:] if len(s) > len(s.lstrip()) else s) for s in lines)
  14. docstring = '\n'.join(lines)
  15. tag_re = re.compile('(?<=\n)(Args|Arguments|Variables|Fields|Yields|Call arguments|Raises|Examples|Example|Usage|Input shape|Output shape|Returns|References):\n', re.MULTILINE)
  16. parts = tag_re.split(docstring)
  17. headers.append(('', parts.pop(0)))
  18. while len(parts) > 0:
  19. headers.append((parts.pop(0), parts.pop(0)))
  20. return headers
  21. def parse_arguments(arguments):
  22. result = []
  23. item_re = re.compile(r'^ ? ?(\*?\*?\w[\w.]*?\s*):\s', re.MULTILINE)
  24. content = item_re.split(arguments)
  25. if content.pop(0) != '':
  26. raise Exception('')
  27. while len(content) > 0:
  28. result.append((content.pop(0), content.pop(0)))
  29. return result
  30. def convert_code_blocks(description):
  31. lines = description.splitlines()
  32. output = []
  33. while len(lines) > 0:
  34. line = lines.pop(0)
  35. if line.startswith('>>>') and len(lines) > 0 and \
  36. (lines[0].startswith('>>>') or lines[0].startswith('...')):
  37. output.append('```')
  38. output.append(line)
  39. while len(lines) > 0 and lines[0] != '':
  40. output.append(lines.pop(0))
  41. output.append('```')
  42. else:
  43. output.append(line)
  44. return '\n'.join(output)
  45. def remove_indentation(value):
  46. lines = value.splitlines()
  47. indentation = min(map(lambda s: len(s) - len(s.lstrip()), \
  48. filter(lambda s: len(s) > 0, lines)))
  49. lines = list((s[indentation:] if len(s) > 0 else s) for s in lines)
  50. return '\n'.join(lines).strip()
  51. def update_argument(schema, name, description):
  52. if not 'attributes' in schema:
  53. schema['attributes'] = []
  54. attribute = next((_ for _ in schema['attributes'] if _['name'] == name), None)
  55. if not attribute:
  56. attribute = {}
  57. attribute['name'] = name
  58. schema['attributes'].append(attribute)
  59. attribute['description'] = remove_indentation(description)
  60. def update_input(schema, description):
  61. if not 'inputs' in schema:
  62. schema['inputs'] = [ { 'name': 'input' } ]
  63. parameter = next((_ for _ in schema['inputs'] \
  64. if (_['name'] == 'input' or _['name'] == 'inputs')), None)
  65. if parameter:
  66. parameter['description'] = remove_indentation(description)
  67. else:
  68. raise Exception('')
  69. def update_output(schema, description):
  70. if not 'outputs' in schema:
  71. schema['outputs'] = [ { 'name': 'output' } ]
  72. parameter = next((param for param in schema['outputs'] if param['name'] == 'output'), None)
  73. if parameter:
  74. parameter['description'] = remove_indentation(description)
  75. else:
  76. raise Exception('')
  77. def update_examples(schema, value):
  78. if 'examples' in schema:
  79. del schema['examples']
  80. value = convert_code_blocks(value)
  81. lines = value.splitlines()
  82. code = []
  83. summary = []
  84. while len(lines) > 0:
  85. line = lines.pop(0)
  86. if len(line) > 0:
  87. if line.startswith('```'):
  88. while len(lines) > 0:
  89. line = lines.pop(0)
  90. if line == '```':
  91. break
  92. code.append(line)
  93. else:
  94. summary.append(line)
  95. if len(code) > 0:
  96. example = {}
  97. if len(summary) > 0:
  98. example['summary'] = '\n'.join(summary)
  99. example['code'] = '\n'.join(code)
  100. if not 'examples' in schema:
  101. schema['examples'] = []
  102. schema['examples'].append(example)
  103. code = []
  104. summary = []
  105. def update_references(schema, value):
  106. if 'references' in schema:
  107. del schema['references']
  108. references = []
  109. reference = ''
  110. lines = value.splitlines()
  111. for line in lines:
  112. if line.lstrip().startswith('- '):
  113. if len(reference) > 0:
  114. references.append(reference)
  115. reference = line.lstrip().lstrip('- ')
  116. else:
  117. if line.startswith(' '):
  118. line = line[2:]
  119. reference = ' '.join([ reference, line.strip() ])
  120. if len(reference) > 0:
  121. references.append(reference)
  122. for reference in references:
  123. if not 'references' in schema:
  124. schema['references'] = []
  125. schema['references'].append({ 'description': reference })
  126. json_path = os.path.join(os.path.dirname(__file__), '../source/keras-metadata.json')
  127. with open(json_path, 'r', encoding='utf-8') as file:
  128. json_root = json.loads(file.read())
  129. for schema in json_root:
  130. name = schema['name']
  131. if 'module' in schema:
  132. class_name = schema['module'] + '.' + name
  133. class_definition = pydoc.locate(class_name)
  134. if not class_definition:
  135. raise Exception('\'' + class_name + '\' not found.')
  136. if not class_definition.__doc__:
  137. raise Exception('\'' + class_name + '\' missing __doc__.')
  138. headers = parse_docstring(class_definition.__doc__)
  139. for header in headers:
  140. key = header[0]
  141. value = header[1]
  142. if key == '':
  143. description = convert_code_blocks(value)
  144. schema['description'] = remove_indentation(description)
  145. elif key in ('Args', 'Arguments'):
  146. arguments = parse_arguments(value)
  147. for argument in arguments:
  148. update_argument(schema, argument[0], argument[1])
  149. elif key == 'Call arguments':
  150. pass
  151. elif key == 'Returns':
  152. pass
  153. elif key == 'Input shape':
  154. update_input(schema, value)
  155. elif key == 'Output shape':
  156. update_output(schema, value)
  157. elif key in ('Example', 'Examples', 'Usage'):
  158. update_examples(schema, value)
  159. elif key == 'References':
  160. update_references(schema, value)
  161. elif key == 'Variables':
  162. pass
  163. elif key == 'Raises':
  164. pass
  165. else:
  166. raise Exception('')
  167. with open(json_path, 'w', encoding='utf-8') as file:
  168. json_data = json.dumps(json_root, sort_keys=False, indent=2)
  169. for line in json_data.splitlines():
  170. file.write(line.rstrip() + '\n')
  171. def main(): # pylint: disable=missing-function-docstring
  172. command_table = { 'metadata': _metadata }
  173. command = sys.argv[1]
  174. command_table[command]()
  175. if __name__ == '__main__':
  176. main()