| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425 |
- """ TensorFlow Metadata Script """
- import json
- import logging
- import os
- import re
- import sys
- import google.protobuf
- logging.getLogger("tensorflow").setLevel(logging.ERROR)
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
- dup_stderr = os.dup(sys.stderr.fileno())
- null = os.open(os.devnull, os.O_WRONLY)
- os.dup2(null, sys.stderr.fileno())
- os.close(null)
- from tensorflow.core.framework import ( # noqa: E402 # type: ignore
- api_def_pb2,
- op_def_pb2,
- types_pb2,
- )
- os.dup2(dup_stderr, sys.stderr.fileno())
- os.close(dup_stderr)
- def _read(path):
- with open(path, encoding="utf-8") as file:
- return file.read()
- def _write(path, content):
- with open(path, "w", encoding="utf-8") as file:
- file.write(content)
- def _find_multiline(line, colon):
- if colon == -1:
- return None
- line = line[colon+1:]
- while line.startswith(" "):
- line = line[1:]
- if line.startswith("<<"):
- line = line[2:]
- return line
- return None
- def _str_escape(text):
- result = ""
- for value in text:
- if value == "\n":
- result += "\\n"
- elif value == "\r":
- result += "\\r"
- elif value == "\t":
- result += "\\t"
- elif value == '"':
- result += '\\"'
- elif value == "'":
- result += "\\'"
- elif value == "\\":
- result += "\\\\"
- else:
- result += value
- return result
- def _pbtxt_from_multiline(multiline_pbtxt):
- pbtxt = ""
- while len(multiline_pbtxt) > 0:
- index = multiline_pbtxt.find("\n")
- if index == -1:
- pbtxt = pbtxt + multiline_pbtxt
- multiline_pbtxt = ""
- break
- line = multiline_pbtxt[0:index]
- multiline_pbtxt = multiline_pbtxt[index+1:]
- colon = line.find(":")
- end = _find_multiline(line, colon)
- if end is None:
- pbtxt = pbtxt + line + "\n"
- continue
- pbtxt = pbtxt + line[0:colon+1]
- unescaped = ""
- newline = False
- line = ""
- while len(multiline_pbtxt) > 0:
- index = multiline_pbtxt.find("\n")
- line = multiline_pbtxt[0:index]
- multiline_pbtxt = multiline_pbtxt[index+1:]
- if line.startswith(end):
- line = line[len(end):]
- break
- if newline:
- unescaped = unescaped + "\n"
- newline = True
- unescaped = unescaped + line
- line = ""
- pbtxt = pbtxt + '"' + _str_escape(unescaped) + '"' + line + "\n"
- return pbtxt
- def _read_op_list(file):
- op_list = op_def_pb2.OpList()
- content = _read(file)
- content = re.sub(r"^go/[a-z]+\s*", "", content)
- google.protobuf.text_format.Merge(content, op_list)
- return op_list
- def _read_api_def_map(folder):
- api_def_map = {}
- for filename in sorted(os.listdir(folder)):
- if filename.endswith(".pbtxt"):
- api_defs = api_def_pb2.ApiDefs()
- filename = folder + "/" + filename
- with open(filename, encoding="utf-8") as file:
- multiline_pbtxt = file.read()
- pbtxt = _pbtxt_from_multiline(multiline_pbtxt)
- google.protobuf.text_format.Merge(pbtxt, api_defs)
- for api_def in api_defs.op:
- api_def_map[api_def.graph_op_name] = api_def
- return api_def_map
- def _convert_type(value):
- return { "type": "type", "value": value }
- def _convert_tensor(tensor):
- return { "type": "tensor", "value": "?" }
- def _convert_shape(shape):
- return { "type": "shape", "value": "?" }
- def _convert_number(number):
- if number == float("inf"):
- return "NaN"
- if number == float("-inf"):
- return "-NaN"
- return number
- attr_type_table = {
- "type": "type", "list(type)": "type[]",
- "bool": "boolean",
- "int": "int64", "list(int)": "int64[]",
- "float": "float32", "list(float)": "float32[]",
- "string": "string", "list(string)": "string[]",
- "shape": "shape", "list(shape)": "shape[]",
- "tensor": "tensor",
- "func": "function", "list(func)": "function[]"
- }
- def _convert_attr_type(attr_type):
- if attr_type not in attr_type_table:
- raise ValueError(f"Unknown attribute type '{attr_type}'")
- return attr_type_table[attr_type]
- def _convert_attr_list(attr_value):
- result = []
- attr_value_list = attr_value.list
- if len(attr_value_list.s) > 0:
- for value in attr_value_list.s:
- result.append(value.decode("utf8"))
- if len(attr_value_list.i) > 0:
- for i in attr_value_list.i:
- result.append(i)
- if len(attr_value_list.f) > 0:
- for value in attr_value_list.f:
- result.append(_convert_number(value))
- if len(attr_value_list.type) > 0:
- for value in attr_value_list.type:
- result.append(_convert_type(value))
- if len(result) == 0:
- for _, value in attr_value_list.ListFields():
- if len(value) > 0:
- raise NotImplementedError()
- return result
- def _convert_attr_value(attr_value):
- if attr_value.HasField("list"):
- value = _convert_attr_list(attr_value)
- elif attr_value.HasField("s"):
- value = attr_value.s.decode("utf8")
- elif attr_value.HasField("i"):
- value = attr_value.i
- elif attr_value.HasField("f"):
- value = _convert_number(attr_value.f)
- elif attr_value.HasField("b"):
- value = attr_value.b
- elif attr_value.HasField("type"):
- value = _convert_type(attr_value.type)
- elif attr_value.HasField("tensor"):
- value = _convert_tensor(attr_value.tensor)
- elif attr_value.HasField("shape"):
- value = _convert_shape(attr_value.shape)
- else:
- raise NotImplementedError()
- return value
- DataType = types_pb2.DataType
- type_to_string_map = {
- DataType.DT_HALF: "float16",
- DataType.DT_FLOAT: "float32",
- DataType.DT_DOUBLE: "float64",
- DataType.DT_INT32: "int32",
- DataType.DT_UINT8: "uint8",
- DataType.DT_UINT16: "uint16",
- DataType.DT_UINT32: "uint32",
- DataType.DT_UINT64: "uint64",
- DataType.DT_INT16: "int16",
- DataType.DT_INT8: "int8",
- DataType.DT_STRING: "string",
- DataType.DT_COMPLEX64: "complex64",
- DataType.DT_COMPLEX128: "complex128",
- DataType.DT_INT64: "int64",
- DataType.DT_BOOL: "bool",
- DataType.DT_QINT8: "qint8",
- DataType.DT_QUINT8: "quint8",
- DataType.DT_QINT16: "qint16",
- DataType.DT_QUINT16: "quint16",
- DataType.DT_QINT32: "qint32",
- DataType.DT_BFLOAT16: "bfloat16",
- DataType.DT_RESOURCE: "resource",
- DataType.DT_VARIANT: "variant",
- DataType.DT_HALF_REF: "float16_ref",
- DataType.DT_FLOAT_REF: "float32_ref",
- DataType.DT_DOUBLE_REF: "float64_ref",
- DataType.DT_INT32_REF: "int32_ref",
- DataType.DT_UINT32_REF: "uint32_ref",
- DataType.DT_UINT8_REF: "uint8_ref",
- DataType.DT_UINT16_REF: "uint16_ref",
- DataType.DT_INT16_REF: "int16_ref",
- DataType.DT_INT8_REF: "int8_ref",
- DataType.DT_STRING_REF: "string_ref",
- DataType.DT_COMPLEX64_REF: "complex64_ref",
- DataType.DT_COMPLEX128_REF: "complex128_ref",
- DataType.DT_INT64_REF: "int64_ref",
- DataType.DT_UINT64_REF: "uint64_ref",
- DataType.DT_BOOL_REF: "bool_ref",
- DataType.DT_QINT8_REF: "qint8_ref",
- DataType.DT_QUINT8_REF: "quint8_ref",
- DataType.DT_QINT16_REF: "qint16_ref",
- DataType.DT_QUINT16_REF: "quint16_ref",
- DataType.DT_QINT32_REF: "qint32_ref",
- DataType.DT_BFLOAT16_REF: "bfloat16_ref",
- DataType.DT_RESOURCE_REF: "resource_ref",
- DataType.DT_VARIANT_REF: "variant_ref",
- }
- def _format_data_type(data_type):
- if data_type in type_to_string_map:
- return type_to_string_map[data_type]
- raise KeyError()
- def _format_attribute_value(value):
- if isinstance(value, dict) and \
- "type" in value and "value" in value and value["type"] == "type":
- return _format_data_type(value["value"])
- if isinstance(value, str):
- return value
- if value is True:
- return "true"
- if value is False:
- return "false"
- raise NotImplementedError()
- def _update_attributes(json_schema, operator, api_def):
- api_def_attr_map = {}
- for attr in api_def.attr:
- api_def_attr_map[attr.name] = attr
- for attr in operator.attr:
- if "attributes" not in json_schema:
- json_schema["attributes"] = []
- json_attribute = {}
- json_attribute["name"] = attr.name
- attr_type = _convert_attr_type(attr.type)
- if attr_type:
- json_attribute["type"] = attr_type
- else:
- del json_attribute["type"]
- if attr.name in api_def_attr_map:
- api_def_attr = api_def_attr_map[attr.name]
- if api_def_attr.description:
- json_attribute["description"] = api_def_attr.description
- if attr.has_minimum:
- json_attribute["minimum"] = attr.minimum
- if attr.HasField("allowed_values"):
- allowed_values = _convert_attr_value(attr.allowed_values)
- description = json_attribute["description"] + \
- " " if "description" in json_attribute else ""
- allowed_values = list( \
- map(lambda x: "`" + _format_attribute_value(x) + "`", \
- allowed_values))
- description = description + \
- "Must be one of the following: " + ", ".join(allowed_values) + "."
- json_attribute["description"] = description
- if attr.HasField("default_value"):
- default_value = _convert_attr_value(attr.default_value)
- json_attribute["default"] = default_value
- json_schema["attributes"].append(json_attribute)
- def _update_inputs(json_schema, operator, api_def):
- api_def_in_arg_map = {}
- for in_arg in api_def.in_arg:
- api_def_in_arg_map[in_arg.name] = in_arg
- for input_arg in operator.input_arg:
- if "inputs" not in json_schema:
- json_schema["inputs"] = []
- json_input = {}
- json_input["name"] = input_arg.name
- if input_arg.name in api_def_in_arg_map:
- api_def_in_arg = api_def_in_arg_map[input_arg.name]
- if api_def_in_arg.description:
- json_input["description"] = api_def_in_arg.description
- if input_arg.number_attr:
- json_input["numberAttr"] = input_arg.number_attr
- if input_arg.type:
- json_input["type"] = input_arg.type
- if input_arg.type_attr:
- json_input["typeAttr"] = input_arg.type_attr
- if input_arg.type_list_attr:
- json_input["typeListAttr"] = input_arg.type_list_attr
- if input_arg.is_ref:
- json_input["isRef"] = True
- json_schema["inputs"].append(json_input)
- def _update_outputs(json_schema, operator, api_def):
- api_def_out_arg_map = {}
- for out_arg in api_def.out_arg:
- api_def_out_arg_map[out_arg.name] = out_arg
- for output_arg in operator.output_arg:
- if "outputs" not in json_schema:
- json_schema["outputs"] = []
- json_output = {}
- json_output["name"] = output_arg.name
- if output_arg.name in api_def_out_arg_map:
- api_def_out_arg = api_def_out_arg_map[output_arg.name]
- if api_def_out_arg.description:
- json_output["description"] = api_def_out_arg.description
- if output_arg.number_attr:
- json_output["numberAttr"] = output_arg.number_attr
- if output_arg.type:
- json_output["type"] = output_arg.type
- elif output_arg.type_attr:
- json_output["typeAttr"] = output_arg.type_attr
- elif output_arg.type_list_attr:
- json_output["typeListAttr"] = output_arg.type_list_attr
- if output_arg.is_ref:
- json_output["isRef"] = True
- json_schema["outputs"].append(json_output)
- categories = {
- "Assign": "Control",
- "AvgPool": "Pool",
- "BatchNormWithGlobalNormalization": "Normalization",
- "BiasAdd": "Layer",
- "Concat": "Tensor",
- "ConcatV2": "Tensor",
- "Const": "Constant",
- "Conv2D": "Layer",
- "DepthwiseConv2dNative": "Layer",
- "Dequantize": "Quantization",
- "Elu": "Activation",
- "FusedBatchNorm": "Normalization",
- "FusedBatchNormV2": "Normalization",
- "FusedBatchNormV3": "Normalization",
- "Gather": "Transform",
- "Identity": "Control",
- "LeakyRelu": "Activation",
- "LRN": "Normalization",
- "LSTMBlockCell": "Layer",
- "MaxPool": "Pool",
- "MaxPoolV2": "Pool",
- "MaxPoolWithArgmax": "Pool",
- "Pad": "Tensor",
- "QuantizeAndDequantize": "Quantization",
- "QuantizeAndDequantizeV2": "Quantization",
- "QuantizeAndDequantizeV3": "Quantization",
- "QuantizeAndDequantizeV4": "Quantization",
- "QuantizeAndDequantizeV4Grad": "Quantization",
- "QuantizeDownAndShrinkRange": "Quantization",
- "QuantizeV2": "Quantization",
- "Relu": "Activation",
- "Relu6": "Activation",
- "Reshape": "Shape",
- "Sigmoid": "Activation",
- "Slice": "Tensor",
- "Softmax": "Activation",
- "Split": "Tensor",
- "Squeeze": "Transform",
- "StridedSlice": "Tensor",
- "swish_f32": "Activation",
- "Transpose": "Transform",
- "Variable": "Control",
- "VariableV2": "Control",
- }
- def _metadata():
- root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
- tensorflow_dir = os.path.join(root_dir, "third_party", "source", "tensorflow")
- core_dir = os.path.join(tensorflow_dir, "tensorflow", "core")
- api_def_map = _read_api_def_map(os.path.join(core_dir, "api_def" , "base_api"))
- ops_list = _read_op_list(os.path.join(core_dir, "ops", "ops.pbtxt"))
- json_root = []
- for operator in ops_list.op:
- json_schema = {}
- json_schema["name"] = operator.name
- if operator.name in categories:
- json_schema["category"] = categories[operator.name]
- api_def = api_def_pb2.ApiDef()
- if operator.name in api_def_map:
- api_def = api_def_map[operator.name]
- if api_def.summary:
- json_schema["summary"] = api_def.summary
- if api_def.description:
- json_schema["description"] = api_def.description
- _update_attributes(json_schema, operator, api_def)
- _update_inputs(json_schema, operator, api_def)
- _update_outputs(json_schema, operator, api_def)
- json_root.append(json_schema)
- json_file = os.path.join(root_dir, "source", "tf-metadata.json")
- _write(json_file, json.dumps(json_root, sort_keys=False, indent=2))
- def main():
- _metadata()
- if __name__ == "__main__":
- main()
|