2
0

tf_script.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. ''' TensorFlow Metadata Script '''
  2. import io
  3. import json
  4. import os
  5. import sys
  6. import google.protobuf.text_format
  7. from tensorflow.core.framework import api_def_pb2
  8. from tensorflow.core.framework import op_def_pb2
  9. from tensorflow.core.framework import types_pb2
  10. def _metadata():
  11. categories = {
  12. 'Assign': 'Control',
  13. 'AvgPool': 'Pool',
  14. 'BatchNormWithGlobalNormalization': 'Normalization',
  15. 'BiasAdd': 'Layer',
  16. 'Concat': 'Tensor',
  17. 'ConcatV2': 'Tensor',
  18. 'Const': 'Constant',
  19. 'Conv2D': 'Layer',
  20. 'DepthwiseConv2dNative': 'Layer',
  21. 'Dequantize': 'Tensor',
  22. 'Elu': 'Activation',
  23. 'FusedBatchNorm': 'Normalization',
  24. 'FusedBatchNormV2': 'Normalization',
  25. 'FusedBatchNormV3': 'Normalization',
  26. 'Gather': 'Transform',
  27. 'Identity': 'Control',
  28. 'LeakyRelu': 'Activation',
  29. 'LRN': 'Normalization',
  30. 'LSTMBlockCell': 'Layer',
  31. 'MaxPool': 'Pool',
  32. 'MaxPoolV2': 'Pool',
  33. 'MaxPoolWithArgmax': 'Pool',
  34. 'Pad': 'Tensor',
  35. 'Relu': 'Activation',
  36. 'Relu6': 'Activation',
  37. 'Reshape': 'Shape',
  38. 'Sigmoid': 'Activation',
  39. 'Slice': 'Tensor',
  40. 'Softmax': 'Activation',
  41. 'Split': 'Tensor',
  42. 'Squeeze': 'Transform',
  43. 'StridedSlice': 'Tensor',
  44. 'swish_f32': 'Activation',
  45. 'Variable': 'Control',
  46. 'VariableV2': 'Control',
  47. }
  48. def find_multiline(line, colon):
  49. if colon == -1:
  50. return None
  51. line = line[colon+1:]
  52. while line.startswith(' '):
  53. line = line[1:]
  54. if line.startswith('<<'):
  55. line = line[2:]
  56. return line
  57. return None
  58. def str_escape(text):
  59. result = ''
  60. for value in text:
  61. if value == '\n':
  62. result += '\\n'
  63. elif value == '\r':
  64. result += "\\r"
  65. elif value == '\t':
  66. result += "\\t"
  67. elif value == '\"':
  68. result += "\\\""
  69. elif value == '\'':
  70. result += "\\'"
  71. elif value == '\\':
  72. result += "\\\\"
  73. else:
  74. result += value
  75. return result
  76. def pbtxt_from_multiline(multiline_pbtxt):
  77. pbtxt = ''
  78. while len(multiline_pbtxt) > 0:
  79. index = multiline_pbtxt.find('\n')
  80. if index == -1:
  81. pbtxt = pbtxt + multiline_pbtxt
  82. multiline_pbtxt = ''
  83. break
  84. line = multiline_pbtxt[0:index]
  85. multiline_pbtxt = multiline_pbtxt[index+1:]
  86. colon = line.find(':')
  87. end = find_multiline(line, colon)
  88. if end is None:
  89. pbtxt = pbtxt + line + '\n'
  90. continue
  91. pbtxt = pbtxt + line[0:colon+1]
  92. unescaped = ''
  93. newline = False
  94. line = ''
  95. while len(multiline_pbtxt) > 0:
  96. index = multiline_pbtxt.find('\n')
  97. line = multiline_pbtxt[0:index]
  98. multiline_pbtxt = multiline_pbtxt[index+1:]
  99. if line.startswith(end):
  100. line = line[len(end):]
  101. break
  102. if newline:
  103. unescaped = unescaped + '\n'
  104. newline = True
  105. unescaped = unescaped + line
  106. line = ''
  107. pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n'
  108. return pbtxt
  109. def read_api_def_map(folder):
  110. api_def_map = {}
  111. file_list = os.listdir(folder)
  112. file_list = sorted(file_list)
  113. for filename in file_list:
  114. if filename.endswith('.pbtxt'):
  115. api_defs = api_def_pb2.ApiDefs()
  116. filename = folder + '/' + filename
  117. with open(filename, 'r', encoding='utf-8') as file:
  118. multiline_pbtxt = file.read()
  119. pbtxt = pbtxt_from_multiline(multiline_pbtxt)
  120. google.protobuf.text_format.Merge(pbtxt, api_defs)
  121. for api_def in api_defs.op:
  122. api_def_map[api_def.graph_op_name] = api_def
  123. return api_def_map
  124. def convert_type(value):
  125. return { 'type': 'type', 'value': value }
  126. def convert_tensor(tensor):
  127. return { 'type': 'tensor', 'value': '?' }
  128. def convert_shape(shape):
  129. return { 'type': 'shape', 'value': '?' }
  130. def convert_number(number):
  131. if number == float('inf'):
  132. return 'NaN'
  133. if number == float('-inf'):
  134. return '-NaN'
  135. return number
  136. attr_type_table = {
  137. 'type': 'type', 'list(type)': 'type[]',
  138. 'bool': 'boolean',
  139. 'int': 'int64', 'list(int)': 'int64[]',
  140. 'float': 'float32', 'list(float)': 'float32[]',
  141. 'string': 'string', 'list(string)': 'string[]',
  142. 'shape': 'shape', 'list(shape)': 'shape[]',
  143. 'tensor': 'tensor',
  144. 'func': 'function', 'list(func)': 'function[]'
  145. }
  146. def convert_attr_type(attr_type):
  147. if attr_type in attr_type_table:
  148. return attr_type_table[attr_type]
  149. print(attr_type)
  150. return attr_type
  151. def convert_attr_value(attr_value):
  152. if attr_value.HasField('list'):
  153. result = []
  154. attr_value_list = attr_value.list
  155. if len(attr_value_list.s) > 0:
  156. for value in attr_value_list.s:
  157. result.append(value.decode('utf8'))
  158. if len(attr_value_list.i) > 0:
  159. for i in attr_value_list.i:
  160. result.append(i)
  161. if len(attr_value_list.f) > 0:
  162. for value in attr_value_list.f:
  163. result.append(convert_number(value))
  164. if len(attr_value_list.type) > 0:
  165. for value in attr_value_list.type:
  166. result.append(convert_type(value))
  167. if len(result) == 0:
  168. for _, value in attr_value_list.ListFields():
  169. if len(value) > 0:
  170. raise Exception()
  171. return result
  172. if attr_value.HasField('s'):
  173. value = attr_value.s.decode('utf8')
  174. elif attr_value.HasField('i'):
  175. value = attr_value.i
  176. elif attr_value.HasField('f'):
  177. value = convert_number(attr_value.f)
  178. elif attr_value.HasField('b'):
  179. value = attr_value.b
  180. elif attr_value.HasField('type'):
  181. value = convert_type(attr_value.type)
  182. elif attr_value.HasField('tensor'):
  183. value = convert_tensor(attr_value.tensor)
  184. elif attr_value.HasField('shape'):
  185. value = convert_shape(attr_value.shape)
  186. else:
  187. raise Exception()
  188. return value
  189. type_to_string_map = {
  190. types_pb2.DataType.DT_HALF: "float16",
  191. types_pb2.DataType.DT_FLOAT: "float32",
  192. types_pb2.DataType.DT_DOUBLE: "float64",
  193. types_pb2.DataType.DT_INT32: "int32",
  194. types_pb2.DataType.DT_UINT8: "uint8",
  195. types_pb2.DataType.DT_UINT16: "uint16",
  196. types_pb2.DataType.DT_UINT32: "uint32",
  197. types_pb2.DataType.DT_UINT64: "uint64",
  198. types_pb2.DataType.DT_INT16: "int16",
  199. types_pb2.DataType.DT_INT8: "int8",
  200. types_pb2.DataType.DT_STRING: "string",
  201. types_pb2.DataType.DT_COMPLEX64: "complex64",
  202. types_pb2.DataType.DT_COMPLEX128: "complex128",
  203. types_pb2.DataType.DT_INT64: "int64",
  204. types_pb2.DataType.DT_BOOL: "bool",
  205. types_pb2.DataType.DT_QINT8: "qint8",
  206. types_pb2.DataType.DT_QUINT8: "quint8",
  207. types_pb2.DataType.DT_QINT16: "qint16",
  208. types_pb2.DataType.DT_QUINT16: "quint16",
  209. types_pb2.DataType.DT_QINT32: "qint32",
  210. types_pb2.DataType.DT_BFLOAT16: "bfloat16",
  211. types_pb2.DataType.DT_RESOURCE: "resource",
  212. types_pb2.DataType.DT_VARIANT: "variant",
  213. types_pb2.DataType.DT_HALF_REF: "float16_ref",
  214. types_pb2.DataType.DT_FLOAT_REF: "float32_ref",
  215. types_pb2.DataType.DT_DOUBLE_REF: "float64_ref",
  216. types_pb2.DataType.DT_INT32_REF: "int32_ref",
  217. types_pb2.DataType.DT_UINT32_REF: "uint32_ref",
  218. types_pb2.DataType.DT_UINT8_REF: "uint8_ref",
  219. types_pb2.DataType.DT_UINT16_REF: "uint16_ref",
  220. types_pb2.DataType.DT_INT16_REF: "int16_ref",
  221. types_pb2.DataType.DT_INT8_REF: "int8_ref",
  222. types_pb2.DataType.DT_STRING_REF: "string_ref",
  223. types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref",
  224. types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref",
  225. types_pb2.DataType.DT_INT64_REF: "int64_ref",
  226. types_pb2.DataType.DT_UINT64_REF: "uint64_ref",
  227. types_pb2.DataType.DT_BOOL_REF: "bool_ref",
  228. types_pb2.DataType.DT_QINT8_REF: "qint8_ref",
  229. types_pb2.DataType.DT_QUINT8_REF: "quint8_ref",
  230. types_pb2.DataType.DT_QINT16_REF: "qint16_ref",
  231. types_pb2.DataType.DT_QUINT16_REF: "quint16_ref",
  232. types_pb2.DataType.DT_QINT32_REF: "qint32_ref",
  233. types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref",
  234. types_pb2.DataType.DT_RESOURCE_REF: "resource_ref",
  235. types_pb2.DataType.DT_VARIANT_REF: "variant_ref",
  236. }
  237. def format_data_type(data_type):
  238. if data_type in type_to_string_map:
  239. return type_to_string_map[data_type]
  240. raise Exception()
  241. def format_attribute_value(value):
  242. if isinstance(value, dict) and 'type' in value and 'value' in value and value['type'] == 'type':
  243. return format_data_type(value['value'])
  244. if isinstance(value, str):
  245. return value
  246. if value is True:
  247. return 'true'
  248. if value is False:
  249. return 'false'
  250. raise Exception()
  251. repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/source/tensorflow')
  252. api_def_map = read_api_def_map(os.path.join(repo_dir, 'tensorflow/core/api_def/base_api'))
  253. input_file = os.path.join(repo_dir, 'tensorflow/core/ops/ops.pbtxt')
  254. ops_list = op_def_pb2.OpList()
  255. with open(input_file, 'r', encoding='utf-8') as file:
  256. google.protobuf.text_format.Merge(file.read(), ops_list)
  257. json_root = []
  258. for operator in ops_list.op:
  259. # print(op.name)
  260. json_schema = {}
  261. json_schema['name'] = operator.name
  262. if operator.name in categories:
  263. json_schema['category'] = categories[operator.name]
  264. api_def = api_def_pb2.ApiDef()
  265. if operator.name in api_def_map:
  266. api_def = api_def_map[operator.name]
  267. # if op.deprecation.version != 0:
  268. # print('[' + op.name + ']')
  269. # print(op.deprecation.version)
  270. # print(op.deprecation.explanation)
  271. api_def_attr_map = {}
  272. for attr in api_def.attr:
  273. api_def_attr_map[attr.name] = attr
  274. api_def_in_arg_map = {}
  275. for in_arg in api_def.in_arg:
  276. api_def_in_arg_map[in_arg.name] = in_arg
  277. api_def_out_arg_map = {}
  278. for out_arg in api_def.out_arg:
  279. api_def_out_arg_map[out_arg.name] = out_arg
  280. if api_def.summary:
  281. json_schema['summary'] = api_def.summary
  282. if api_def.description:
  283. json_schema['description'] = api_def.description
  284. for attr in operator.attr:
  285. if 'attributes' not in json_schema:
  286. json_schema['attributes'] = []
  287. json_attribute = {}
  288. json_attribute['name'] = attr.name
  289. attr_type = convert_attr_type(attr.type)
  290. if attr_type:
  291. json_attribute['type'] = attr_type
  292. else:
  293. del json_attribute['type']
  294. if attr.name in api_def_attr_map:
  295. api_def_attr = api_def_attr_map[attr.name]
  296. if api_def_attr.description:
  297. json_attribute['description'] = api_def_attr.description
  298. if attr.has_minimum:
  299. json_attribute['minimum'] = attr.minimum
  300. if attr.HasField('allowed_values'):
  301. allowed_values = convert_attr_value(attr.allowed_values)
  302. description = json_attribute['description'] + ' ' if 'description' in json_attribute else ''
  303. description = description + 'Must be one of the following: ' + ', '.join(list(map(lambda x: "`" + format_attribute_value(x) + "`", allowed_values))) + '.'
  304. json_attribute['description'] = description
  305. if attr.HasField('default_value'):
  306. default_value = convert_attr_value(attr.default_value)
  307. json_attribute['default'] = default_value
  308. json_schema['attributes'].append(json_attribute)
  309. for input_arg in operator.input_arg:
  310. if 'inputs' not in json_schema:
  311. json_schema['inputs'] = []
  312. json_input = {}
  313. json_input['name'] = input_arg.name
  314. if input_arg.name in api_def_in_arg_map:
  315. api_def_in_arg = api_def_in_arg_map[input_arg.name]
  316. if api_def_in_arg.description:
  317. json_input['description'] = api_def_in_arg.description
  318. if input_arg.number_attr:
  319. json_input['numberAttr'] = input_arg.number_attr
  320. if input_arg.type:
  321. json_input['type'] = input_arg.type
  322. if input_arg.type_attr:
  323. json_input['typeAttr'] = input_arg.type_attr
  324. if input_arg.type_list_attr:
  325. json_input['typeListAttr'] = input_arg.type_list_attr
  326. if input_arg.is_ref:
  327. json_input['isRef'] = True
  328. json_schema['inputs'].append(json_input)
  329. for output_arg in operator.output_arg:
  330. if 'outputs' not in json_schema:
  331. json_schema['outputs'] = []
  332. json_output = {}
  333. json_output['name'] = output_arg.name
  334. if output_arg.name in api_def_out_arg_map:
  335. api_def_out_arg = api_def_out_arg_map[output_arg.name]
  336. if api_def_out_arg.description:
  337. json_output['description'] = api_def_out_arg.description
  338. if output_arg.number_attr:
  339. json_output['numberAttr'] = output_arg.number_attr
  340. if output_arg.type:
  341. json_output['type'] = output_arg.type
  342. elif output_arg.type_attr:
  343. json_output['typeAttr'] = output_arg.type_attr
  344. elif output_arg.type_list_attr:
  345. json_output['typeListAttr'] = output_arg.type_list_attr
  346. if output_arg.is_ref:
  347. json_output['isRef'] = True
  348. json_schema['outputs'].append(json_output)
  349. json_root.append(json_schema)
  350. json_file = os.path.join(os.path.dirname(__file__), '../source/tf-metadata.json')
  351. with io.open(json_file, 'w', encoding='utf-8', newline='') as fout:
  352. json_data = json.dumps(json_root, sort_keys=False, indent=2)
  353. for line in json_data.splitlines():
  354. line = line.rstrip()
  355. fout.write(line)
  356. fout.write('\n')
  357. def main():
  358. command_table = { 'metadata': _metadata }
  359. command = sys.argv[1] if len(sys.argv) > 1 else 'metadata'
  360. command_table[command]()
  361. if __name__ == '__main__':
  362. main()