tf_metadata.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. ''' TensorFlow Metadata Script '''
  2. import json
  3. import logging
  4. import os
  5. import google.protobuf # pylint: disable=import-error
  6. logging.getLogger('tensorflow').setLevel(logging.ERROR)
  7. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  8. from tensorflow.core.framework import api_def_pb2 # pylint: disable=import-error,no-name-in-module,wrong-import-position
  9. from tensorflow.core.framework import op_def_pb2 # pylint: disable=import-error,no-name-in-module,wrong-import-position
  10. from tensorflow.core.framework import types_pb2 # pylint: disable=import-error,no-name-in-module,wrong-import-position
  11. def _read(path):
  12. with open(path, 'r', encoding='utf-8') as file:
  13. return file.read()
  14. def _write(path, content):
  15. with open(path, 'w', encoding='utf-8') as file:
  16. file.write(content)
  17. def _find_multiline(line, colon):
  18. if colon == -1:
  19. return None
  20. line = line[colon+1:]
  21. while line.startswith(' '):
  22. line = line[1:]
  23. if line.startswith('<<'):
  24. line = line[2:]
  25. return line
  26. return None
  27. def _str_escape(text):
  28. result = ''
  29. for value in text:
  30. if value == '\n':
  31. result += '\\n'
  32. elif value == '\r':
  33. result += "\\r"
  34. elif value == '\t':
  35. result += "\\t"
  36. elif value == '\"':
  37. result += "\\\""
  38. elif value == '\'':
  39. result += "\\'"
  40. elif value == '\\':
  41. result += "\\\\"
  42. else:
  43. result += value
  44. return result
  45. def _pbtxt_from_multiline(multiline_pbtxt):
  46. pbtxt = ''
  47. while len(multiline_pbtxt) > 0:
  48. index = multiline_pbtxt.find('\n')
  49. if index == -1:
  50. pbtxt = pbtxt + multiline_pbtxt
  51. multiline_pbtxt = ''
  52. break
  53. line = multiline_pbtxt[0:index]
  54. multiline_pbtxt = multiline_pbtxt[index+1:]
  55. colon = line.find(':')
  56. end = _find_multiline(line, colon)
  57. if end is None:
  58. pbtxt = pbtxt + line + '\n'
  59. continue
  60. pbtxt = pbtxt + line[0:colon+1]
  61. unescaped = ''
  62. newline = False
  63. line = ''
  64. while len(multiline_pbtxt) > 0:
  65. index = multiline_pbtxt.find('\n')
  66. line = multiline_pbtxt[0:index]
  67. multiline_pbtxt = multiline_pbtxt[index+1:]
  68. if line.startswith(end):
  69. line = line[len(end):]
  70. break
  71. if newline:
  72. unescaped = unescaped + '\n'
  73. newline = True
  74. unescaped = unescaped + line
  75. line = ''
  76. pbtxt = pbtxt + '\"' + _str_escape(unescaped) + '\"' + line + '\n'
  77. return pbtxt
  78. def _read_op_list(file):
  79. op_list = op_def_pb2.OpList() # pylint: disable=no-member
  80. content = _read(file)
  81. google.protobuf.text_format.Merge(content, op_list)
  82. return op_list
  83. def _read_api_def_map(folder):
  84. api_def_map = {}
  85. for filename in sorted(os.listdir(folder)):
  86. if filename.endswith('.pbtxt'):
  87. api_defs = api_def_pb2.ApiDefs() # pylint: disable=no-member
  88. filename = folder + '/' + filename
  89. with open(filename, 'r', encoding='utf-8') as file:
  90. multiline_pbtxt = file.read()
  91. pbtxt = _pbtxt_from_multiline(multiline_pbtxt)
  92. google.protobuf.text_format.Merge(pbtxt, api_defs)
  93. for api_def in api_defs.op:
  94. api_def_map[api_def.graph_op_name] = api_def
  95. return api_def_map
  96. def _convert_type(value):
  97. return { 'type': 'type', 'value': value }
  98. def _convert_tensor(tensor): # pylint: disable=unused-argument
  99. return { 'type': 'tensor', 'value': '?' }
  100. def _convert_shape(shape): # pylint: disable=unused-argument
  101. return { 'type': 'shape', 'value': '?' }
  102. def _convert_number(number):
  103. if number == float('inf'):
  104. return 'NaN'
  105. if number == float('-inf'):
  106. return '-NaN'
  107. return number
  108. attr_type_table = {
  109. 'type': 'type', 'list(type)': 'type[]',
  110. 'bool': 'boolean',
  111. 'int': 'int64', 'list(int)': 'int64[]',
  112. 'float': 'float32', 'list(float)': 'float32[]',
  113. 'string': 'string', 'list(string)': 'string[]',
  114. 'shape': 'shape', 'list(shape)': 'shape[]',
  115. 'tensor': 'tensor',
  116. 'func': 'function', 'list(func)': 'function[]'
  117. }
  118. def _convert_attr_type(attr_type):
  119. if attr_type in attr_type_table:
  120. return attr_type_table[attr_type]
  121. print(attr_type)
  122. return attr_type
  123. def _convert_attr_list(attr_value):
  124. result = []
  125. attr_value_list = attr_value.list
  126. if len(attr_value_list.s) > 0:
  127. for value in attr_value_list.s:
  128. result.append(value.decode('utf8'))
  129. if len(attr_value_list.i) > 0:
  130. for i in attr_value_list.i:
  131. result.append(i)
  132. if len(attr_value_list.f) > 0:
  133. for value in attr_value_list.f:
  134. result.append(_convert_number(value))
  135. if len(attr_value_list.type) > 0:
  136. for value in attr_value_list.type:
  137. result.append(_convert_type(value))
  138. if len(result) == 0:
  139. for _, value in attr_value_list.ListFields():
  140. if len(value) > 0:
  141. raise NotImplementedError()
  142. return result
  143. def _convert_attr_value(attr_value):
  144. if attr_value.HasField('list'):
  145. value = _convert_attr_list(attr_value)
  146. elif attr_value.HasField('s'):
  147. value = attr_value.s.decode('utf8')
  148. elif attr_value.HasField('i'):
  149. value = attr_value.i
  150. elif attr_value.HasField('f'):
  151. value = _convert_number(attr_value.f)
  152. elif attr_value.HasField('b'):
  153. value = attr_value.b
  154. elif attr_value.HasField('type'):
  155. value = _convert_type(attr_value.type)
  156. elif attr_value.HasField('tensor'):
  157. value = _convert_tensor(attr_value.tensor)
  158. elif attr_value.HasField('shape'):
  159. value = _convert_shape(attr_value.shape)
  160. else:
  161. raise NotImplementedError()
  162. return value
  163. DataType = types_pb2.DataType # pylint: disable=no-member
  164. type_to_string_map = {
  165. DataType.DT_HALF: "float16",
  166. DataType.DT_FLOAT: "float32",
  167. DataType.DT_DOUBLE: "float64",
  168. DataType.DT_INT32: "int32",
  169. DataType.DT_UINT8: "uint8",
  170. DataType.DT_UINT16: "uint16",
  171. DataType.DT_UINT32: "uint32",
  172. DataType.DT_UINT64: "uint64",
  173. DataType.DT_INT16: "int16",
  174. DataType.DT_INT8: "int8",
  175. DataType.DT_STRING: "string",
  176. DataType.DT_COMPLEX64: "complex64",
  177. DataType.DT_COMPLEX128: "complex128",
  178. DataType.DT_INT64: "int64",
  179. DataType.DT_BOOL: "bool",
  180. DataType.DT_QINT8: "qint8",
  181. DataType.DT_QUINT8: "quint8",
  182. DataType.DT_QINT16: "qint16",
  183. DataType.DT_QUINT16: "quint16",
  184. DataType.DT_QINT32: "qint32",
  185. DataType.DT_BFLOAT16: "bfloat16",
  186. DataType.DT_RESOURCE: "resource",
  187. DataType.DT_VARIANT: "variant",
  188. DataType.DT_HALF_REF: "float16_ref",
  189. DataType.DT_FLOAT_REF: "float32_ref",
  190. DataType.DT_DOUBLE_REF: "float64_ref",
  191. DataType.DT_INT32_REF: "int32_ref",
  192. DataType.DT_UINT32_REF: "uint32_ref",
  193. DataType.DT_UINT8_REF: "uint8_ref",
  194. DataType.DT_UINT16_REF: "uint16_ref",
  195. DataType.DT_INT16_REF: "int16_ref",
  196. DataType.DT_INT8_REF: "int8_ref",
  197. DataType.DT_STRING_REF: "string_ref",
  198. DataType.DT_COMPLEX64_REF: "complex64_ref",
  199. DataType.DT_COMPLEX128_REF: "complex128_ref",
  200. DataType.DT_INT64_REF: "int64_ref",
  201. DataType.DT_UINT64_REF: "uint64_ref",
  202. DataType.DT_BOOL_REF: "bool_ref",
  203. DataType.DT_QINT8_REF: "qint8_ref",
  204. DataType.DT_QUINT8_REF: "quint8_ref",
  205. DataType.DT_QINT16_REF: "qint16_ref",
  206. DataType.DT_QUINT16_REF: "quint16_ref",
  207. DataType.DT_QINT32_REF: "qint32_ref",
  208. DataType.DT_BFLOAT16_REF: "bfloat16_ref",
  209. DataType.DT_RESOURCE_REF: "resource_ref",
  210. DataType.DT_VARIANT_REF: "variant_ref",
  211. }
  212. def _format_data_type(data_type):
  213. if data_type in type_to_string_map:
  214. return type_to_string_map[data_type]
  215. raise KeyError()
  216. def _format_attribute_value(value):
  217. if isinstance(value, dict) and \
  218. 'type' in value and 'value' in value and value['type'] == 'type':
  219. return _format_data_type(value['value'])
  220. if isinstance(value, str):
  221. return value
  222. if value is True:
  223. return 'true'
  224. if value is False:
  225. return 'false'
  226. raise NotImplementedError()
  227. def _update_attributes(json_schema, operator, api_def):
  228. api_def_attr_map = {}
  229. for attr in api_def.attr:
  230. api_def_attr_map[attr.name] = attr
  231. for attr in operator.attr:
  232. if 'attributes' not in json_schema:
  233. json_schema['attributes'] = []
  234. json_attribute = {}
  235. json_attribute['name'] = attr.name
  236. attr_type = _convert_attr_type(attr.type)
  237. if attr_type:
  238. json_attribute['type'] = attr_type
  239. else:
  240. del json_attribute['type']
  241. if attr.name in api_def_attr_map:
  242. api_def_attr = api_def_attr_map[attr.name]
  243. if api_def_attr.description:
  244. json_attribute['description'] = api_def_attr.description
  245. if attr.has_minimum:
  246. json_attribute['minimum'] = attr.minimum
  247. if attr.HasField('allowed_values'):
  248. allowed_values = _convert_attr_value(attr.allowed_values)
  249. description = json_attribute['description'] + \
  250. ' ' if 'description' in json_attribute else ''
  251. allowed_values = list( \
  252. map(lambda x: "`" + _format_attribute_value(x) + "`", \
  253. allowed_values))
  254. description = description + \
  255. 'Must be one of the following: ' + ', '.join(allowed_values) + '.'
  256. json_attribute['description'] = description
  257. if attr.HasField('default_value'):
  258. default_value = _convert_attr_value(attr.default_value)
  259. json_attribute['default'] = default_value
  260. json_schema['attributes'].append(json_attribute)
  261. def _update_inputs(json_schema, operator, api_def):
  262. api_def_in_arg_map = {}
  263. for in_arg in api_def.in_arg:
  264. api_def_in_arg_map[in_arg.name] = in_arg
  265. for input_arg in operator.input_arg:
  266. if 'inputs' not in json_schema:
  267. json_schema['inputs'] = []
  268. json_input = {}
  269. json_input['name'] = input_arg.name
  270. if input_arg.name in api_def_in_arg_map:
  271. api_def_in_arg = api_def_in_arg_map[input_arg.name]
  272. if api_def_in_arg.description:
  273. json_input['description'] = api_def_in_arg.description
  274. if input_arg.number_attr:
  275. json_input['numberAttr'] = input_arg.number_attr
  276. if input_arg.type:
  277. json_input['type'] = input_arg.type
  278. if input_arg.type_attr:
  279. json_input['typeAttr'] = input_arg.type_attr
  280. if input_arg.type_list_attr:
  281. json_input['typeListAttr'] = input_arg.type_list_attr
  282. if input_arg.is_ref:
  283. json_input['isRef'] = True
  284. json_schema['inputs'].append(json_input)
  285. def _update_outputs(json_schema, operator, api_def):
  286. api_def_out_arg_map = {}
  287. for out_arg in api_def.out_arg:
  288. api_def_out_arg_map[out_arg.name] = out_arg
  289. for output_arg in operator.output_arg:
  290. if 'outputs' not in json_schema:
  291. json_schema['outputs'] = []
  292. json_output = {}
  293. json_output['name'] = output_arg.name
  294. if output_arg.name in api_def_out_arg_map:
  295. api_def_out_arg = api_def_out_arg_map[output_arg.name]
  296. if api_def_out_arg.description:
  297. json_output['description'] = api_def_out_arg.description
  298. if output_arg.number_attr:
  299. json_output['numberAttr'] = output_arg.number_attr
  300. if output_arg.type:
  301. json_output['type'] = output_arg.type
  302. elif output_arg.type_attr:
  303. json_output['typeAttr'] = output_arg.type_attr
  304. elif output_arg.type_list_attr:
  305. json_output['typeListAttr'] = output_arg.type_list_attr
  306. if output_arg.is_ref:
  307. json_output['isRef'] = True
  308. json_schema['outputs'].append(json_output)
  309. categories = {
  310. 'Assign': 'Control',
  311. 'AvgPool': 'Pool',
  312. 'BatchNormWithGlobalNormalization': 'Normalization',
  313. 'BiasAdd': 'Layer',
  314. 'Concat': 'Tensor',
  315. 'ConcatV2': 'Tensor',
  316. 'Const': 'Constant',
  317. 'Conv2D': 'Layer',
  318. 'DepthwiseConv2dNative': 'Layer',
  319. 'Dequantize': 'Tensor',
  320. 'Elu': 'Activation',
  321. 'FusedBatchNorm': 'Normalization',
  322. 'FusedBatchNormV2': 'Normalization',
  323. 'FusedBatchNormV3': 'Normalization',
  324. 'Gather': 'Transform',
  325. 'Identity': 'Control',
  326. 'LeakyRelu': 'Activation',
  327. 'LRN': 'Normalization',
  328. 'LSTMBlockCell': 'Layer',
  329. 'MaxPool': 'Pool',
  330. 'MaxPoolV2': 'Pool',
  331. 'MaxPoolWithArgmax': 'Pool',
  332. 'Pad': 'Tensor',
  333. 'Relu': 'Activation',
  334. 'Relu6': 'Activation',
  335. 'Reshape': 'Shape',
  336. 'Sigmoid': 'Activation',
  337. 'Slice': 'Tensor',
  338. 'Softmax': 'Activation',
  339. 'Split': 'Tensor',
  340. 'Squeeze': 'Transform',
  341. 'StridedSlice': 'Tensor',
  342. 'swish_f32': 'Activation',
  343. 'Variable': 'Control',
  344. 'VariableV2': 'Control',
  345. }
  346. def _metadata():
  347. root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
  348. core_dir = os.path.join(root_dir, 'third_party', 'source', 'tensorflow', 'tensorflow', 'core')
  349. api_def_map = _read_api_def_map(os.path.join(core_dir, 'api_def' , 'base_api'))
  350. ops_list = _read_op_list(os.path.join(core_dir, 'ops', 'ops.pbtxt'))
  351. json_root = []
  352. for operator in ops_list.op:
  353. json_schema = {}
  354. json_schema['name'] = operator.name
  355. if operator.name in categories:
  356. json_schema['category'] = categories[operator.name]
  357. api_def = api_def_pb2.ApiDef() # pylint: disable=no-member
  358. if operator.name in api_def_map:
  359. api_def = api_def_map[operator.name]
  360. if api_def.summary:
  361. json_schema['summary'] = api_def.summary
  362. if api_def.description:
  363. json_schema['description'] = api_def.description
  364. _update_attributes(json_schema, operator, api_def)
  365. _update_inputs(json_schema, operator, api_def)
  366. _update_outputs(json_schema, operator, api_def)
  367. json_root.append(json_schema)
  368. json_file = os.path.join(root_dir, 'source', 'tf-metadata.json')
  369. _write(json_file, json.dumps(json_root, sort_keys=False, indent=2))
  370. def main(): # pylint: disable=missing-function-docstring
  371. _metadata()
  372. if __name__ == '__main__':
  373. main()