tf-script.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. from __future__ import unicode_literals
  2. import json
  3. import io
  4. import sys
  5. import os
  6. from tensorflow.core.framework import api_def_pb2
  7. from tensorflow.core.framework import op_def_pb2
  8. from tensorflow.core.framework import types_pb2
  9. from google.protobuf import text_format
  10. def metadata():
  11. categories = {
  12. 'Const': 'Constant',
  13. 'Conv2D': 'Layer',
  14. 'BiasAdd': 'Layer',
  15. 'DepthwiseConv2dNative': 'Layer',
  16. 'Relu': 'Activation',
  17. 'Relu6': 'Activation',
  18. 'Elu': 'Activation',
  19. 'Softmax': 'Activation',
  20. 'Sigmoid': 'Activation',
  21. 'LRN': 'Normalization',
  22. 'MaxPool': 'Pool',
  23. 'MaxPoolV2': 'Pool',
  24. 'AvgPool': 'Pool',
  25. 'Reshape': 'Shape',
  26. 'Squeeze': 'Shape',
  27. 'ConcatV2': 'Tensor',
  28. 'Split': 'Tensor',
  29. 'Dequantize': 'Tensor',
  30. 'Identity': 'Control',
  31. 'Variable': 'Control',
  32. 'VariableV2': 'Control',
  33. 'Assign': 'Control',
  34. 'BatchNormWithGlobalNormalization': 'Normalization',
  35. 'FusedBatchNorm': 'Normalization',
  36. # 'VariableV2':
  37. # 'Assign':
  38. # 'BiasAdd':
  39. }
  40. def find_multiline(line, colon):
  41. if colon == -1:
  42. return None
  43. line = line[colon+1:]
  44. while line.startswith(' '):
  45. line = line[1:]
  46. if line.startswith('<<'):
  47. line = line[2:]
  48. return line
  49. return None
  50. def str_escape(text):
  51. result = ''
  52. for c in text:
  53. if (c == '\n'):
  54. result += '\\n'
  55. elif (c == '\r'):
  56. result += "\\r"
  57. elif (c == '\t'):
  58. result += "\\t"
  59. elif (c == '\"'):
  60. result += "\\\""
  61. elif (c == '\''):
  62. result += "\\'"
  63. elif (c == '\\'):
  64. result += "\\\\"
  65. else:
  66. result += c
  67. return result
  68. def pbtxt_from_multiline(multiline_pbtxt):
  69. pbtxt = ''
  70. while len(multiline_pbtxt) > 0:
  71. index = multiline_pbtxt.find('\n')
  72. if index == -1:
  73. pbtxt = pbtxt + multiline_pbtxt
  74. multiline_pbtxt = ''
  75. break
  76. line = multiline_pbtxt[0:index]
  77. multiline_pbtxt = multiline_pbtxt[index+1:]
  78. colon = line.find(':')
  79. end = find_multiline(line, colon)
  80. if end == None:
  81. pbtxt = pbtxt + line + '\n'
  82. continue
  83. pbtxt = pbtxt + line[0:colon+1]
  84. unescaped = ''
  85. newline = False
  86. line = ''
  87. while len(multiline_pbtxt) > 0:
  88. index = multiline_pbtxt.find('\n')
  89. line = multiline_pbtxt[0:index]
  90. multiline_pbtxt = multiline_pbtxt[index+1:]
  91. if line.startswith(end):
  92. line = line[len(end):]
  93. break
  94. if newline:
  95. unescaped = unescaped + '\n'
  96. newline = True
  97. unescaped = unescaped + line
  98. line = ''
  99. pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n'
  100. return pbtxt
  101. def read_api_def_map(folder):
  102. api_def_map = {}
  103. file_list = os.listdir(folder)
  104. file_list = sorted(file_list)
  105. for filename in file_list:
  106. api_defs = api_def_pb2.ApiDefs()
  107. filename = folder + '/' + filename
  108. with open(filename) as handle:
  109. multiline_pbtxt = handle.read()
  110. pbtxt = pbtxt_from_multiline(multiline_pbtxt)
  111. text_format.Merge(pbtxt, api_defs)
  112. for api_def in api_defs.op:
  113. api_def_map[api_def.graph_op_name] = api_def
  114. return api_def_map
  115. def convert_type(type):
  116. return { 'type': 'type', 'value': type }
  117. def convert_tensor(tensor):
  118. return { 'type': 'tensor', 'value': '?' }
  119. def convert_shape(shape):
  120. return { 'type': 'shape', 'value': '?' }
  121. def convert_number(number):
  122. if number == float('inf'):
  123. return 'NaN'
  124. if number == float('-inf'):
  125. return '-NaN'
  126. return number
  127. attr_type_table = {
  128. 'type': 'type', 'list(type)': 'type[]',
  129. 'bool': 'boolean',
  130. 'int': 'int64', 'list(int)': 'int64[]',
  131. 'float': 'float32', 'list(float)': 'float32[]',
  132. 'string': 'string', 'list(string)': 'string[]',
  133. 'shape': 'shape', 'list(shape)': 'shape[]',
  134. 'tensor': 'tensor',
  135. 'func': 'function', 'list(func)': 'function[]'
  136. }
  137. def convert_attr_type(type):
  138. if type in attr_type_table:
  139. return attr_type_table[type]
  140. print(type)
  141. return type
  142. def convert_attr_value(attr_value):
  143. if attr_value.HasField('list'):
  144. list = []
  145. attr_value_list = attr_value.list
  146. if len(attr_value_list.s) > 0:
  147. for s in attr_value_list.s:
  148. list.append(s.decode('utf8'))
  149. if len(attr_value_list.i) > 0:
  150. for i in attr_value_list.i:
  151. list.append(i)
  152. if len(attr_value_list.f) > 0:
  153. for f in attr_value_list.f:
  154. list.append(convert_number(f))
  155. if len(attr_value_list.type) > 0:
  156. for type in attr_value_list.type:
  157. list.append(convert_type(type))
  158. if len(list) == 0:
  159. for _, value in attr_value_list.ListFields():
  160. if len(value) > 0:
  161. raise Exception()
  162. return list
  163. if attr_value.HasField('s'):
  164. return attr_value.s.decode('utf8')
  165. if attr_value.HasField('i'):
  166. return attr_value.i
  167. if attr_value.HasField('f'):
  168. return convert_number(attr_value.f)
  169. if attr_value.HasField('b'):
  170. return attr_value.b
  171. if attr_value.HasField('type'):
  172. return convert_type(attr_value.type)
  173. if attr_value.HasField('tensor'):
  174. return convert_tensor(attr_value.tensor)
  175. if attr_value.HasField('shape'):
  176. return convert_shape(attr_value.shape)
  177. raise Exception()
  178. _TYPE_TO_STRING = {
  179. types_pb2.DataType.DT_HALF: "float16",
  180. types_pb2.DataType.DT_FLOAT: "float32",
  181. types_pb2.DataType.DT_DOUBLE: "float64",
  182. types_pb2.DataType.DT_INT32: "int32",
  183. types_pb2.DataType.DT_UINT8: "uint8",
  184. types_pb2.DataType.DT_UINT16: "uint16",
  185. types_pb2.DataType.DT_UINT32: "uint32",
  186. types_pb2.DataType.DT_UINT64: "uint64",
  187. types_pb2.DataType.DT_INT16: "int16",
  188. types_pb2.DataType.DT_INT8: "int8",
  189. types_pb2.DataType.DT_STRING: "string",
  190. types_pb2.DataType.DT_COMPLEX64: "complex64",
  191. types_pb2.DataType.DT_COMPLEX128: "complex128",
  192. types_pb2.DataType.DT_INT64: "int64",
  193. types_pb2.DataType.DT_BOOL: "bool",
  194. types_pb2.DataType.DT_QINT8: "qint8",
  195. types_pb2.DataType.DT_QUINT8: "quint8",
  196. types_pb2.DataType.DT_QINT16: "qint16",
  197. types_pb2.DataType.DT_QUINT16: "quint16",
  198. types_pb2.DataType.DT_QINT32: "qint32",
  199. types_pb2.DataType.DT_BFLOAT16: "bfloat16",
  200. types_pb2.DataType.DT_RESOURCE: "resource",
  201. types_pb2.DataType.DT_VARIANT: "variant",
  202. types_pb2.DataType.DT_HALF_REF: "float16_ref",
  203. types_pb2.DataType.DT_FLOAT_REF: "float32_ref",
  204. types_pb2.DataType.DT_DOUBLE_REF: "float64_ref",
  205. types_pb2.DataType.DT_INT32_REF: "int32_ref",
  206. types_pb2.DataType.DT_UINT32_REF: "uint32_ref",
  207. types_pb2.DataType.DT_UINT8_REF: "uint8_ref",
  208. types_pb2.DataType.DT_UINT16_REF: "uint16_ref",
  209. types_pb2.DataType.DT_INT16_REF: "int16_ref",
  210. types_pb2.DataType.DT_INT8_REF: "int8_ref",
  211. types_pb2.DataType.DT_STRING_REF: "string_ref",
  212. types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref",
  213. types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref",
  214. types_pb2.DataType.DT_INT64_REF: "int64_ref",
  215. types_pb2.DataType.DT_UINT64_REF: "uint64_ref",
  216. types_pb2.DataType.DT_BOOL_REF: "bool_ref",
  217. types_pb2.DataType.DT_QINT8_REF: "qint8_ref",
  218. types_pb2.DataType.DT_QUINT8_REF: "quint8_ref",
  219. types_pb2.DataType.DT_QINT16_REF: "qint16_ref",
  220. types_pb2.DataType.DT_QUINT16_REF: "quint16_ref",
  221. types_pb2.DataType.DT_QINT32_REF: "qint32_ref",
  222. types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref",
  223. types_pb2.DataType.DT_RESOURCE_REF: "resource_ref",
  224. types_pb2.DataType.DT_VARIANT_REF: "variant_ref",
  225. }
  226. def format_data_type(data_type):
  227. if data_type in _TYPE_TO_STRING:
  228. return _TYPE_TO_STRING[data_type]
  229. raise Exception()
  230. def format_attribute_value(value):
  231. if type(value) is dict and 'type' in value and 'value' in value and value['type'] == 'type':
  232. return format_data_type(value['value'])
  233. if type(value) is str:
  234. return value
  235. if value == True:
  236. return 'true'
  237. if value == False:
  238. return 'false'
  239. raise Exception()
  240. tensorflow_repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/src/tensorflow')
  241. api_def_map = read_api_def_map(os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api'))
  242. input_file = os.path.join(tensorflow_repo_dir, 'tensorflow/core/ops/ops.pbtxt')
  243. ops_list = op_def_pb2.OpList()
  244. with open(input_file) as input_handle:
  245. text_format.Merge(input_handle.read(), ops_list)
  246. json_root = []
  247. for op in ops_list.op:
  248. # print(op.name)
  249. json_schema = {}
  250. if op.name in categories:
  251. json_schema['category'] = categories[op.name]
  252. api_def = api_def_pb2.ApiDef()
  253. if op.name in api_def_map:
  254. api_def = api_def_map[op.name]
  255. # if op.deprecation.version != 0:
  256. # print('[' + op.name + ']')
  257. # print(op.deprecation.version)
  258. # print(op.deprecation.explanation)
  259. api_def_attr_map = {}
  260. for attr in api_def.attr:
  261. api_def_attr_map[attr.name] = attr
  262. api_def_in_arg_map = {}
  263. for in_arg in api_def.in_arg:
  264. api_def_in_arg_map[in_arg.name] = in_arg
  265. api_def_out_arg_map = {}
  266. for out_arg in api_def.out_arg:
  267. api_def_out_arg_map[out_arg.name] = out_arg
  268. if api_def.summary:
  269. json_schema['summary'] = api_def.summary
  270. if api_def.description:
  271. json_schema['description'] = api_def.description
  272. for attr in op.attr:
  273. if not 'attributes' in json_schema:
  274. json_schema['attributes'] = []
  275. json_attribute = {}
  276. json_attribute['name'] = attr.name
  277. attr_type = convert_attr_type(attr.type)
  278. if attr_type:
  279. json_attribute['type'] = attr_type
  280. else:
  281. del json_attribute['type']
  282. if attr.name in api_def_attr_map:
  283. api_def_attr = api_def_attr_map[attr.name]
  284. if api_def_attr.description:
  285. json_attribute['description'] = api_def_attr.description
  286. if attr.has_minimum:
  287. json_attribute['minimum'] = attr.minimum
  288. if attr.HasField('allowed_values'):
  289. allowed_values = convert_attr_value(attr.allowed_values)
  290. description = json_attribute['description'] + ' ' if 'description' in json_attribute else ''
  291. description = description + 'Must be one of the following: ' + ', '.join(list(map(lambda x: "`" + format_attribute_value(x) + "`", allowed_values))) + '.'
  292. json_attribute['description'] = description
  293. if attr.HasField('default_value'):
  294. default_value = convert_attr_value(attr.default_value)
  295. json_attribute['default'] = default_value
  296. json_schema['attributes'].append(json_attribute)
  297. for input_arg in op.input_arg:
  298. if not 'inputs' in json_schema:
  299. json_schema['inputs'] = []
  300. json_input = {}
  301. json_input['name'] = input_arg.name
  302. if input_arg.name in api_def_in_arg_map:
  303. api_def_in_arg = api_def_in_arg_map[input_arg.name]
  304. if api_def_in_arg.description:
  305. json_input['description'] = api_def_in_arg.description
  306. if input_arg.number_attr:
  307. json_input['numberAttr'] = input_arg.number_attr
  308. if input_arg.type:
  309. json_input['type'] = input_arg.type
  310. if input_arg.type_attr:
  311. json_input['typeAttr'] = input_arg.type_attr
  312. if input_arg.type_list_attr:
  313. json_input['typeListAttr'] = input_arg.type_list_attr
  314. if input_arg.is_ref:
  315. json_input['isRef'] = True
  316. json_schema['inputs'].append(json_input)
  317. for output_arg in op.output_arg:
  318. if not 'outputs' in json_schema:
  319. json_schema['outputs'] = []
  320. json_output = {}
  321. json_output['name'] = output_arg.name
  322. if output_arg.name in api_def_out_arg_map:
  323. api_def_out_arg = api_def_out_arg_map[output_arg.name]
  324. if api_def_out_arg.description:
  325. json_output['description'] = api_def_out_arg.description
  326. if output_arg.number_attr:
  327. json_output['numberAttr'] = output_arg.number_attr
  328. if output_arg.type:
  329. json_output['type'] = output_arg.type
  330. elif output_arg.type_attr:
  331. json_output['typeAttr'] = output_arg.type_attr
  332. elif output_arg.type_list_attr:
  333. json_output['typeListAttr'] = output_arg.type_list_attr
  334. if output_arg.is_ref:
  335. json_output['isRef'] = True
  336. json_schema['outputs'].append(json_output)
  337. json_root.append({
  338. 'name': op.name,
  339. 'schema': json_schema
  340. })
  341. json_file = os.path.join(os.path.dirname(__file__), '../src/tf-metadata.json')
  342. with io.open(json_file, 'w', newline='') as fout:
  343. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  344. for line in json_data.splitlines():
  345. line = line.rstrip()
  346. if sys.version_info[0] < 3:
  347. line = unicode(line)
  348. fout.write(line)
  349. fout.write('\n')
  350. if __name__ == '__main__':
  351. command_table = { 'metadata': metadata }
  352. command = sys.argv[1];
  353. command_table[command]()