|
|
@@ -16,7 +16,7 @@ class ModelFactory: # pylint: disable=too-few-public-methods
|
|
|
file = os.path.join(path, entry[0])
|
|
|
with open(file, 'r', encoding='utf-8') as handle:
|
|
|
for item in json.load(handle):
|
|
|
- name = entry[1] + item['name']
|
|
|
+ name = entry[1] + item['name'].split('(', 1)[0]
|
|
|
metadata[name] = item
|
|
|
metadata = Metadata(metadata)
|
|
|
return _Model(metadata, model)
|
|
|
@@ -143,12 +143,12 @@ class _Graph: # pylint: disable=too-few-public-methods
|
|
|
lists[node] = 0
|
|
|
|
|
|
def create_node(node):
|
|
|
- schema = node.schema() if hasattr(node, 'schema') else None
|
|
|
- schema = self.metadata.type(schema) if schema and schema != '(no schema)' else None
|
|
|
+ identifier = node.schema()
|
|
|
+ schema, category = self.metadata.type(identifier)
|
|
|
json_node = {
|
|
|
'type': {
|
|
|
'name': node.kind(),
|
|
|
- 'category': schema['category'] if schema and 'category' in schema else ''
|
|
|
+ 'category': category
|
|
|
},
|
|
|
'inputs': [],
|
|
|
'outputs': [],
|
|
|
@@ -171,12 +171,13 @@ class _Graph: # pylint: disable=too-few-public-methods
|
|
|
json_node['attributes'].append(json_attribute)
|
|
|
|
|
|
for i, value in enumerate(node.inputs()):
|
|
|
- parameter = schema['inputs'][i] if schema and i < len(schema['inputs']) else None
|
|
|
- parameter_name = parameter['name'] if parameter and 'name' in parameter else 'input'
|
|
|
- parameter_type = parameter['type'] if parameter and 'type' in parameter else None
|
|
|
+ arg = schema.arguments[i] if schema and i < len(schema.arguments) else None
|
|
|
+ parameter_name = arg.name if arg else 'input'
|
|
|
+ real_type = arg.real_type if arg else None
|
|
|
input_node = value.node()
|
|
|
if input_node in constants:
|
|
|
- if parameter_type == 'Tensor' or value.type().kind() == 'TensorType':
|
|
|
+ if (real_type and real_type.kind() == 'TensorType') or \
|
|
|
+ value.type().kind() == 'TensorType':
|
|
|
json_node['inputs'].append({
|
|
|
'name': parameter_name,
|
|
|
'value': [ argument(value) ]
|
|
|
@@ -186,8 +187,8 @@ class _Graph: # pylint: disable=too-few-public-methods
|
|
|
'name': parameter_name,
|
|
|
'value': constant_value(input_node)
|
|
|
}
|
|
|
- if parameter and 'type' in parameter:
|
|
|
- json_attribute['type'] = parameter['type']
|
|
|
+ if real_type:
|
|
|
+ json_attribute['type'] = self._argument_type(real_type)
|
|
|
json_node['attributes'].append(json_attribute)
|
|
|
constants[input_node] = constants[input_node] + 1
|
|
|
continue
|
|
|
@@ -209,8 +210,8 @@ class _Graph: # pylint: disable=too-few-public-methods
|
|
|
})
|
|
|
|
|
|
for i, value in enumerate(node.outputs()):
|
|
|
- parameter = schema['outputs'][i] if schema and i < len(schema['outputs']) else None
|
|
|
- name = parameter['name'] if parameter and 'name' in parameter else 'output'
|
|
|
+ ret = schema.returns[i] if schema and i < len(schema.returns) else None
|
|
|
+ name = ret.name if ret else 'output'
|
|
|
json_node['outputs'].append({
|
|
|
'name': name,
|
|
|
'value': [ argument(value) ]
|
|
|
@@ -235,361 +236,63 @@ class _Graph: # pylint: disable=too-few-public-methods
|
|
|
|
|
|
return json_graph
|
|
|
|
|
|
-class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
|
|
|
-
|
|
|
- def __init__(self, metadata):
|
|
|
- self.types = metadata
|
|
|
- self.cache = set()
|
|
|
- self._primitives = {
|
|
|
- 'int': 'int64', 'float': 'float32', 'bool': 'boolean', 'str': 'string'
|
|
|
- }
|
|
|
-
|
|
|
- def type(self, schema): # pylint: disable=missing-function-docstring
|
|
|
- key = schema.name if isinstance(schema, Schema) else schema.split('(', 1)[0].strip()
|
|
|
- if key not in self.cache:
|
|
|
- self.cache.add(key)
|
|
|
- schema = schema if isinstance(schema, Schema) else Schema(schema)
|
|
|
- arguments = list(filter(lambda _: \
|
|
|
- not(_.kwarg_only and hasattr(_, 'alias')), schema.arguments))
|
|
|
- returns = schema.returns
|
|
|
- value = self.types.setdefault(schema.name, { 'name': schema.name, })
|
|
|
- inputs = value.get('inputs', [])
|
|
|
- outputs = value.get('outputs', [])
|
|
|
- inputs = [ inputs[i] if i < len(inputs) else {} for i in range(len(arguments)) ]
|
|
|
- outputs = [ outputs[i] if i < len(outputs) else {} for i in range(len(returns)) ]
|
|
|
- value['inputs'] = inputs
|
|
|
- value['outputs'] = outputs
|
|
|
- for i, _ in enumerate(arguments):
|
|
|
- argument = inputs[i]
|
|
|
- argument['name'] = _.name
|
|
|
- self._argument(argument, getattr(_, 'type'))
|
|
|
- if hasattr(_, 'default'):
|
|
|
- argument['default'] = _.default
|
|
|
- if hasattr(_, 'kwarg_only') and _.kwarg_only is True:
|
|
|
- argument['kwarg_only'] = True
|
|
|
- for i, _ in enumerate(returns):
|
|
|
- argument = outputs[i]
|
|
|
- if hasattr(_, 'name'):
|
|
|
- argument['name'] = _.name
|
|
|
- self._argument(argument, getattr(_, 'type'))
|
|
|
- return self.types[key]
|
|
|
-
|
|
|
- def _argument_type(self, value):
|
|
|
- if isinstance(value, Schema.OptionalType):
|
|
|
- element_type = self._argument_type(value.element_type)
|
|
|
+ def _argument_type(self, value): # pylint: disable=too-many-branches,too-many-return-statements
|
|
|
+ if value.kind() == 'TensorType':
|
|
|
+ return 'Tensor'
|
|
|
+ if value.kind() == 'OptionalType':
|
|
|
+ element_type = self._argument_type(value.getElementType())
|
|
|
return f'{element_type}?'
|
|
|
- if isinstance(value, Schema.ListType):
|
|
|
- element_type = self._argument_type(value.element_type)
|
|
|
+ if value.kind() == 'ListType':
|
|
|
+ element_type = self._argument_type(value.getElementType())
|
|
|
size = str(value.size) if hasattr(value, 'size') else ''
|
|
|
return f'{element_type}[{size}]'
|
|
|
- if isinstance(value, Schema.DictType):
|
|
|
+ if value.kind() == 'DictType':
|
|
|
key_type = self._argument_type(value.getKeyType())
|
|
|
value_type = self._argument_type(value.getValueType())
|
|
|
return f'Dict({key_type}, {value_type})'
|
|
|
- if isinstance(value, Schema.TupleType):
|
|
|
+ if value.kind() == 'TupleType':
|
|
|
elements = []
|
|
|
for element in value.elements():
|
|
|
elements.append(self._argument_type(element))
|
|
|
return f'({', '.join(elements)})'
|
|
|
- name = value.name
|
|
|
- return self._primitives[name] if name in self._primitives else name
|
|
|
+ if value.kind() == 'IntType':
|
|
|
+ return 'int64'
|
|
|
+ if value.kind() == 'SymIntType':
|
|
|
+ return 'SymInt'
|
|
|
+ if value.kind() == 'FloatType':
|
|
|
+ return 'float32'
|
|
|
+ if value.kind() == 'BoolType':
|
|
|
+ return 'boolean'
|
|
|
+ if value.kind() == 'StringType':
|
|
|
+ return 'string'
|
|
|
+ if value.kind() == 'NumberType':
|
|
|
+ return 'Scalar'
|
|
|
+ if value.kind() == 'ScalarTypeType':
|
|
|
+ return 'ScalarType'
|
|
|
+ if value.kind() == 'LayoutType':
|
|
|
+ return 'Layout'
|
|
|
+ if value.kind() == 'MemoryFormatType':
|
|
|
+ return 'MemoryFormat'
|
|
|
+ if value.kind() == 'DeviceObjType':
|
|
|
+ return 'Device'
|
|
|
+ if value.kind() == 'GeneratorType':
|
|
|
+ return 'Generator'
|
|
|
+ if value.kind() == 'VarType':
|
|
|
+ return value.annotation_str
|
|
|
+ raise NotImplementedError()
|
|
|
|
|
|
- def _argument(self, argument, value):
|
|
|
- argument_type = self._argument_type(value)
|
|
|
- if argument_type:
|
|
|
- argument['type'] = argument_type
|
|
|
- else:
|
|
|
- argument.pop('type', None)
|
|
|
- if 'optional' in argument:
|
|
|
- del argument['optional']
|
|
|
+class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
|
|
|
|
|
|
-class Schema: # pylint: disable=too-few-public-methods,missing-class-docstring
|
|
|
- def __init__(self, value):
|
|
|
- self.value = value
|
|
|
- lexer = Schema.Lexer(value)
|
|
|
- lexer.whitespace(0)
|
|
|
- self._parse_name(lexer)
|
|
|
- lexer.whitespace(0)
|
|
|
- if lexer.kind == '(':
|
|
|
- self._parse_arguments(lexer)
|
|
|
- lexer.whitespace(0)
|
|
|
- lexer.expect('->')
|
|
|
- lexer.whitespace(0)
|
|
|
- self._parse_returns(lexer)
|
|
|
- def __str__(self):
|
|
|
- arguments = []
|
|
|
- kwarg_only = False
|
|
|
- for _ in self.arguments:
|
|
|
- if not kwarg_only and _.kwarg_only:
|
|
|
- kwarg_only = True
|
|
|
- arguments.append('*')
|
|
|
- arguments.append(_.__str__())
|
|
|
- if self.is_vararg:
|
|
|
- arguments.append('...')
|
|
|
- returns = ', '.join(map(lambda _: _.__str__(), self.returns))
|
|
|
- returns = returns if len(self.returns) == 1 else '(' + returns + ')'
|
|
|
- return self.name + '(' + ', '.join(arguments) + ') -> ' + returns
|
|
|
- def _parse_name(self, lexer):
|
|
|
- self.name = lexer.expect('id')
|
|
|
- if lexer.eat(':'):
|
|
|
- lexer.expect(':')
|
|
|
- self.name = self.name + '::' + lexer.expect('id')
|
|
|
- if lexer.eat('.'):
|
|
|
- self.name = self.name + '.' + lexer.expect('id')
|
|
|
- def _parse_arguments(self, lexer):
|
|
|
- self.arguments = []
|
|
|
- self.is_vararg = False
|
|
|
- self.kwarg_only = False
|
|
|
- lexer.expect('(')
|
|
|
- if not lexer.eat(')'):
|
|
|
- while True:
|
|
|
- lexer.whitespace(0)
|
|
|
- if self.is_vararg:
|
|
|
- raise NotImplementedError()
|
|
|
- if lexer.eat('*'):
|
|
|
- self.kwarg_only = True
|
|
|
- elif lexer.eat('...'):
|
|
|
- self.is_vararg = True
|
|
|
- else:
|
|
|
- self.arguments.append(Schema.Argument(lexer, False, self.kwarg_only))
|
|
|
- lexer.whitespace(0)
|
|
|
- if not lexer.eat(','):
|
|
|
- break
|
|
|
- lexer.expect(')')
|
|
|
- def _parse_returns(self, lexer):
|
|
|
- self.returns = []
|
|
|
- self.is_varret = False
|
|
|
- if lexer.eat('...'):
|
|
|
- self.is_varret = True
|
|
|
- elif lexer.eat('('):
|
|
|
- lexer.whitespace(0)
|
|
|
- if not lexer.eat(')'):
|
|
|
- while True:
|
|
|
- lexer.whitespace(0)
|
|
|
- if self.is_varret:
|
|
|
- raise NotImplementedError()
|
|
|
- if lexer.eat('...'):
|
|
|
- self.is_varret = True
|
|
|
- else:
|
|
|
- self.returns.append(Schema.Argument(lexer, True, False))
|
|
|
- lexer.whitespace(0)
|
|
|
- if not lexer.eat(','):
|
|
|
- break
|
|
|
- lexer.expect(')')
|
|
|
- lexer.whitespace(0)
|
|
|
- else:
|
|
|
- self.returns.append(Schema.Argument(lexer, True, False))
|
|
|
- class Argument: # pylint: disable=too-few-public-methods
|
|
|
- def __init__(self, lexer, is_return, kwarg_only):
|
|
|
- value = Schema.Type.parse(lexer)
|
|
|
- lexer.whitespace(0)
|
|
|
- while True:
|
|
|
- if lexer.eat('['):
|
|
|
- size = None
|
|
|
- if lexer.kind == '#':
|
|
|
- size = int(lexer.value)
|
|
|
- lexer.next()
|
|
|
- lexer.expect(']')
|
|
|
- value = Schema.ListType(value, size)
|
|
|
- elif lexer.eat('?'):
|
|
|
- value = Schema.OptionalType(value)
|
|
|
- elif lexer.kind == '(' and not hasattr(self, 'alias'):
|
|
|
- self.alias = self._parse_alias(lexer)
|
|
|
- else:
|
|
|
- break
|
|
|
- self.type = value
|
|
|
- if is_return:
|
|
|
- lexer.whitespace(0)
|
|
|
- self.kwarg_only = False
|
|
|
- if lexer.kind == 'id':
|
|
|
- self.name = lexer.expect('id')
|
|
|
- else:
|
|
|
- lexer.whitespace(1)
|
|
|
- self.kwarg_only = kwarg_only
|
|
|
- self.name = lexer.expect('id')
|
|
|
- lexer.whitespace(0)
|
|
|
- if lexer.eat('='):
|
|
|
- lexer.whitespace(0)
|
|
|
- self.default = self._parse_value(lexer)
|
|
|
- def __str__(self):
|
|
|
- alias = '(' + self.alias + ')' if hasattr(self, 'alias') else ''
|
|
|
- name = ' ' + self.name if hasattr(self, 'name') else ''
|
|
|
- default = '=' + self.default.__str__() if hasattr(self, 'default') else ''
|
|
|
- return self.type.__str__() + alias + name + default
|
|
|
- def _parse_value(self, lexer):
|
|
|
- if lexer.kind == 'id':
|
|
|
- if lexer.value in ('True', 'False'):
|
|
|
- value = bool(lexer.value == 'True')
|
|
|
- elif lexer.value == 'None':
|
|
|
- value = None
|
|
|
- elif lexer.value in ('Mean', 'contiguous_format', 'long'):
|
|
|
- value = lexer.value
|
|
|
- else:
|
|
|
- raise NotImplementedError()
|
|
|
- elif lexer.kind == '#':
|
|
|
- value = float(lexer.value) if \
|
|
|
- lexer.value.find('.') != -1 or lexer.value.find('e') != -1 else \
|
|
|
- int(lexer.value)
|
|
|
- elif lexer.kind == 'string':
|
|
|
- value = lexer.value[1:-1]
|
|
|
- elif lexer.eat('['):
|
|
|
- value = []
|
|
|
- if not lexer.eat(']'):
|
|
|
- while True:
|
|
|
- lexer.whitespace(0)
|
|
|
- value.append(self._parse_value(lexer))
|
|
|
- lexer.whitespace(0)
|
|
|
- if not lexer.eat(','):
|
|
|
- break
|
|
|
- lexer.expect(']')
|
|
|
- return value
|
|
|
- else:
|
|
|
- raise NotImplementedError()
|
|
|
- lexer.next()
|
|
|
- return value
|
|
|
- def _parse_alias(self, lexer):
|
|
|
- value = ''
|
|
|
- lexer.expect('(')
|
|
|
- while not lexer.eat(')'):
|
|
|
- value += lexer.value
|
|
|
- lexer.next()
|
|
|
- return value
|
|
|
- class Type: # pylint: disable=too-few-public-methods,missing-class-docstring
|
|
|
- def __init__(self, name):
|
|
|
- self.name = name
|
|
|
- def __str__(self):
|
|
|
- return self.name
|
|
|
- @staticmethod
|
|
|
- def parse(lexer): # pylint: disable=missing-function-docstring
|
|
|
- if lexer.eat('('):
|
|
|
- lexer.whitespace(0)
|
|
|
- elements = []
|
|
|
- while not lexer.eat(')'):
|
|
|
- elements.append(Schema.Type.parse(lexer))
|
|
|
- lexer.whitespace(0)
|
|
|
- lexer.eat(',')
|
|
|
- lexer.whitespace(0)
|
|
|
- return Schema.TupleType(elements)
|
|
|
- name = lexer.expect('id')
|
|
|
- while lexer.eat('.'):
|
|
|
- name = name + '.' + lexer.expect('id')
|
|
|
- if name == 'Dict':
|
|
|
- lexer.expect('(')
|
|
|
- lexer.whitespace(0)
|
|
|
- key_type = Schema.Type.parse(lexer)
|
|
|
- lexer.whitespace(0)
|
|
|
- lexer.expect(',')
|
|
|
- lexer.whitespace(0)
|
|
|
- value_type = Schema.Type.parse(lexer)
|
|
|
- lexer.whitespace(0)
|
|
|
- lexer.expect(')')
|
|
|
- return Schema.DictType(key_type, value_type)
|
|
|
- if name == 'Future':
|
|
|
- lexer.expect('(')
|
|
|
- lexer.whitespace(0)
|
|
|
- elem_type = Schema.Type.parse(lexer)
|
|
|
- lexer.whitespace(0)
|
|
|
- lexer.expect(')')
|
|
|
- return Schema.Type(f'Future({elem_type})')
|
|
|
- return Schema.Type(name)
|
|
|
- class OptionalType: # pylint: disable=too-few-public-methods,missing-class-docstring
|
|
|
- def __init__(self, element_type):
|
|
|
- self.element_type = element_type
|
|
|
- def __str__(self):
|
|
|
- return self.element_type.__str__() + '?'
|
|
|
- class ListType: # pylint: disable=too-few-public-methods,missing-class-docstring
|
|
|
- def __init__(self, element_type, size):
|
|
|
- self.element_type = element_type
|
|
|
- if size:
|
|
|
- self.size = size
|
|
|
- def __str__(self):
|
|
|
- size = self.size.__str__() if hasattr(self, 'size') else ''
|
|
|
- return self.element_type.__str__() + '[' + size + ']'
|
|
|
- class DictType:
|
|
|
- def __init__(self, key_type, value_type):
|
|
|
- self._key_type = key_type
|
|
|
- self._value_type = value_type
|
|
|
- def __str__(self):
|
|
|
- return 'Dict(' + str(self._key_type) + ', ' + str(self._value_type) + ')'
|
|
|
- def getKeyType(self): # pylint: disable=invalid-name,missing-function-docstring
|
|
|
- return self._key_type
|
|
|
- def getValueType(self): # pylint: disable=invalid-name,,missing-function-docstring
|
|
|
- return self._value_type
|
|
|
- class TupleType:
|
|
|
- def __init__(self, elements):
|
|
|
- self._elements = elements
|
|
|
- def elements(self): # pylint: disable=invalid-name,,missing-function-docstring
|
|
|
- return self._elements
|
|
|
- class Lexer: # pylint: disable=too-few-public-methods,missing-class-docstring
|
|
|
- def __init__(self, buffer):
|
|
|
- self.buffer = buffer
|
|
|
- self.position = 0
|
|
|
- self.value = ''
|
|
|
- self.next()
|
|
|
- def eat(self, kind): # pylint: disable=missing-function-docstring
|
|
|
- if self.kind != kind:
|
|
|
- return None
|
|
|
- value = self.value
|
|
|
- self.next()
|
|
|
- return value
|
|
|
- def expect(self, kind): # pylint: disable=missing-function-docstring
|
|
|
- if self.kind != kind:
|
|
|
- raise SyntaxError("Unexpected '" + self.kind + "' instead of '" + kind + "'.")
|
|
|
- value = self.value
|
|
|
- self.next()
|
|
|
- return value
|
|
|
- def whitespace(self, count): # pylint: disable=missing-function-docstring
|
|
|
- if self.kind != ' ':
|
|
|
- if count > len(self.value):
|
|
|
- raise IndexError()
|
|
|
- return False
|
|
|
- self.next()
|
|
|
- return True
|
|
|
- def next(self): # pylint: disable=missing-function-docstring,too-many-branches
|
|
|
- self.position += len(self.value)
|
|
|
- i = self.position
|
|
|
- if i >= len(self.buffer):
|
|
|
- self.kind = '\0'
|
|
|
- self.value = ''
|
|
|
- elif self.buffer[i] == ' ':
|
|
|
- while self.buffer[i] == ' ':
|
|
|
- i += 1
|
|
|
- self.kind = ' '
|
|
|
- self.value = self.buffer[self.position:i]
|
|
|
- elif self.buffer[i] == '.' and self.buffer[i+1] == '.' and self.buffer[i+2] == '.':
|
|
|
- self.kind = '...'
|
|
|
- self.value = '...'
|
|
|
- elif self.buffer[i] in ('(', ')', ':', '.', '[', ']', ',', '=', '?', '!', '*', '|'):
|
|
|
- self.kind = self.buffer[i]
|
|
|
- self.value = self.buffer[i]
|
|
|
- elif (self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
|
|
|
- (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or self.buffer[i] == '_':
|
|
|
- i += 1
|
|
|
- while i < len(self.buffer) and \
|
|
|
- ((self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
|
|
|
- (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or \
|
|
|
- (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '_'):
|
|
|
- i += 1
|
|
|
- self.kind = 'id'
|
|
|
- self.value = self.buffer[self.position:i]
|
|
|
- elif self.buffer[i] == '-' and self.buffer[i+1] == '>':
|
|
|
- self.kind = '->'
|
|
|
- self.value = '->'
|
|
|
- elif (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '-':
|
|
|
- i += 1
|
|
|
- while i < len(self.buffer) and \
|
|
|
- ((self.buffer[i] >= '0' and self.buffer[i] <= '9') or \
|
|
|
- self.buffer[i] == '.' or self.buffer[i] == 'e' or self.buffer[i] == '-'):
|
|
|
- i += 1
|
|
|
- self.kind = '#'
|
|
|
- self.value = self.buffer[self.position:i]
|
|
|
- elif self.buffer[i] in ("'", '"'):
|
|
|
- quote = self.buffer[i]
|
|
|
- i += 1
|
|
|
- while i < len(self.buffer) and self.buffer[i] != quote:
|
|
|
- i += 2 if self.buffer[i] == '\\' and self.buffer[i+1] in ("'", '"', '\\') else 1
|
|
|
- i += 1
|
|
|
- self.kind = 'string'
|
|
|
- self.value = self.buffer[self.position:i]
|
|
|
- else:
|
|
|
- raise NotImplementedError("Unsupported token at " + self.position)
|
|
|
+ def __init__(self, metadata):
|
|
|
+ self.types = metadata
|
|
|
+
|
|
|
+ def type(self, identifier): # pylint: disable=missing-function-docstring
|
|
|
+ if identifier == '(no schema)':
|
|
|
+ return (None, '')
|
|
|
+ key = identifier.split('(', 1)[0]
|
|
|
+ value = self.types.get(key)
|
|
|
+ category = value['category'] if value and 'category' in value else ''
|
|
|
+ name, overload_name = key.split('.', 1) if key.find('.') > 0 else (key, '')
|
|
|
+ import torch # pylint: disable=import-outside-toplevel,import-error
|
|
|
+ schema = torch._C._get_schema(name, overload_name) # pylint: disable=protected-access
|
|
|
+ return (schema, category)
|