|
|
@@ -8,6 +8,7 @@ import os
|
|
|
|
|
|
from tensorflow.core.framework import api_def_pb2
|
|
|
from tensorflow.core.framework import op_def_pb2
|
|
|
+from tensorflow.core.framework import types_pb2
|
|
|
from google.protobuf import text_format
|
|
|
|
|
|
def metadata():
|
|
|
@@ -190,6 +191,71 @@ def metadata():
|
|
|
return convert_shape(attr_value.shape)
|
|
|
raise Exception()
|
|
|
|
|
|
+ _TYPE_TO_STRING = {
|
|
|
+ types_pb2.DataType.DT_HALF: "float16",
|
|
|
+ types_pb2.DataType.DT_FLOAT: "float32",
|
|
|
+ types_pb2.DataType.DT_DOUBLE: "float64",
|
|
|
+ types_pb2.DataType.DT_INT32: "int32",
|
|
|
+ types_pb2.DataType.DT_UINT8: "uint8",
|
|
|
+ types_pb2.DataType.DT_UINT16: "uint16",
|
|
|
+ types_pb2.DataType.DT_UINT32: "uint32",
|
|
|
+ types_pb2.DataType.DT_UINT64: "uint64",
|
|
|
+ types_pb2.DataType.DT_INT16: "int16",
|
|
|
+ types_pb2.DataType.DT_INT8: "int8",
|
|
|
+ types_pb2.DataType.DT_STRING: "string",
|
|
|
+ types_pb2.DataType.DT_COMPLEX64: "complex64",
|
|
|
+ types_pb2.DataType.DT_COMPLEX128: "complex128",
|
|
|
+ types_pb2.DataType.DT_INT64: "int64",
|
|
|
+ types_pb2.DataType.DT_BOOL: "bool",
|
|
|
+ types_pb2.DataType.DT_QINT8: "qint8",
|
|
|
+ types_pb2.DataType.DT_QUINT8: "quint8",
|
|
|
+ types_pb2.DataType.DT_QINT16: "qint16",
|
|
|
+ types_pb2.DataType.DT_QUINT16: "quint16",
|
|
|
+ types_pb2.DataType.DT_QINT32: "qint32",
|
|
|
+ types_pb2.DataType.DT_BFLOAT16: "bfloat16",
|
|
|
+ types_pb2.DataType.DT_RESOURCE: "resource",
|
|
|
+ types_pb2.DataType.DT_VARIANT: "variant",
|
|
|
+ types_pb2.DataType.DT_HALF_REF: "float16_ref",
|
|
|
+ types_pb2.DataType.DT_FLOAT_REF: "float32_ref",
|
|
|
+ types_pb2.DataType.DT_DOUBLE_REF: "float64_ref",
|
|
|
+ types_pb2.DataType.DT_INT32_REF: "int32_ref",
|
|
|
+ types_pb2.DataType.DT_UINT32_REF: "uint32_ref",
|
|
|
+ types_pb2.DataType.DT_UINT8_REF: "uint8_ref",
|
|
|
+ types_pb2.DataType.DT_UINT16_REF: "uint16_ref",
|
|
|
+ types_pb2.DataType.DT_INT16_REF: "int16_ref",
|
|
|
+ types_pb2.DataType.DT_INT8_REF: "int8_ref",
|
|
|
+ types_pb2.DataType.DT_STRING_REF: "string_ref",
|
|
|
+ types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref",
|
|
|
+ types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref",
|
|
|
+ types_pb2.DataType.DT_INT64_REF: "int64_ref",
|
|
|
+ types_pb2.DataType.DT_UINT64_REF: "uint64_ref",
|
|
|
+ types_pb2.DataType.DT_BOOL_REF: "bool_ref",
|
|
|
+ types_pb2.DataType.DT_QINT8_REF: "qint8_ref",
|
|
|
+ types_pb2.DataType.DT_QUINT8_REF: "quint8_ref",
|
|
|
+ types_pb2.DataType.DT_QINT16_REF: "qint16_ref",
|
|
|
+ types_pb2.DataType.DT_QUINT16_REF: "quint16_ref",
|
|
|
+ types_pb2.DataType.DT_QINT32_REF: "qint32_ref",
|
|
|
+ types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref",
|
|
|
+ types_pb2.DataType.DT_RESOURCE_REF: "resource_ref",
|
|
|
+ types_pb2.DataType.DT_VARIANT_REF: "variant_ref",
|
|
|
+ }
|
|
|
+
|
|
|
+ def format_data_type(data_type):
|
|
|
+ if data_type in _TYPE_TO_STRING:
|
|
|
+ return _TYPE_TO_STRING[data_type]
|
|
|
+ raise Exception()
|
|
|
+
|
|
|
+ def format_attribute_value(value):
|
|
|
+ if type(value) is dict and 'type' in value and 'value' in value and value['type'] == 'type':
|
|
|
+ return format_data_type(value['value'])
|
|
|
+ if type(value) is str:
|
|
|
+ return value
|
|
|
+ if value == True:
|
|
|
+ return 'true'
|
|
|
+ if value == False:
|
|
|
+ return 'false'
|
|
|
+ raise Exception()
|
|
|
+
|
|
|
tensorflow_repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/src/tensorflow')
|
|
|
api_def_map = read_api_def_map(os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api'))
|
|
|
input_file = os.path.join(tensorflow_repo_dir, 'tensorflow/core/ops/ops.pbtxt')
|
|
|
@@ -241,9 +307,13 @@ def metadata():
|
|
|
if attr.has_minimum:
|
|
|
json_attribute['minimum'] = attr.minimum
|
|
|
if attr.HasField('allowed_values'):
|
|
|
- json_attribute['allowedValues'] = convert_attr_value(attr.allowed_values)
|
|
|
+ allowed_values = convert_attr_value(attr.allowed_values)
|
|
|
+ description = json_attribute['description'] + ' ' if 'description' in json_attribute else ''
|
|
|
+ description = description + 'Must be one of the following: ' + ', '.join(list(map(lambda x: "`" + format_attribute_value(x) + "`", allowed_values))) + '.'
|
|
|
+ json_attribute['description'] = description
|
|
|
if attr.HasField('default_value'):
|
|
|
- json_attribute['default'] = convert_attr_value(attr.default_value)
|
|
|
+ default_value = convert_attr_value(attr.default_value)
|
|
|
+ json_attribute['default'] = default_value
|
|
|
json_schema['attributes'].append(json_attribute)
|
|
|
for input_arg in op.input_arg:
|
|
|
if not 'inputs' in json_schema:
|