tf_script.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. """ TensorFlow Metadata Script """
  2. import json
  3. import logging
  4. import os
  5. import re
  6. import sys
  7. import google.protobuf
  8. logging.getLogger("tensorflow").setLevel(logging.ERROR)
  9. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  10. dup_stderr = os.dup(sys.stderr.fileno())
  11. null = os.open(os.devnull, os.O_WRONLY)
  12. os.dup2(null, sys.stderr.fileno())
  13. os.close(null)
  14. from tensorflow.core.framework import ( # noqa: E402 # type: ignore
  15. api_def_pb2,
  16. op_def_pb2,
  17. types_pb2,
  18. )
  19. os.dup2(dup_stderr, sys.stderr.fileno())
  20. os.close(dup_stderr)
  21. def _read(path):
  22. with open(path, encoding="utf-8") as file:
  23. return file.read()
  24. def _write(path, content):
  25. with open(path, "w", encoding="utf-8") as file:
  26. file.write(content)
  27. def _find_multiline(line, colon):
  28. if colon == -1:
  29. return None
  30. line = line[colon+1:]
  31. while line.startswith(" "):
  32. line = line[1:]
  33. if line.startswith("<<"):
  34. line = line[2:]
  35. return line
  36. return None
  37. def _str_escape(text):
  38. result = ""
  39. for value in text:
  40. if value == "\n":
  41. result += "\\n"
  42. elif value == "\r":
  43. result += "\\r"
  44. elif value == "\t":
  45. result += "\\t"
  46. elif value == '"':
  47. result += '\\"'
  48. elif value == "'":
  49. result += "\\'"
  50. elif value == "\\":
  51. result += "\\\\"
  52. else:
  53. result += value
  54. return result
  55. def _pbtxt_from_multiline(multiline_pbtxt):
  56. pbtxt = ""
  57. while len(multiline_pbtxt) > 0:
  58. index = multiline_pbtxt.find("\n")
  59. if index == -1:
  60. pbtxt = pbtxt + multiline_pbtxt
  61. multiline_pbtxt = ""
  62. break
  63. line = multiline_pbtxt[0:index]
  64. multiline_pbtxt = multiline_pbtxt[index+1:]
  65. colon = line.find(":")
  66. end = _find_multiline(line, colon)
  67. if end is None:
  68. pbtxt = pbtxt + line + "\n"
  69. continue
  70. pbtxt = pbtxt + line[0:colon+1]
  71. unescaped = ""
  72. newline = False
  73. line = ""
  74. while len(multiline_pbtxt) > 0:
  75. index = multiline_pbtxt.find("\n")
  76. line = multiline_pbtxt[0:index]
  77. multiline_pbtxt = multiline_pbtxt[index+1:]
  78. if line.startswith(end):
  79. line = line[len(end):]
  80. break
  81. if newline:
  82. unescaped = unescaped + "\n"
  83. newline = True
  84. unescaped = unescaped + line
  85. line = ""
  86. pbtxt = pbtxt + '"' + _str_escape(unescaped) + '"' + line + "\n"
  87. return pbtxt
  88. def _read_op_list(file):
  89. op_list = op_def_pb2.OpList()
  90. content = _read(file)
  91. content = re.sub(r"^go/[a-z]+\s*", "", content)
  92. google.protobuf.text_format.Merge(content, op_list)
  93. return op_list
  94. def _read_api_def_map(folder):
  95. api_def_map = {}
  96. for filename in sorted(os.listdir(folder)):
  97. if filename.endswith(".pbtxt"):
  98. api_defs = api_def_pb2.ApiDefs()
  99. filename = folder + "/" + filename
  100. with open(filename, encoding="utf-8") as file:
  101. multiline_pbtxt = file.read()
  102. pbtxt = _pbtxt_from_multiline(multiline_pbtxt)
  103. google.protobuf.text_format.Merge(pbtxt, api_defs)
  104. for api_def in api_defs.op:
  105. api_def_map[api_def.graph_op_name] = api_def
  106. return api_def_map
  107. def _convert_type(value):
  108. return { "type": "type", "value": value }
  109. def _convert_tensor(tensor):
  110. return { "type": "tensor", "value": "?" }
  111. def _convert_shape(shape):
  112. return { "type": "shape", "value": "?" }
  113. def _convert_number(number):
  114. if number == float("inf"):
  115. return "NaN"
  116. if number == float("-inf"):
  117. return "-NaN"
  118. return number
  119. attr_type_table = {
  120. "type": "type", "list(type)": "type[]",
  121. "bool": "boolean",
  122. "int": "int64", "list(int)": "int64[]",
  123. "float": "float32", "list(float)": "float32[]",
  124. "string": "string", "list(string)": "string[]",
  125. "shape": "shape", "list(shape)": "shape[]",
  126. "tensor": "tensor",
  127. "func": "function", "list(func)": "function[]"
  128. }
  129. def _convert_attr_type(attr_type):
  130. if attr_type not in attr_type_table:
  131. raise ValueError(f"Unknown attribute type '{attr_type}'")
  132. return attr_type_table[attr_type]
  133. def _convert_attr_list(attr_value):
  134. result = []
  135. attr_value_list = attr_value.list
  136. if len(attr_value_list.s) > 0:
  137. for value in attr_value_list.s:
  138. result.append(value.decode("utf8"))
  139. if len(attr_value_list.i) > 0:
  140. for i in attr_value_list.i:
  141. result.append(i)
  142. if len(attr_value_list.f) > 0:
  143. for value in attr_value_list.f:
  144. result.append(_convert_number(value))
  145. if len(attr_value_list.type) > 0:
  146. for value in attr_value_list.type:
  147. result.append(_convert_type(value))
  148. if len(result) == 0:
  149. for _, value in attr_value_list.ListFields():
  150. if len(value) > 0:
  151. raise NotImplementedError()
  152. return result
  153. def _convert_attr_value(attr_value):
  154. if attr_value.HasField("list"):
  155. value = _convert_attr_list(attr_value)
  156. elif attr_value.HasField("s"):
  157. value = attr_value.s.decode("utf8")
  158. elif attr_value.HasField("i"):
  159. value = attr_value.i
  160. elif attr_value.HasField("f"):
  161. value = _convert_number(attr_value.f)
  162. elif attr_value.HasField("b"):
  163. value = attr_value.b
  164. elif attr_value.HasField("type"):
  165. value = _convert_type(attr_value.type)
  166. elif attr_value.HasField("tensor"):
  167. value = _convert_tensor(attr_value.tensor)
  168. elif attr_value.HasField("shape"):
  169. value = _convert_shape(attr_value.shape)
  170. else:
  171. raise NotImplementedError()
  172. return value
  173. DataType = types_pb2.DataType
  174. type_to_string_map = {
  175. DataType.DT_HALF: "float16",
  176. DataType.DT_FLOAT: "float32",
  177. DataType.DT_DOUBLE: "float64",
  178. DataType.DT_INT32: "int32",
  179. DataType.DT_UINT8: "uint8",
  180. DataType.DT_UINT16: "uint16",
  181. DataType.DT_UINT32: "uint32",
  182. DataType.DT_UINT64: "uint64",
  183. DataType.DT_INT16: "int16",
  184. DataType.DT_INT8: "int8",
  185. DataType.DT_STRING: "string",
  186. DataType.DT_COMPLEX64: "complex64",
  187. DataType.DT_COMPLEX128: "complex128",
  188. DataType.DT_INT64: "int64",
  189. DataType.DT_BOOL: "bool",
  190. DataType.DT_QINT8: "qint8",
  191. DataType.DT_QUINT8: "quint8",
  192. DataType.DT_QINT16: "qint16",
  193. DataType.DT_QUINT16: "quint16",
  194. DataType.DT_QINT32: "qint32",
  195. DataType.DT_BFLOAT16: "bfloat16",
  196. DataType.DT_RESOURCE: "resource",
  197. DataType.DT_VARIANT: "variant",
  198. DataType.DT_HALF_REF: "float16_ref",
  199. DataType.DT_FLOAT_REF: "float32_ref",
  200. DataType.DT_DOUBLE_REF: "float64_ref",
  201. DataType.DT_INT32_REF: "int32_ref",
  202. DataType.DT_UINT32_REF: "uint32_ref",
  203. DataType.DT_UINT8_REF: "uint8_ref",
  204. DataType.DT_UINT16_REF: "uint16_ref",
  205. DataType.DT_INT16_REF: "int16_ref",
  206. DataType.DT_INT8_REF: "int8_ref",
  207. DataType.DT_STRING_REF: "string_ref",
  208. DataType.DT_COMPLEX64_REF: "complex64_ref",
  209. DataType.DT_COMPLEX128_REF: "complex128_ref",
  210. DataType.DT_INT64_REF: "int64_ref",
  211. DataType.DT_UINT64_REF: "uint64_ref",
  212. DataType.DT_BOOL_REF: "bool_ref",
  213. DataType.DT_QINT8_REF: "qint8_ref",
  214. DataType.DT_QUINT8_REF: "quint8_ref",
  215. DataType.DT_QINT16_REF: "qint16_ref",
  216. DataType.DT_QUINT16_REF: "quint16_ref",
  217. DataType.DT_QINT32_REF: "qint32_ref",
  218. DataType.DT_BFLOAT16_REF: "bfloat16_ref",
  219. DataType.DT_RESOURCE_REF: "resource_ref",
  220. DataType.DT_VARIANT_REF: "variant_ref",
  221. }
  222. def _format_data_type(data_type):
  223. if data_type in type_to_string_map:
  224. return type_to_string_map[data_type]
  225. raise KeyError()
  226. def _format_attribute_value(value):
  227. if isinstance(value, dict) and \
  228. "type" in value and "value" in value and value["type"] == "type":
  229. return _format_data_type(value["value"])
  230. if isinstance(value, str):
  231. return value
  232. if value is True:
  233. return "true"
  234. if value is False:
  235. return "false"
  236. raise NotImplementedError()
  237. def _update_attributes(json_schema, operator, api_def):
  238. api_def_attr_map = {}
  239. for attr in api_def.attr:
  240. api_def_attr_map[attr.name] = attr
  241. for attr in operator.attr:
  242. if "attributes" not in json_schema:
  243. json_schema["attributes"] = []
  244. json_attribute = {}
  245. json_attribute["name"] = attr.name
  246. attr_type = _convert_attr_type(attr.type)
  247. if attr_type:
  248. json_attribute["type"] = attr_type
  249. else:
  250. del json_attribute["type"]
  251. if attr.name in api_def_attr_map:
  252. api_def_attr = api_def_attr_map[attr.name]
  253. if api_def_attr.description:
  254. json_attribute["description"] = api_def_attr.description
  255. if attr.has_minimum:
  256. json_attribute["minimum"] = attr.minimum
  257. if attr.HasField("allowed_values"):
  258. allowed_values = _convert_attr_value(attr.allowed_values)
  259. description = json_attribute["description"] + \
  260. " " if "description" in json_attribute else ""
  261. allowed_values = list( \
  262. map(lambda x: "`" + _format_attribute_value(x) + "`", \
  263. allowed_values))
  264. description = description + \
  265. "Must be one of the following: " + ", ".join(allowed_values) + "."
  266. json_attribute["description"] = description
  267. if attr.HasField("default_value"):
  268. default_value = _convert_attr_value(attr.default_value)
  269. json_attribute["default"] = default_value
  270. json_schema["attributes"].append(json_attribute)
  271. def _update_inputs(json_schema, operator, api_def):
  272. api_def_in_arg_map = {}
  273. for in_arg in api_def.in_arg:
  274. api_def_in_arg_map[in_arg.name] = in_arg
  275. for input_arg in operator.input_arg:
  276. if "inputs" not in json_schema:
  277. json_schema["inputs"] = []
  278. json_input = {}
  279. json_input["name"] = input_arg.name
  280. if input_arg.name in api_def_in_arg_map:
  281. api_def_in_arg = api_def_in_arg_map[input_arg.name]
  282. if api_def_in_arg.description:
  283. json_input["description"] = api_def_in_arg.description
  284. if input_arg.number_attr:
  285. json_input["numberAttr"] = input_arg.number_attr
  286. if input_arg.type:
  287. json_input["type"] = input_arg.type
  288. if input_arg.type_attr:
  289. json_input["typeAttr"] = input_arg.type_attr
  290. if input_arg.type_list_attr:
  291. json_input["typeListAttr"] = input_arg.type_list_attr
  292. if input_arg.is_ref:
  293. json_input["isRef"] = True
  294. json_schema["inputs"].append(json_input)
  295. def _update_outputs(json_schema, operator, api_def):
  296. api_def_out_arg_map = {}
  297. for out_arg in api_def.out_arg:
  298. api_def_out_arg_map[out_arg.name] = out_arg
  299. for output_arg in operator.output_arg:
  300. if "outputs" not in json_schema:
  301. json_schema["outputs"] = []
  302. json_output = {}
  303. json_output["name"] = output_arg.name
  304. if output_arg.name in api_def_out_arg_map:
  305. api_def_out_arg = api_def_out_arg_map[output_arg.name]
  306. if api_def_out_arg.description:
  307. json_output["description"] = api_def_out_arg.description
  308. if output_arg.number_attr:
  309. json_output["numberAttr"] = output_arg.number_attr
  310. if output_arg.type:
  311. json_output["type"] = output_arg.type
  312. elif output_arg.type_attr:
  313. json_output["typeAttr"] = output_arg.type_attr
  314. elif output_arg.type_list_attr:
  315. json_output["typeListAttr"] = output_arg.type_list_attr
  316. if output_arg.is_ref:
  317. json_output["isRef"] = True
  318. json_schema["outputs"].append(json_output)
  319. categories = {
  320. "Assign": "Control",
  321. "AvgPool": "Pool",
  322. "BatchNormWithGlobalNormalization": "Normalization",
  323. "BiasAdd": "Layer",
  324. "Concat": "Tensor",
  325. "ConcatV2": "Tensor",
  326. "Const": "Constant",
  327. "Conv2D": "Layer",
  328. "DepthwiseConv2dNative": "Layer",
  329. "Dequantize": "Quantization",
  330. "Elu": "Activation",
  331. "FusedBatchNorm": "Normalization",
  332. "FusedBatchNormV2": "Normalization",
  333. "FusedBatchNormV3": "Normalization",
  334. "Gather": "Transform",
  335. "Identity": "Control",
  336. "LeakyRelu": "Activation",
  337. "LRN": "Normalization",
  338. "LSTMBlockCell": "Layer",
  339. "MaxPool": "Pool",
  340. "MaxPoolV2": "Pool",
  341. "MaxPoolWithArgmax": "Pool",
  342. "Pad": "Tensor",
  343. "QuantizeAndDequantize": "Quantization",
  344. "QuantizeAndDequantizeV2": "Quantization",
  345. "QuantizeAndDequantizeV3": "Quantization",
  346. "QuantizeAndDequantizeV4": "Quantization",
  347. "QuantizeAndDequantizeV4Grad": "Quantization",
  348. "QuantizeDownAndShrinkRange": "Quantization",
  349. "QuantizeV2": "Quantization",
  350. "Relu": "Activation",
  351. "Relu6": "Activation",
  352. "Reshape": "Shape",
  353. "Sigmoid": "Activation",
  354. "Slice": "Tensor",
  355. "Softmax": "Activation",
  356. "Split": "Tensor",
  357. "Squeeze": "Transform",
  358. "StridedSlice": "Tensor",
  359. "swish_f32": "Activation",
  360. "Transpose": "Transform",
  361. "Variable": "Control",
  362. "VariableV2": "Control",
  363. }
  364. def _metadata():
  365. root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
  366. tensorflow_dir = os.path.join(root_dir, "third_party", "source", "tensorflow")
  367. core_dir = os.path.join(tensorflow_dir, "tensorflow", "core")
  368. api_def_map = _read_api_def_map(os.path.join(core_dir, "api_def" , "base_api"))
  369. ops_list = _read_op_list(os.path.join(core_dir, "ops", "ops.pbtxt"))
  370. json_root = []
  371. for operator in ops_list.op:
  372. json_schema = {}
  373. json_schema["name"] = operator.name
  374. if operator.name in categories:
  375. json_schema["category"] = categories[operator.name]
  376. api_def = api_def_pb2.ApiDef()
  377. if operator.name in api_def_map:
  378. api_def = api_def_map[operator.name]
  379. if api_def.summary:
  380. json_schema["summary"] = api_def.summary
  381. if api_def.description:
  382. json_schema["description"] = api_def.description
  383. _update_attributes(json_schema, operator, api_def)
  384. _update_inputs(json_schema, operator, api_def)
  385. _update_outputs(json_schema, operator, api_def)
  386. json_root.append(json_schema)
  387. json_file = os.path.join(root_dir, "source", "tf-metadata.json")
  388. _write(json_file, json.dumps(json_root, sort_keys=False, indent=2))
  389. def main():
  390. _metadata()
  391. if __name__ == "__main__":
  392. main()