2
0

tf-script.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. from __future__ import unicode_literals
  2. import json
  3. import io
  4. import sys
  5. import os
  6. from tensorflow.core.framework import api_def_pb2
  7. from tensorflow.core.framework import op_def_pb2
  8. from tensorflow.core.framework import types_pb2
  9. from google.protobuf import text_format
  10. def metadata():
  11. categories = {
  12. 'Assign': 'Control',
  13. 'AvgPool': 'Pool',
  14. 'BatchNormWithGlobalNormalization': 'Normalization',
  15. 'BiasAdd': 'Layer',
  16. 'ConcatV2': 'Tensor',
  17. 'Const': 'Constant',
  18. 'Conv2D': 'Layer',
  19. 'DepthwiseConv2dNative': 'Layer',
  20. 'Dequantize': 'Tensor',
  21. 'Elu': 'Activation',
  22. 'FusedBatchNorm': 'Normalization',
  23. 'FusedBatchNormV2': 'Normalization',
  24. 'FusedBatchNormV3': 'Normalization',
  25. 'Identity': 'Control',
  26. 'LeakyRelu': 'Activation',
  27. 'LRN': 'Normalization',
  28. 'MaxPool': 'Pool',
  29. 'MaxPoolV2': 'Pool',
  30. 'Pad': 'Tensor',
  31. 'Relu': 'Activation',
  32. 'Relu6': 'Activation',
  33. 'Reshape': 'Shape',
  34. 'Sigmoid': 'Activation',
  35. 'Slice': 'Tensor',
  36. 'Softmax': 'Activation',
  37. 'Split': 'Tensor',
  38. 'Squeeze': 'Shape',
  39. 'StridedSlice': 'Tensor',
  40. 'swish_f32': 'Activation',
  41. 'Variable': 'Control',
  42. 'VariableV2': 'Control',
  43. };
  44. def find_multiline(line, colon):
  45. if colon == -1:
  46. return None
  47. line = line[colon+1:]
  48. while line.startswith(' '):
  49. line = line[1:]
  50. if line.startswith('<<'):
  51. line = line[2:]
  52. return line
  53. return None
  54. def str_escape(text):
  55. result = ''
  56. for c in text:
  57. if (c == '\n'):
  58. result += '\\n'
  59. elif (c == '\r'):
  60. result += "\\r"
  61. elif (c == '\t'):
  62. result += "\\t"
  63. elif (c == '\"'):
  64. result += "\\\""
  65. elif (c == '\''):
  66. result += "\\'"
  67. elif (c == '\\'):
  68. result += "\\\\"
  69. else:
  70. result += c
  71. return result
  72. def pbtxt_from_multiline(multiline_pbtxt):
  73. pbtxt = ''
  74. while len(multiline_pbtxt) > 0:
  75. index = multiline_pbtxt.find('\n')
  76. if index == -1:
  77. pbtxt = pbtxt + multiline_pbtxt
  78. multiline_pbtxt = ''
  79. break
  80. line = multiline_pbtxt[0:index]
  81. multiline_pbtxt = multiline_pbtxt[index+1:]
  82. colon = line.find(':')
  83. end = find_multiline(line, colon)
  84. if end == None:
  85. pbtxt = pbtxt + line + '\n'
  86. continue
  87. pbtxt = pbtxt + line[0:colon+1]
  88. unescaped = ''
  89. newline = False
  90. line = ''
  91. while len(multiline_pbtxt) > 0:
  92. index = multiline_pbtxt.find('\n')
  93. line = multiline_pbtxt[0:index]
  94. multiline_pbtxt = multiline_pbtxt[index+1:]
  95. if line.startswith(end):
  96. line = line[len(end):]
  97. break
  98. if newline:
  99. unescaped = unescaped + '\n'
  100. newline = True
  101. unescaped = unescaped + line
  102. line = ''
  103. pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n'
  104. return pbtxt
  105. def read_api_def_map(folder):
  106. api_def_map = {}
  107. file_list = os.listdir(folder)
  108. file_list = sorted(file_list)
  109. for filename in file_list:
  110. api_defs = api_def_pb2.ApiDefs()
  111. filename = folder + '/' + filename
  112. with open(filename) as handle:
  113. multiline_pbtxt = handle.read()
  114. pbtxt = pbtxt_from_multiline(multiline_pbtxt)
  115. text_format.Merge(pbtxt, api_defs)
  116. for api_def in api_defs.op:
  117. api_def_map[api_def.graph_op_name] = api_def
  118. return api_def_map
  119. def convert_type(type):
  120. return { 'type': 'type', 'value': type }
  121. def convert_tensor(tensor):
  122. return { 'type': 'tensor', 'value': '?' }
  123. def convert_shape(shape):
  124. return { 'type': 'shape', 'value': '?' }
  125. def convert_number(number):
  126. if number == float('inf'):
  127. return 'NaN'
  128. if number == float('-inf'):
  129. return '-NaN'
  130. return number
  131. attr_type_table = {
  132. 'type': 'type', 'list(type)': 'type[]',
  133. 'bool': 'boolean',
  134. 'int': 'int64', 'list(int)': 'int64[]',
  135. 'float': 'float32', 'list(float)': 'float32[]',
  136. 'string': 'string', 'list(string)': 'string[]',
  137. 'shape': 'shape', 'list(shape)': 'shape[]',
  138. 'tensor': 'tensor',
  139. 'func': 'function', 'list(func)': 'function[]'
  140. }
  141. def convert_attr_type(type):
  142. if type in attr_type_table:
  143. return attr_type_table[type]
  144. print(type)
  145. return type
  146. def convert_attr_value(attr_value):
  147. if attr_value.HasField('list'):
  148. list = []
  149. attr_value_list = attr_value.list
  150. if len(attr_value_list.s) > 0:
  151. for s in attr_value_list.s:
  152. list.append(s.decode('utf8'))
  153. if len(attr_value_list.i) > 0:
  154. for i in attr_value_list.i:
  155. list.append(i)
  156. if len(attr_value_list.f) > 0:
  157. for f in attr_value_list.f:
  158. list.append(convert_number(f))
  159. if len(attr_value_list.type) > 0:
  160. for type in attr_value_list.type:
  161. list.append(convert_type(type))
  162. if len(list) == 0:
  163. for _, value in attr_value_list.ListFields():
  164. if len(value) > 0:
  165. raise Exception()
  166. return list
  167. if attr_value.HasField('s'):
  168. return attr_value.s.decode('utf8')
  169. if attr_value.HasField('i'):
  170. return attr_value.i
  171. if attr_value.HasField('f'):
  172. return convert_number(attr_value.f)
  173. if attr_value.HasField('b'):
  174. return attr_value.b
  175. if attr_value.HasField('type'):
  176. return convert_type(attr_value.type)
  177. if attr_value.HasField('tensor'):
  178. return convert_tensor(attr_value.tensor)
  179. if attr_value.HasField('shape'):
  180. return convert_shape(attr_value.shape)
  181. raise Exception()
  182. _TYPE_TO_STRING = {
  183. types_pb2.DataType.DT_HALF: "float16",
  184. types_pb2.DataType.DT_FLOAT: "float32",
  185. types_pb2.DataType.DT_DOUBLE: "float64",
  186. types_pb2.DataType.DT_INT32: "int32",
  187. types_pb2.DataType.DT_UINT8: "uint8",
  188. types_pb2.DataType.DT_UINT16: "uint16",
  189. types_pb2.DataType.DT_UINT32: "uint32",
  190. types_pb2.DataType.DT_UINT64: "uint64",
  191. types_pb2.DataType.DT_INT16: "int16",
  192. types_pb2.DataType.DT_INT8: "int8",
  193. types_pb2.DataType.DT_STRING: "string",
  194. types_pb2.DataType.DT_COMPLEX64: "complex64",
  195. types_pb2.DataType.DT_COMPLEX128: "complex128",
  196. types_pb2.DataType.DT_INT64: "int64",
  197. types_pb2.DataType.DT_BOOL: "bool",
  198. types_pb2.DataType.DT_QINT8: "qint8",
  199. types_pb2.DataType.DT_QUINT8: "quint8",
  200. types_pb2.DataType.DT_QINT16: "qint16",
  201. types_pb2.DataType.DT_QUINT16: "quint16",
  202. types_pb2.DataType.DT_QINT32: "qint32",
  203. types_pb2.DataType.DT_BFLOAT16: "bfloat16",
  204. types_pb2.DataType.DT_RESOURCE: "resource",
  205. types_pb2.DataType.DT_VARIANT: "variant",
  206. types_pb2.DataType.DT_HALF_REF: "float16_ref",
  207. types_pb2.DataType.DT_FLOAT_REF: "float32_ref",
  208. types_pb2.DataType.DT_DOUBLE_REF: "float64_ref",
  209. types_pb2.DataType.DT_INT32_REF: "int32_ref",
  210. types_pb2.DataType.DT_UINT32_REF: "uint32_ref",
  211. types_pb2.DataType.DT_UINT8_REF: "uint8_ref",
  212. types_pb2.DataType.DT_UINT16_REF: "uint16_ref",
  213. types_pb2.DataType.DT_INT16_REF: "int16_ref",
  214. types_pb2.DataType.DT_INT8_REF: "int8_ref",
  215. types_pb2.DataType.DT_STRING_REF: "string_ref",
  216. types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref",
  217. types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref",
  218. types_pb2.DataType.DT_INT64_REF: "int64_ref",
  219. types_pb2.DataType.DT_UINT64_REF: "uint64_ref",
  220. types_pb2.DataType.DT_BOOL_REF: "bool_ref",
  221. types_pb2.DataType.DT_QINT8_REF: "qint8_ref",
  222. types_pb2.DataType.DT_QUINT8_REF: "quint8_ref",
  223. types_pb2.DataType.DT_QINT16_REF: "qint16_ref",
  224. types_pb2.DataType.DT_QUINT16_REF: "quint16_ref",
  225. types_pb2.DataType.DT_QINT32_REF: "qint32_ref",
  226. types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref",
  227. types_pb2.DataType.DT_RESOURCE_REF: "resource_ref",
  228. types_pb2.DataType.DT_VARIANT_REF: "variant_ref",
  229. }
  230. def format_data_type(data_type):
  231. if data_type in _TYPE_TO_STRING:
  232. return _TYPE_TO_STRING[data_type]
  233. raise Exception()
  234. def format_attribute_value(value):
  235. if type(value) is dict and 'type' in value and 'value' in value and value['type'] == 'type':
  236. return format_data_type(value['value'])
  237. if type(value) is str:
  238. return value
  239. if value == True:
  240. return 'true'
  241. if value == False:
  242. return 'false'
  243. raise Exception()
  244. tensorflow_repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/src/tensorflow')
  245. api_def_map = read_api_def_map(os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api'))
  246. input_file = os.path.join(tensorflow_repo_dir, 'tensorflow/core/ops/ops.pbtxt')
  247. ops_list = op_def_pb2.OpList()
  248. with open(input_file) as input_handle:
  249. text_format.Merge(input_handle.read(), ops_list)
  250. json_root = []
  251. for op in ops_list.op:
  252. # print(op.name)
  253. json_schema = {}
  254. if op.name in categories:
  255. json_schema['category'] = categories[op.name]
  256. api_def = api_def_pb2.ApiDef()
  257. if op.name in api_def_map:
  258. api_def = api_def_map[op.name]
  259. # if op.deprecation.version != 0:
  260. # print('[' + op.name + ']')
  261. # print(op.deprecation.version)
  262. # print(op.deprecation.explanation)
  263. api_def_attr_map = {}
  264. for attr in api_def.attr:
  265. api_def_attr_map[attr.name] = attr
  266. api_def_in_arg_map = {}
  267. for in_arg in api_def.in_arg:
  268. api_def_in_arg_map[in_arg.name] = in_arg
  269. api_def_out_arg_map = {}
  270. for out_arg in api_def.out_arg:
  271. api_def_out_arg_map[out_arg.name] = out_arg
  272. if api_def.summary:
  273. json_schema['summary'] = api_def.summary
  274. if api_def.description:
  275. json_schema['description'] = api_def.description
  276. for attr in op.attr:
  277. if not 'attributes' in json_schema:
  278. json_schema['attributes'] = []
  279. json_attribute = {}
  280. json_attribute['name'] = attr.name
  281. attr_type = convert_attr_type(attr.type)
  282. if attr_type:
  283. json_attribute['type'] = attr_type
  284. else:
  285. del json_attribute['type']
  286. if attr.name in api_def_attr_map:
  287. api_def_attr = api_def_attr_map[attr.name]
  288. if api_def_attr.description:
  289. json_attribute['description'] = api_def_attr.description
  290. if attr.has_minimum:
  291. json_attribute['minimum'] = attr.minimum
  292. if attr.HasField('allowed_values'):
  293. allowed_values = convert_attr_value(attr.allowed_values)
  294. description = json_attribute['description'] + ' ' if 'description' in json_attribute else ''
  295. description = description + 'Must be one of the following: ' + ', '.join(list(map(lambda x: "`" + format_attribute_value(x) + "`", allowed_values))) + '.'
  296. json_attribute['description'] = description
  297. if attr.HasField('default_value'):
  298. default_value = convert_attr_value(attr.default_value)
  299. json_attribute['default'] = default_value
  300. json_schema['attributes'].append(json_attribute)
  301. for input_arg in op.input_arg:
  302. if not 'inputs' in json_schema:
  303. json_schema['inputs'] = []
  304. json_input = {}
  305. json_input['name'] = input_arg.name
  306. if input_arg.name in api_def_in_arg_map:
  307. api_def_in_arg = api_def_in_arg_map[input_arg.name]
  308. if api_def_in_arg.description:
  309. json_input['description'] = api_def_in_arg.description
  310. if input_arg.number_attr:
  311. json_input['numberAttr'] = input_arg.number_attr
  312. if input_arg.type:
  313. json_input['type'] = input_arg.type
  314. if input_arg.type_attr:
  315. json_input['typeAttr'] = input_arg.type_attr
  316. if input_arg.type_list_attr:
  317. json_input['typeListAttr'] = input_arg.type_list_attr
  318. if input_arg.is_ref:
  319. json_input['isRef'] = True
  320. json_schema['inputs'].append(json_input)
  321. for output_arg in op.output_arg:
  322. if not 'outputs' in json_schema:
  323. json_schema['outputs'] = []
  324. json_output = {}
  325. json_output['name'] = output_arg.name
  326. if output_arg.name in api_def_out_arg_map:
  327. api_def_out_arg = api_def_out_arg_map[output_arg.name]
  328. if api_def_out_arg.description:
  329. json_output['description'] = api_def_out_arg.description
  330. if output_arg.number_attr:
  331. json_output['numberAttr'] = output_arg.number_attr
  332. if output_arg.type:
  333. json_output['type'] = output_arg.type
  334. elif output_arg.type_attr:
  335. json_output['typeAttr'] = output_arg.type_attr
  336. elif output_arg.type_list_attr:
  337. json_output['typeListAttr'] = output_arg.type_list_attr
  338. if output_arg.is_ref:
  339. json_output['isRef'] = True
  340. json_schema['outputs'].append(json_output)
  341. json_root.append({
  342. 'name': op.name,
  343. 'schema': json_schema
  344. })
  345. json_file = os.path.join(os.path.dirname(__file__), '../src/tf-metadata.json')
  346. with io.open(json_file, 'w', newline='') as fout:
  347. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  348. for line in json_data.splitlines():
  349. line = line.rstrip()
  350. if sys.version_info[0] < 3:
  351. line = unicode(line)
  352. fout.write(line)
  353. fout.write('\n')
  354. if __name__ == '__main__':
  355. command_table = { 'metadata': metadata }
  356. command = sys.argv[1];
  357. command_table[command]()