2
0

tf-metadata.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #!/usr/bin/env python
  2. from __future__ import unicode_literals
  3. import json
  4. import io
  5. import sys
  6. import os
  7. from tensorflow.core.framework import api_def_pb2
  8. from tensorflow.core.framework import op_def_pb2
  9. from google.protobuf import text_format
  10. categories = {
  11. 'Const': 'Constant',
  12. 'Conv2D': 'Layer',
  13. 'BiasAdd': 'Layer',
  14. 'DepthwiseConv2dNative': 'Layer',
  15. 'Relu': 'Activation',
  16. 'Relu6': 'Activation',
  17. 'Softmax': 'Activation',
  18. 'Sigmoid': 'Activation',
  19. 'LRN': 'Normalization',
  20. 'MaxPool': 'Pool',
  21. 'MaxPoolV2': 'Pool',
  22. 'AvgPool': 'Pool',
  23. 'Reshape': 'Shape',
  24. 'Squeeze': 'Shape',
  25. 'ConcatV2': 'Tensor',
  26. 'Split': 'Tensor',
  27. 'Dequantize': 'Tensor',
  28. 'Identity': 'Control',
  29. 'Variable': 'Control',
  30. 'VariableV2': 'Control',
  31. 'Assign': 'Control',
  32. 'BatchNormWithGlobalNormalization': 'Normalization',
  33. 'FusedBatchNorm': 'Normalization',
  34. # 'VariableV2':
  35. # 'Assign':
  36. # 'BiasAdd':
  37. }
  38. def find_multiline(line, colon):
  39. if colon == -1:
  40. return None
  41. line = line[colon+1:]
  42. while line.startswith(' '):
  43. line = line[1:]
  44. if line.startswith('<<'):
  45. line = line[2:]
  46. return line
  47. return None
  48. def str_escape(text):
  49. result = ''
  50. for c in text:
  51. if (c == '\n'):
  52. result += '\\n'
  53. elif (c == '\r'):
  54. result += "\\r"
  55. elif (c == '\t'):
  56. result += "\\t"
  57. elif (c == '\"'):
  58. result += "\\\""
  59. elif (c == '\''):
  60. result += "\\'"
  61. elif (c == '\\'):
  62. result += "\\\\"
  63. else:
  64. result += c
  65. return result
  66. def pbtxt_from_multiline(multiline_pbtxt):
  67. pbtxt = ''
  68. while len(multiline_pbtxt) > 0:
  69. index = multiline_pbtxt.find('\n')
  70. if index == -1:
  71. pbtxt = pbtxt + multiline_pbtxt
  72. multiline_pbtxt = ''
  73. break
  74. line = multiline_pbtxt[0:index]
  75. multiline_pbtxt = multiline_pbtxt[index+1:]
  76. colon = line.find(':')
  77. end = find_multiline(line, colon)
  78. if end == None:
  79. pbtxt = pbtxt + line + '\n'
  80. continue
  81. pbtxt = pbtxt + line[0:colon+1]
  82. unescaped = ''
  83. newline = False
  84. line = ''
  85. while len(multiline_pbtxt) > 0:
  86. index = multiline_pbtxt.find('\n')
  87. line = multiline_pbtxt[0:index]
  88. multiline_pbtxt = multiline_pbtxt[index+1:]
  89. if line.startswith(end):
  90. line = line[len(end):]
  91. break
  92. if newline:
  93. unescaped = unescaped + '\n'
  94. newline = True
  95. unescaped = unescaped + line
  96. line = ''
  97. pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n'
  98. return pbtxt
  99. def read_api_def_map(folder):
  100. api_def_map = {}
  101. file_list = os.listdir(folder)
  102. file_list = sorted(file_list)
  103. for filename in file_list:
  104. api_defs = api_def_pb2.ApiDefs()
  105. filename = folder + '/' + filename
  106. with open(filename) as handle:
  107. multiline_pbtxt = handle.read()
  108. pbtxt = pbtxt_from_multiline(multiline_pbtxt)
  109. text_format.Merge(pbtxt, api_defs)
  110. for api_def in api_defs.op:
  111. api_def_map[api_def.graph_op_name] = api_def
  112. return api_def_map
  113. def convert_type(type):
  114. return { 'type': 'type', 'value': type }
  115. def convert_tensor(tensor):
  116. return { 'type': 'tensor', 'value': '?' }
  117. def convert_shape(shape):
  118. return { 'type': 'shape', 'value': '?' }
  119. def convert_number(number):
  120. if number == float('inf'):
  121. return 'NaN'
  122. if number == float('-inf'):
  123. return '-NaN'
  124. return number
  125. def convert_attr_value(attr_value):
  126. if attr_value.HasField('list'):
  127. list = []
  128. attr_value_list = attr_value.list
  129. if len(attr_value_list.s) > 0:
  130. for s in attr_value_list.s:
  131. list.append(s.decode('utf8'))
  132. if len(attr_value_list.i) > 0:
  133. for i in attr_value_list.i:
  134. list.append(i)
  135. if len(attr_value_list.f) > 0:
  136. for f in attr_value_list.f:
  137. list.append(convert_number(f))
  138. if len(attr_value_list.type) > 0:
  139. for type in attr_value_list.type:
  140. list.append(convert_type(type))
  141. if len(list) == 0:
  142. for _, value in attr_value_list.ListFields():
  143. if len(value) > 0:
  144. raise Exception()
  145. return list
  146. if attr_value.HasField('s'):
  147. return attr_value.s.decode('utf8')
  148. if attr_value.HasField('i'):
  149. return attr_value.i
  150. if attr_value.HasField('f'):
  151. return convert_number(attr_value.f)
  152. if attr_value.HasField('b'):
  153. return attr_value.b
  154. if attr_value.HasField('type'):
  155. return convert_type(attr_value.type)
  156. if attr_value.HasField('tensor'):
  157. return convert_tensor(attr_value.tensor)
  158. if attr_value.HasField('shape'):
  159. return convert_shape(attr_value.shape)
  160. raise Exception()
  161. api_def_map = read_api_def_map('../third_party/tensorflow/tensorflow/core/api_def/base_api')
  162. input_file = '../third_party/tensorflow/tensorflow/core/ops/ops.pbtxt';
  163. ops_list = op_def_pb2.OpList()
  164. with open(input_file) as input_handle:
  165. text_format.Merge(input_handle.read(), ops_list)
  166. json_root = []
  167. for op in ops_list.op:
  168. # print(op.name)
  169. json_schema = {}
  170. if op.name in categories:
  171. json_schema['category'] = categories[op.name]
  172. api_def = api_def_pb2.ApiDef()
  173. if op.name in api_def_map:
  174. api_def = api_def_map[op.name]
  175. # if op.deprecation.version != 0:
  176. # print('[' + op.name + ']')
  177. # print(op.deprecation.version)
  178. # print(op.deprecation.explanation)
  179. api_def_attr_map = {}
  180. for attr in api_def.attr:
  181. api_def_attr_map[attr.name] = attr
  182. api_def_in_arg_map = {}
  183. for in_arg in api_def.in_arg:
  184. api_def_in_arg_map[in_arg.name] = in_arg
  185. api_def_out_arg_map = {}
  186. for out_arg in api_def.out_arg:
  187. api_def_out_arg_map[out_arg.name] = out_arg
  188. if api_def.summary:
  189. json_schema['summary'] = api_def.summary
  190. if api_def.description:
  191. json_schema['description'] = api_def.description
  192. for attr in op.attr:
  193. if not 'attributes' in json_schema:
  194. json_schema['attributes'] = []
  195. json_attribute = {}
  196. json_attribute['name'] = attr.name
  197. if attr.type:
  198. json_attribute['type'] = attr.type
  199. if attr.name in api_def_attr_map:
  200. api_def_attr = api_def_attr_map[attr.name]
  201. if api_def_attr.description:
  202. json_attribute['description'] = api_def_attr.description
  203. if attr.has_minimum:
  204. json_attribute['minimum'] = attr.minimum
  205. if attr.HasField('allowed_values'):
  206. json_attribute['allowedValues'] = convert_attr_value(attr.allowed_values)
  207. if attr.HasField('default_value'):
  208. json_attribute['defaultValue'] = convert_attr_value(attr.default_value)
  209. json_schema['attributes'].append(json_attribute)
  210. for input_arg in op.input_arg:
  211. if not 'inputs' in json_schema:
  212. json_schema['inputs'] = []
  213. json_input = {}
  214. json_input['name'] = input_arg.name
  215. if input_arg.name in api_def_in_arg_map:
  216. api_def_in_arg = api_def_in_arg_map[input_arg.name]
  217. if api_def_in_arg.description:
  218. json_input['description'] = api_def_in_arg.description
  219. if input_arg.number_attr:
  220. json_input['numberAttr'] = input_arg.number_attr
  221. if input_arg.type:
  222. json_input['type'] = input_arg.type
  223. if input_arg.type_attr:
  224. json_input['typeAttr'] = input_arg.type_attr
  225. if input_arg.type_list_attr:
  226. json_input['typeListAttr'] = input_arg.type_list_attr
  227. if input_arg.is_ref:
  228. json_input['isRef'] = True
  229. json_schema['inputs'].append(json_input)
  230. for output_arg in op.output_arg:
  231. if not 'outputs' in json_schema:
  232. json_schema['outputs'] = []
  233. json_output = {}
  234. json_output['name'] = output_arg.name
  235. if output_arg.name in api_def_out_arg_map:
  236. api_def_out_arg = api_def_out_arg_map[output_arg.name]
  237. if api_def_out_arg.description:
  238. json_output['description'] = api_def_out_arg.description
  239. if output_arg.type:
  240. json_output['type'] = output_arg.type
  241. elif output_arg.type_attr:
  242. json_output['typeAttr'] = output_arg.type_attr
  243. elif output_arg.type_list_attr:
  244. json_output['typeListAttr'] = output_arg.type_list_attr
  245. if output_arg.is_ref:
  246. json_output['isRef'] = True
  247. json_schema['outputs'].append(json_output)
  248. json_root.append({
  249. 'name': op.name,
  250. 'schema': json_schema
  251. })
  252. json_file = '../src/tf-metadata.json'
  253. with io.open(json_file, 'w', newline='') as fout:
  254. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  255. for line in json_data.splitlines():
  256. line = line.rstrip()
  257. if sys.version_info[0] < 3:
  258. line = unicode(line)
  259. fout.write(line)
  260. fout.write('\n')