keras-script.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. from __future__ import unicode_literals
  2. from __future__ import print_function
  3. import io
  4. import json
  5. import os
  6. import pydoc
  7. import re
  8. import sys
  9. stderr = sys.stderr
  10. sys.stderr = open(os.devnull, 'w')
  11. import keras
  12. sys.stderr = stderr
  13. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  14. def count_leading_spaces(s):
  15. ws = re.search(r'\S', s)
  16. if ws:
  17. return ws.start()
  18. else:
  19. return 0
  20. def process_list_block(docstring, starting_point, leading_spaces, marker):
  21. ending_point = docstring.find('\n\n', starting_point)
  22. block = docstring[starting_point:(None if ending_point == -1 else
  23. ending_point - 1)]
  24. # Place marker for later reinjection.
  25. docstring = docstring.replace(block, marker)
  26. lines = block.split('\n')
  27. # Remove the computed number of leading white spaces from each line.
  28. lines = [re.sub('^' + ' ' * leading_spaces, '', line) for line in lines]
  29. # Usually lines have at least 4 additional leading spaces.
  30. # These have to be removed, but first the list roots have to be detected.
  31. top_level_regex = r'^ ([^\s\\\(]+):(.*)'
  32. top_level_replacement = r'- __\1__:\2'
  33. lines = [re.sub(top_level_regex, top_level_replacement, line) for line in lines]
  34. # All the other lines get simply the 4 leading space (if present) removed
  35. lines = [re.sub(r'^ ', '', line) for line in lines]
  36. # Fix text lines after lists
  37. indent = 0
  38. text_block = False
  39. for i in range(len(lines)):
  40. line = lines[i]
  41. spaces = re.search(r'\S', line)
  42. if spaces:
  43. # If it is a list element
  44. if line[spaces.start()] == '-':
  45. indent = spaces.start() + 1
  46. if text_block:
  47. text_block = False
  48. lines[i] = '\n' + line
  49. elif spaces.start() < indent:
  50. text_block = True
  51. indent = spaces.start()
  52. lines[i] = '\n' + line
  53. else:
  54. text_block = False
  55. indent = 0
  56. block = '\n'.join(lines)
  57. return docstring, block
  58. def process_docstring(docstring):
  59. # First, extract code blocks and process them.
  60. code_blocks = []
  61. if '```' in docstring:
  62. tmp = docstring[:]
  63. while '```' in tmp:
  64. tmp = tmp[tmp.find('```'):]
  65. index = tmp[3:].find('```') + 6
  66. snippet = tmp[:index]
  67. # Place marker in docstring for later reinjection.
  68. docstring = docstring.replace(
  69. snippet, '$CODE_BLOCK_%d' % len(code_blocks))
  70. snippet_lines = snippet.split('\n')
  71. # Remove leading spaces.
  72. num_leading_spaces = snippet_lines[-1].find('`')
  73. snippet_lines = ([snippet_lines[0]] +
  74. [line[num_leading_spaces:]
  75. for line in snippet_lines[1:]])
  76. # Most code snippets have 3 or 4 more leading spaces
  77. # on inner lines, but not all. Remove them.
  78. inner_lines = snippet_lines[1:-1]
  79. leading_spaces = None
  80. for line in inner_lines:
  81. if not line or line[0] == '\n':
  82. continue
  83. spaces = count_leading_spaces(line)
  84. if leading_spaces is None:
  85. leading_spaces = spaces
  86. if spaces < leading_spaces:
  87. leading_spaces = spaces
  88. if leading_spaces:
  89. snippet_lines = ([snippet_lines[0]] +
  90. [line[leading_spaces:]
  91. for line in snippet_lines[1:-1]] +
  92. [snippet_lines[-1]])
  93. snippet = '\n'.join(snippet_lines)
  94. code_blocks.append(snippet)
  95. tmp = tmp[index:]
  96. # Format docstring lists.
  97. section_regex = r'\n( +)# (.*)\n'
  98. section_idx = re.search(section_regex, docstring)
  99. shift = 0
  100. sections = {}
  101. while section_idx and section_idx.group(2):
  102. anchor = section_idx.group(2)
  103. leading_spaces = len(section_idx.group(1))
  104. shift += section_idx.end()
  105. marker = '$' + anchor.replace(' ', '_') + '$'
  106. docstring, content = process_list_block(docstring,
  107. shift,
  108. leading_spaces,
  109. marker)
  110. sections[marker] = content
  111. section_idx = re.search(section_regex, docstring[shift:])
  112. # Format docstring section titles.
  113. docstring = re.sub(r'\n(\s+)# (.*)\n',
  114. r'\n\1__\2__\n\n',
  115. docstring)
  116. # Strip all remaining leading spaces.
  117. lines = docstring.split('\n')
  118. docstring = '\n'.join([line.lstrip(' ') for line in lines])
  119. # Reinject list blocks.
  120. for marker, content in sections.items():
  121. docstring = docstring.replace(marker, content)
  122. # Reinject code blocks.
  123. for i, code_block in enumerate(code_blocks):
  124. docstring = docstring.replace(
  125. '$CODE_BLOCK_%d' % i, code_block)
  126. return docstring
  127. def split_docstring(docstring):
  128. headers = {}
  129. current_header = ''
  130. current_lines = []
  131. lines = docstring.split('\n')
  132. for line in lines:
  133. if line.startswith('__') and line.endswith('__'):
  134. headers[current_header] = current_lines
  135. current_lines = []
  136. current_header = line[2:-2]
  137. if current_header == 'Masking' or current_header.startswith('Note '):
  138. headline = '**' + current_header + '**'
  139. current_lines = headers['']
  140. current_header = ''
  141. current_lines.append(headline)
  142. else:
  143. current_lines.append(line)
  144. if len(current_lines) > 0:
  145. headers[current_header] = current_lines
  146. return headers
  147. def update_hyperlink(description):
  148. def replace_hyperlink(match):
  149. name = match.group(1)
  150. link = match.group(2)
  151. if link.endswith('.md'):
  152. if link.startswith('../'):
  153. link = link.replace('../', 'https://keras.io/').rstrip('.md')
  154. else:
  155. link = 'https://keras.io/layers/' + link.rstrip('.md')
  156. return '[' + name + '](' + link + ')'
  157. return match.group(0)
  158. return re.sub(r'\[(.*?)\]\((.*?)\)', replace_hyperlink, description)
  159. def update_argument(schema, name, lines):
  160. attribute = None
  161. if not 'attributes' in schema:
  162. schema['attributes'] = []
  163. for current_attribute in schema['attributes']:
  164. if 'name' in current_attribute and current_attribute['name'] == name:
  165. attribute = current_attribute
  166. break
  167. if not attribute:
  168. attribute = {}
  169. attribute['name'] = name
  170. schema['attributes'].append(attribute)
  171. description = '\n'.join(lines)
  172. description = update_hyperlink(description)
  173. attribute['description'] = description
  174. def update_arguments(schema, lines):
  175. argument_name = None
  176. argument_lines = []
  177. for line in lines:
  178. if line.startswith('- __'):
  179. line = line.lstrip('- ')
  180. colon = line.index(':')
  181. if colon > 0:
  182. name = line[0:colon]
  183. line = line[colon+1:].lstrip(' ')
  184. if name.startswith('__') and name.endswith('__'):
  185. if argument_name:
  186. update_argument(schema, argument_name, argument_lines)
  187. argument_name = name[2:-2]
  188. argument_lines = []
  189. if argument_name:
  190. argument_lines.append(line)
  191. if argument_name:
  192. update_argument(schema, argument_name, argument_lines)
  193. return
  194. def update_examples(schema, lines):
  195. if 'examples' in schema:
  196. del schema['examples']
  197. summary_lines = []
  198. code_lines = None
  199. for line in lines:
  200. if line.startswith('```'):
  201. if code_lines != None:
  202. example = {}
  203. example['code'] = '\n'.join(code_lines)
  204. if len(summary_lines) > 0:
  205. example['summary'] = '\n'.join(summary_lines)
  206. if not 'examples' in schema:
  207. schema['examples'] = []
  208. schema['examples'].append(example)
  209. summary_lines = []
  210. code_lines = None
  211. else:
  212. code_lines = [ ]
  213. else:
  214. if code_lines != None:
  215. code_lines.append(line)
  216. elif line != '':
  217. summary_lines.append(line)
  218. def update_references(schema, lines):
  219. if 'references' in schema:
  220. del schema['references']
  221. references = []
  222. reference = ''
  223. for line in lines:
  224. if line.startswith('- '):
  225. if len(reference) > 0:
  226. references.append(reference)
  227. reference = line.lstrip('- ')
  228. else:
  229. if line.startswith(' '):
  230. line = line[2:]
  231. reference = reference + line
  232. if len(reference) > 0:
  233. references.append(reference)
  234. for reference in references:
  235. if not 'references' in schema:
  236. schema['references'] = []
  237. schema['references'].append({ 'description': reference })
  238. def update_input(schema, description):
  239. entry = None
  240. if 'inputs' in schema:
  241. for current_input in schema['inputs']:
  242. if current_input['name'] == 'input':
  243. entry = current_input
  244. break
  245. else:
  246. entry = {}
  247. entry['name'] = 'input'
  248. schema['inputs'] = []
  249. schema['inputs'].append(entry)
  250. if entry:
  251. entry['description'] = description
  252. def update_output(schema, description):
  253. entry = None
  254. if 'outputs' in schema:
  255. for current_output in schema['outputs']:
  256. if current_output['name'] == 'output':
  257. entry = current_output
  258. break
  259. else:
  260. entry = {}
  261. entry['name'] = 'output'
  262. schema['outputs'] = []
  263. schema['outputs'].append(entry)
  264. if entry:
  265. entry['description'] = description
  266. def metadata():
  267. json_file = os.path.join(os.path.dirname(__file__), '../src/keras-metadata.json')
  268. json_data = open(json_file).read()
  269. json_root = json.loads(json_data)
  270. for entry in json_root:
  271. name = entry['name']
  272. schema = entry['schema']
  273. if 'package' in schema:
  274. class_name = schema['package'] + '.' + name
  275. class_definition = pydoc.locate(class_name)
  276. if not class_definition:
  277. raise Exception('\'' + class_name + '\' not found.')
  278. docstring = class_definition.__doc__
  279. if not docstring:
  280. raise Exception('\'' + class_name + '\' missing __doc__.')
  281. docstring = process_docstring(docstring)
  282. headers = split_docstring(docstring)
  283. if '' in headers:
  284. schema['description'] = '\n'.join(headers[''])
  285. del headers['']
  286. if 'Arguments' in headers:
  287. update_arguments(schema, headers['Arguments'])
  288. del headers['Arguments']
  289. if 'Input shape' in headers:
  290. update_input(schema, '\n'.join(headers['Input shape']))
  291. del headers['Input shape']
  292. if 'Output shape' in headers:
  293. update_output(schema, '\n'.join(headers['Output shape']))
  294. del headers['Output shape']
  295. if 'Examples' in headers:
  296. update_examples(schema, headers['Examples'])
  297. del headers['Examples']
  298. if 'Example' in headers:
  299. update_examples(schema, headers['Example'])
  300. del headers['Example']
  301. if 'References' in headers:
  302. update_references(schema, headers['References'])
  303. del headers['References']
  304. if 'Raises' in headers:
  305. del headers['Raises']
  306. if len(headers) > 0:
  307. raise Exception('\'' + class_name + '.__doc__\' contains unprocessed headers.')
  308. with io.open(json_file, 'w', newline='') as fout:
  309. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  310. for line in json_data.splitlines():
  311. line = line.rstrip()
  312. if sys.version_info[0] < 3:
  313. line = unicode(line)
  314. fout.write(line)
  315. fout.write('\n')
  316. def download_model(type, file):
  317. file = os.path.expandvars(file)
  318. if not os.path.exists(file):
  319. folder = os.path.dirname(file);
  320. if not os.path.exists(folder):
  321. os.makedirs(folder)
  322. model = pydoc.locate(type)()
  323. model.save(file);
  324. def zoo():
  325. if not os.environ.get('test'):
  326. os.environ['test'] = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../test'))
  327. download_model('keras.applications.densenet.DenseNet121', '${test}/data/keras/DenseNet121.h5')
  328. download_model('keras.applications.inception_resnet_v2.InceptionResNetV2', '${test}/data/keras/InceptionResNetV2.h5')
  329. download_model('keras.applications.inception_v3.InceptionV3', '${test}/data/keras/InceptionV3.h5')
  330. download_model('keras.applications.mobilenet_v2.MobileNetV2', '${test}/data/keras/MobileNetV2.h5')
  331. download_model('keras.applications.nasnet.NASNetMobile', '${test}/data/keras/NASNetMobile.h5')
  332. download_model('keras.applications.resnet50.ResNet50', '${test}/data/keras/ResNet50.h5')
  333. download_model('keras.applications.vgg19.VGG19', '${test}/data/keras/VGG19.h5')
  334. download_model('keras.applications.xception.Xception', '${test}/data/keras/Xception.h5')
  335. if __name__ == '__main__':
  336. command_table = { 'metadata': metadata, 'zoo': zoo }
  337. command = sys.argv[1];
  338. command_table[command]()