Jelajahi Sumber

Update backend test (#990)

Lutz Roeder 1 tahun lalu
induk
melakukan
e97c8a1082
5 mengubah file dengan 466 tambahan dan 390 penghapusan
  1. 2 2
      source/python.js
  2. 1 1
      source/pytorch-metadata.json
  3. 37 23
      source/pytorch.js
  4. 61 358
      source/pytorch.py
  5. 365 6
      tools/pytorch_script.py

+ 2 - 2
source/python.js

@@ -4364,8 +4364,8 @@ python.Execution = class {
             }
         });
         this.registerFunction('torch._C._get_registry', () => {
-            torch._C._registry = torch._C._registry || new torch._C.OperatorRegistry();
-            return torch._C._registry;
+            this._operators = this._operators || new torch._C.OperatorRegistry();
+            return this._operators;
         });
         this.registerFunction('torch._C._get_schema', (op_name, overload_name) => {
             const registry = torch._C._get_registry();

+ 1 - 1
source/pytorch-metadata.json

@@ -20382,7 +20382,7 @@
     "category": "Data"
   },
   {
-    "name": "torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int64)",
+    "name": "torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int)",
     "inputs": [
       { "name": "tensor", "type": "Tensor" },
       { "name": "sample_rate", "type": "int64" },

+ 37 - 23
source/pytorch.js

@@ -176,6 +176,9 @@ pytorch.Graph = class {
                     node.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
                     continue;
                 }
+                if (node.kind() === 'prim::Constant' && node.outputs().length === 1 && node.outputs()[0].uses().length === 1) {
+                    continue;
+                }
                 this.nodes.push(new pytorch.Node(metadata, null, null, node, initializers, values));
             }
             if (module) {
@@ -398,9 +401,18 @@ pytorch.Node = class {
         let module = null;
         if (pytorch.Utility.isInstance(obj, 'torch.Node')) {
             const node = obj;
-            // const schema = node.schema();
-            this.type = createType(metadata, node.kind());
-            for (const name of node.attributeNames()) {
+            const kind = node.kind();
+            this.type = {
+                identifier: kind,
+                name: kind.indexOf('::') === -1 ? kind : kind.split('::').pop().split('.')[0]
+            };
+            const schema = node.schema();
+            if (schema && schema.category) {
+                this.type.category = schema.category;
+            }
+            const inputs = node.inputs();
+            const outputs = node.outputs();
+            const getAttribute = (node, name) => {
                 const kind = node.kindOf(name);
                 let value = null;
                 let type = null;
@@ -412,12 +424,16 @@ pytorch.Node = class {
                     case 'ival': value = node.ival(name); break;
                     default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
                 }
+                return [type, value];
+            };
+            for (const name of node.attributeNames()) {
+                const [type, value] = getAttribute(node, name);
                 const attribute = new pytorch.Argument(name, value, type);
                 this.attributes.push(attribute);
             }
             let match = true;
             let count = 0;
-            for (const input of node.inputs()) {
+            for (const input of inputs) {
                 const value = input.value;
                 let values = [];
                 if (pytorch.Utility.isObject(value)) {
@@ -458,22 +474,21 @@ pytorch.Node = class {
                     module = null;
                 }
             }
-            const inputs = node.inputs();
             for (let i = 0; i < inputs.length; i++) {
                 const input = inputs[i];
-                const schema = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null;
-                const name = schema && schema.name ? schema.name : i.toString();
-                let type = schema && schema.type ? schema.type : null;
+                const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null;
+                const name = arg && arg.name ? arg.name : i.toString();
+                let type = arg ? arg.real_type : null;
                 let array = false;
-                if (type && type.endsWith('[]')) {
+                if (pytorch.Utility.isInstance(type, 'torch.ListType')) {
                     array = true;
-                    type = type.slice(0, -2);
+                    type = type.getElementType();
                 }
                 let argument = null;
-                if (pytorch.Utility.isObjectType(type)) {
+                if (arg && pytorch.Utility.isInstance(arg.real_type, 'torch.ClassType')) {
                     const obj = input.value;
                     if (!array && initializers.has(obj)) {
-                        const node = new pytorch.Node(metadata, name, type, obj, initializers, values);
+                        const node = new pytorch.Node(metadata, name, type.qualified_name(), obj, initializers, values);
                         argument = new pytorch.Argument(name, node, 'object');
                     } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
                         const node = obj.map((obj) => new pytorch.Node(metadata, name, type, obj, initializers, values));
@@ -507,6 +522,9 @@ pytorch.Node = class {
                         }
                     } else if (pytorch.Utility.isInstance(input.type(), 'torch.StringType') && typeof input.value === 'string') {
                         argument = new pytorch.Argument(name, input.value, 'string');
+                    } else if (input.node() && input.uses().length === 1 && input.node().kind() === 'prim::Constant') {
+                        const [type, value] = getAttribute(input.node(), 'value');
+                        argument = new pytorch.Argument(name, value, type || 'attribute');
                     } else {
                         const identifier = input.unique().toString();
                         const value = values.map(identifier);
@@ -543,21 +561,19 @@ pytorch.Node = class {
                         }
                         return value;
                     });
-                    argument = new pytorch.Argument(name, args, schema.type);
+                    argument = new pytorch.Argument(name, args, pytorch.Utility.toType(type));
                 } else {
-                    argument = createAttribute(schema, schema.name, input.value);
+                    throw new pytorch.Error('Unsupported input value');
                 }
                 this.inputs.push(argument);
             }
-            const outputs = node.outputs();
             for (let i = 0; i < outputs.length; i++) {
                 const output = outputs[i];
-                const metadata = this.type && this.type.outputs && i < this.type.outputs.length ? this.type.outputs[i] : null;
-                let name = '';
-                if (metadata && metadata.name) {
-                    name = metadata.name;
+                const ret = schema && schema.returns && i < schema.returns.length ? schema.returns[i] : null;
+                if (ret && ret.name) {
+                    name = ret.name;
                 } else {
-                    name = i === 0 ? 'output' : `output${i}`;
+                    name = i === 0 && outputs.length === 1 ? 'output' : `${i}`;
                 }
                 let list = [output];
                 if (output.uses().length === 1 &&
@@ -2573,9 +2589,7 @@ pytorch.jit.Execution = class extends pytorch.Execution {
                         const node = this._graph.create('aten::__getitem__.t');
                         node.addInput(target);
                         if (Number.isInteger(index)) {
-                            const value = this.invoke('torch.Value', [node]);
-                            value.value = index;
-                            index = value;
+                            index = this.constant(index);
                         }
                         node.addInput(index);
                         const value = node.addOutput();

+ 61 - 358
source/pytorch.py

@@ -16,7 +16,7 @@ class ModelFactory: # pylint: disable=too-few-public-methods
             file = os.path.join(path, entry[0])
             with open(file, 'r', encoding='utf-8') as handle:
                 for item in json.load(handle):
-                    name = entry[1] + item['name']
+                    name = entry[1] + item['name'].split('(', 1)[0]
                     metadata[name] = item
         metadata = Metadata(metadata)
         return _Model(metadata, model)
@@ -143,12 +143,12 @@ class _Graph: # pylint: disable=too-few-public-methods
                     lists[node] = 0
 
         def create_node(node):
-            schema = node.schema() if hasattr(node, 'schema') else None
-            schema = self.metadata.type(schema) if schema and schema != '(no schema)' else None
+            identifier = node.schema()
+            schema, category = self.metadata.type(identifier)
             json_node = {
                 'type': {
                     'name': node.kind(),
-                    'category': schema['category'] if schema and 'category' in schema else ''
+                    'category': category
                 },
                 'inputs': [],
                 'outputs': [],
@@ -171,12 +171,13 @@ class _Graph: # pylint: disable=too-few-public-methods
                     json_node['attributes'].append(json_attribute)
 
             for i, value in enumerate(node.inputs()):
-                parameter = schema['inputs'][i] if schema and i < len(schema['inputs']) else None
-                parameter_name = parameter['name'] if parameter and 'name' in parameter else 'input'
-                parameter_type = parameter['type'] if parameter and 'type' in parameter else None
+                arg = schema.arguments[i] if schema and i < len(schema.arguments) else None
+                parameter_name = arg.name if arg else 'input'
+                real_type = arg.real_type if arg else None
                 input_node = value.node()
                 if input_node in constants:
-                    if parameter_type == 'Tensor' or value.type().kind() == 'TensorType':
+                    if (real_type and real_type.kind() == 'TensorType') or \
+                        value.type().kind() == 'TensorType':
                         json_node['inputs'].append({
                             'name': parameter_name,
                             'value': [ argument(value) ]
@@ -186,8 +187,8 @@ class _Graph: # pylint: disable=too-few-public-methods
                             'name': parameter_name,
                             'value': constant_value(input_node)
                         }
-                        if parameter and 'type' in parameter:
-                            json_attribute['type'] = parameter['type']
+                        if real_type:
+                            json_attribute['type'] = self._argument_type(real_type)
                         json_node['attributes'].append(json_attribute)
                     constants[input_node] = constants[input_node] + 1
                     continue
@@ -209,8 +210,8 @@ class _Graph: # pylint: disable=too-few-public-methods
                 })
 
             for i, value in enumerate(node.outputs()):
-                parameter = schema['outputs'][i] if schema and i < len(schema['outputs']) else None
-                name = parameter['name'] if parameter and 'name' in parameter else 'output'
+                ret = schema.returns[i] if schema and i < len(schema.returns) else None
+                name = ret.name if ret else 'output'
                 json_node['outputs'].append({
                     'name': name,
                     'value': [ argument(value) ]
@@ -235,361 +236,63 @@ class _Graph: # pylint: disable=too-few-public-methods
 
         return json_graph
 
-class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
-
-    def __init__(self, metadata):
-        self.types = metadata
-        self.cache = set()
-        self._primitives = {
-            'int': 'int64', 'float': 'float32', 'bool': 'boolean', 'str': 'string'
-        }
-
-    def type(self, schema): # pylint: disable=missing-function-docstring
-        key = schema.name if isinstance(schema, Schema) else schema.split('(', 1)[0].strip()
-        if key not in self.cache:
-            self.cache.add(key)
-            schema = schema if isinstance(schema, Schema) else Schema(schema)
-            arguments = list(filter(lambda _: \
-                not(_.kwarg_only and hasattr(_, 'alias')), schema.arguments))
-            returns = schema.returns
-            value = self.types.setdefault(schema.name, { 'name': schema.name, })
-            inputs = value.get('inputs', [])
-            outputs = value.get('outputs', [])
-            inputs = [ inputs[i] if i < len(inputs) else {} for i in range(len(arguments)) ]
-            outputs = [ outputs[i] if i < len(outputs) else {} for i in range(len(returns)) ]
-            value['inputs'] = inputs
-            value['outputs'] = outputs
-            for i, _ in enumerate(arguments):
-                argument = inputs[i]
-                argument['name'] = _.name
-                self._argument(argument, getattr(_, 'type'))
-                if hasattr(_, 'default'):
-                    argument['default'] = _.default
-                if hasattr(_, 'kwarg_only') and _.kwarg_only is True:
-                    argument['kwarg_only'] = True
-            for i, _ in enumerate(returns):
-                argument = outputs[i]
-                if hasattr(_, 'name'):
-                    argument['name'] = _.name
-                self._argument(argument, getattr(_, 'type'))
-        return self.types[key]
-
-    def _argument_type(self, value):
-        if isinstance(value, Schema.OptionalType):
-            element_type = self._argument_type(value.element_type)
+    def _argument_type(self, value): # pylint: disable=too-many-branches,too-many-return-statements
+        if value.kind() == 'TensorType':
+            return 'Tensor'
+        if value.kind() == 'OptionalType':
+            element_type = self._argument_type(value.getElementType())
             return f'{element_type}?'
-        if isinstance(value, Schema.ListType):
-            element_type = self._argument_type(value.element_type)
+        if value.kind() == 'ListType':
+            element_type = self._argument_type(value.getElementType())
             size = str(value.size) if hasattr(value, 'size') else ''
             return f'{element_type}[{size}]'
-        if isinstance(value, Schema.DictType):
+        if value.kind() == 'DictType':
             key_type = self._argument_type(value.getKeyType())
             value_type = self._argument_type(value.getValueType())
             return f'Dict({key_type}, {value_type})'
-        if isinstance(value, Schema.TupleType):
+        if value.kind() == 'TupleType':
             elements = []
             for element in value.elements():
                 elements.append(self._argument_type(element))
             return f'({', '.join(elements)})'
-        name = value.name
-        return self._primitives[name] if name in self._primitives else name
+        if value.kind() == 'IntType':
+            return 'int64'
+        if value.kind() == 'SymIntType':
+            return 'SymInt'
+        if value.kind() == 'FloatType':
+            return 'float32'
+        if value.kind() == 'BoolType':
+            return 'boolean'
+        if value.kind() == 'StringType':
+            return 'string'
+        if value.kind() == 'NumberType':
+            return 'Scalar'
+        if value.kind() == 'ScalarTypeType':
+            return 'ScalarType'
+        if value.kind() == 'LayoutType':
+            return 'Layout'
+        if value.kind() == 'MemoryFormatType':
+            return 'MemoryFormat'
+        if value.kind() == 'DeviceObjType':
+            return 'Device'
+        if value.kind() == 'GeneratorType':
+            return 'Generator'
+        if value.kind() == 'VarType':
+            return value.annotation_str
+        raise NotImplementedError()
 
-    def _argument(self, argument, value):
-        argument_type = self._argument_type(value)
-        if argument_type:
-            argument['type'] = argument_type
-        else:
-            argument.pop('type', None)
-        if 'optional' in argument:
-            del argument['optional']
+class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
 
-class Schema: # pylint: disable=too-few-public-methods,missing-class-docstring
-    def __init__(self, value):
-        self.value = value
-        lexer = Schema.Lexer(value)
-        lexer.whitespace(0)
-        self._parse_name(lexer)
-        lexer.whitespace(0)
-        if lexer.kind == '(':
-            self._parse_arguments(lexer)
-            lexer.whitespace(0)
-            lexer.expect('->')
-            lexer.whitespace(0)
-            self._parse_returns(lexer)
-    def __str__(self):
-        arguments = []
-        kwarg_only = False
-        for _ in self.arguments:
-            if not kwarg_only and _.kwarg_only:
-                kwarg_only = True
-                arguments.append('*')
-            arguments.append(_.__str__())
-        if self.is_vararg:
-            arguments.append('...')
-        returns = ', '.join(map(lambda _: _.__str__(), self.returns))
-        returns = returns if len(self.returns) == 1 else '(' + returns + ')'
-        return self.name + '(' + ', '.join(arguments) + ') -> ' + returns
-    def _parse_name(self, lexer):
-        self.name = lexer.expect('id')
-        if lexer.eat(':'):
-            lexer.expect(':')
-            self.name = self.name + '::' + lexer.expect('id')
-        if lexer.eat('.'):
-            self.name = self.name + '.' + lexer.expect('id')
-    def _parse_arguments(self, lexer):
-        self.arguments = []
-        self.is_vararg = False
-        self.kwarg_only = False
-        lexer.expect('(')
-        if not lexer.eat(')'):
-            while True:
-                lexer.whitespace(0)
-                if self.is_vararg:
-                    raise NotImplementedError()
-                if lexer.eat('*'):
-                    self.kwarg_only = True
-                elif lexer.eat('...'):
-                    self.is_vararg = True
-                else:
-                    self.arguments.append(Schema.Argument(lexer, False, self.kwarg_only))
-                lexer.whitespace(0)
-                if not lexer.eat(','):
-                    break
-            lexer.expect(')')
-    def _parse_returns(self, lexer):
-        self.returns = []
-        self.is_varret = False
-        if lexer.eat('...'):
-            self.is_varret = True
-        elif lexer.eat('('):
-            lexer.whitespace(0)
-            if not lexer.eat(')'):
-                while True:
-                    lexer.whitespace(0)
-                    if self.is_varret:
-                        raise NotImplementedError()
-                    if lexer.eat('...'):
-                        self.is_varret = True
-                    else:
-                        self.returns.append(Schema.Argument(lexer, True, False))
-                    lexer.whitespace(0)
-                    if not lexer.eat(','):
-                        break
-                lexer.expect(')')
-            lexer.whitespace(0)
-        else:
-            self.returns.append(Schema.Argument(lexer, True, False))
-    class Argument: # pylint: disable=too-few-public-methods
-        def __init__(self, lexer, is_return, kwarg_only):
-            value = Schema.Type.parse(lexer)
-            lexer.whitespace(0)
-            while True:
-                if lexer.eat('['):
-                    size = None
-                    if lexer.kind == '#':
-                        size = int(lexer.value)
-                        lexer.next()
-                    lexer.expect(']')
-                    value = Schema.ListType(value, size)
-                elif lexer.eat('?'):
-                    value = Schema.OptionalType(value)
-                elif lexer.kind == '(' and not hasattr(self, 'alias'):
-                    self.alias = self._parse_alias(lexer)
-                else:
-                    break
-            self.type = value
-            if is_return:
-                lexer.whitespace(0)
-                self.kwarg_only = False
-                if lexer.kind == 'id':
-                    self.name = lexer.expect('id')
-            else:
-                lexer.whitespace(1)
-                self.kwarg_only = kwarg_only
-                self.name = lexer.expect('id')
-                lexer.whitespace(0)
-                if lexer.eat('='):
-                    lexer.whitespace(0)
-                    self.default = self._parse_value(lexer)
-        def __str__(self):
-            alias = '(' + self.alias + ')' if hasattr(self, 'alias') else ''
-            name = ' ' + self.name if hasattr(self, 'name') else ''
-            default = '=' + self.default.__str__() if hasattr(self, 'default') else ''
-            return self.type.__str__() + alias + name + default
-        def _parse_value(self, lexer):
-            if lexer.kind == 'id':
-                if lexer.value in ('True', 'False'):
-                    value = bool(lexer.value == 'True')
-                elif lexer.value == 'None':
-                    value = None
-                elif lexer.value in ('Mean', 'contiguous_format', 'long'):
-                    value = lexer.value
-                else:
-                    raise NotImplementedError()
-            elif lexer.kind == '#':
-                value = float(lexer.value) if \
-                    lexer.value.find('.') != -1 or lexer.value.find('e') != -1 else \
-                    int(lexer.value)
-            elif lexer.kind == 'string':
-                value = lexer.value[1:-1]
-            elif lexer.eat('['):
-                value = []
-                if not lexer.eat(']'):
-                    while True:
-                        lexer.whitespace(0)
-                        value.append(self._parse_value(lexer))
-                        lexer.whitespace(0)
-                        if not lexer.eat(','):
-                            break
-                    lexer.expect(']')
-                return value
-            else:
-                raise NotImplementedError()
-            lexer.next()
-            return value
-        def _parse_alias(self, lexer):
-            value = ''
-            lexer.expect('(')
-            while not lexer.eat(')'):
-                value += lexer.value
-                lexer.next()
-            return value
-    class Type: # pylint: disable=too-few-public-methods,missing-class-docstring
-        def __init__(self, name):
-            self.name = name
-        def __str__(self):
-            return self.name
-        @staticmethod
-        def parse(lexer): # pylint: disable=missing-function-docstring
-            if lexer.eat('('):
-                lexer.whitespace(0)
-                elements = []
-                while not lexer.eat(')'):
-                    elements.append(Schema.Type.parse(lexer))
-                    lexer.whitespace(0)
-                    lexer.eat(',')
-                    lexer.whitespace(0)
-                return Schema.TupleType(elements)
-            name = lexer.expect('id')
-            while lexer.eat('.'):
-                name = name + '.' + lexer.expect('id')
-            if name == 'Dict':
-                lexer.expect('(')
-                lexer.whitespace(0)
-                key_type = Schema.Type.parse(lexer)
-                lexer.whitespace(0)
-                lexer.expect(',')
-                lexer.whitespace(0)
-                value_type = Schema.Type.parse(lexer)
-                lexer.whitespace(0)
-                lexer.expect(')')
-                return Schema.DictType(key_type, value_type)
-            if name == 'Future':
-                lexer.expect('(')
-                lexer.whitespace(0)
-                elem_type = Schema.Type.parse(lexer)
-                lexer.whitespace(0)
-                lexer.expect(')')
-                return Schema.Type(f'Future({elem_type})')
-            return Schema.Type(name)
-    class OptionalType: # pylint: disable=too-few-public-methods,missing-class-docstring
-        def __init__(self, element_type):
-            self.element_type = element_type
-        def __str__(self):
-            return self.element_type.__str__() + '?'
-    class ListType: # pylint: disable=too-few-public-methods,missing-class-docstring
-        def __init__(self, element_type, size):
-            self.element_type = element_type
-            if size:
-                self.size = size
-        def __str__(self):
-            size = self.size.__str__() if hasattr(self, 'size') else ''
-            return self.element_type.__str__() + '[' + size + ']'
-    class DictType:
-        def __init__(self, key_type, value_type):
-            self._key_type = key_type
-            self._value_type = value_type
-        def __str__(self):
-            return 'Dict(' + str(self._key_type) + ', ' + str(self._value_type) + ')'
-        def getKeyType(self): # pylint: disable=invalid-name,missing-function-docstring
-            return self._key_type
-        def getValueType(self): # pylint: disable=invalid-name,,missing-function-docstring
-            return self._value_type
-    class TupleType:
-        def __init__(self, elements):
-            self._elements = elements
-        def elements(self): # pylint: disable=invalid-name,,missing-function-docstring
-            return self._elements
-    class Lexer: # pylint: disable=too-few-public-methods,missing-class-docstring
-        def __init__(self, buffer):
-            self.buffer = buffer
-            self.position = 0
-            self.value = ''
-            self.next()
-        def eat(self, kind): # pylint: disable=missing-function-docstring
-            if self.kind != kind:
-                return None
-            value = self.value
-            self.next()
-            return value
-        def expect(self, kind): # pylint: disable=missing-function-docstring
-            if self.kind != kind:
-                raise SyntaxError("Unexpected '" + self.kind + "' instead of '" + kind + "'.")
-            value = self.value
-            self.next()
-            return value
-        def whitespace(self, count): # pylint: disable=missing-function-docstring
-            if self.kind != ' ':
-                if count > len(self.value):
-                    raise IndexError()
-                return False
-            self.next()
-            return True
-        def next(self): # pylint: disable=missing-function-docstring,too-many-branches
-            self.position += len(self.value)
-            i = self.position
-            if i >= len(self.buffer):
-                self.kind = '\0'
-                self.value = ''
-            elif self.buffer[i] == ' ':
-                while self.buffer[i] == ' ':
-                    i += 1
-                self.kind = ' '
-                self.value = self.buffer[self.position:i]
-            elif self.buffer[i] == '.' and self.buffer[i+1] == '.' and self.buffer[i+2] == '.':
-                self.kind = '...'
-                self.value = '...'
-            elif self.buffer[i] in ('(', ')', ':', '.', '[', ']', ',', '=', '?', '!', '*', '|'):
-                self.kind = self.buffer[i]
-                self.value = self.buffer[i]
-            elif (self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
-                 (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or self.buffer[i] == '_':
-                i += 1
-                while i < len(self.buffer) and \
-                    ((self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
-                     (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or \
-                     (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '_'):
-                    i += 1
-                self.kind = 'id'
-                self.value = self.buffer[self.position:i]
-            elif self.buffer[i] == '-' and self.buffer[i+1] == '>':
-                self.kind = '->'
-                self.value = '->'
-            elif (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '-':
-                i += 1
-                while i < len(self.buffer) and \
-                    ((self.buffer[i] >= '0' and self.buffer[i] <= '9') or \
-                    self.buffer[i] == '.' or self.buffer[i] == 'e' or self.buffer[i] == '-'):
-                    i += 1
-                self.kind = '#'
-                self.value = self.buffer[self.position:i]
-            elif self.buffer[i] in ("'", '"'):
-                quote = self.buffer[i]
-                i += 1
-                while i < len(self.buffer) and self.buffer[i] != quote:
-                    i += 2 if self.buffer[i] == '\\' and self.buffer[i+1] in ("'", '"', '\\') else 1
-                i += 1
-                self.kind = 'string'
-                self.value = self.buffer[self.position:i]
-            else:
-                raise NotImplementedError("Unsupported token at " + self.position)
+    def __init__(self, metadata):
+        self.types = metadata
+
+    def type(self, identifier): # pylint: disable=missing-function-docstring
+        if identifier == '(no schema)':
+            return (None, '')
+        key = identifier.split('(', 1)[0]
+        value = self.types.get(key)
+        category = value['category'] if value and 'category' in value else ''
+        name, overload_name = key.split('.', 1) if key.find('.') > 0 else (key, '')
+        import torch # pylint: disable=import-outside-toplevel,import-error
+        schema = torch._C._get_schema(name, overload_name) # pylint: disable=protected-access
+        return (schema, category)

+ 365 - 6
tools/pytorch_script.py

@@ -1,4 +1,5 @@
 ''' TorchScript metadata script '''
+# pylint: disable=too-many-lines
 
 import collections
 import json
@@ -9,13 +10,371 @@ import sys
 root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 sys.path.append(root_dir)
 sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'test', 'backend')
-pytorch = __import__('source.pytorch').pytorch
 
 source_dir = os.path.join(root_dir, 'source')
 third_party_dir = os.path.join(root_dir, 'third_party')
 metadata_file = os.path.join(source_dir, 'pytorch-metadata.json')
 pytorch_source_dir = os.path.join(third_party_dir, 'source', 'pytorch')
 
+class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
+
+    def __init__(self, metadata):
+        self.types = metadata
+        self.cache = set()
+        self._primitives = {
+            'int': 'int64', 'float': 'float32', 'bool': 'boolean', 'str': 'string'
+        }
+
+    def type(self, schema): # pylint: disable=missing-function-docstring
+        key = schema.name if isinstance(schema, Schema) else schema.split('(', 1)[0].strip()
+        if key not in self.cache:
+            self.cache.add(key)
+            schema = schema if isinstance(schema, Schema) else Schema(schema)
+            arguments = list(filter(lambda _: \
+                not(_.kwarg_only and hasattr(_, 'alias')), schema.arguments))
+            returns = schema.returns
+            value = self.types.setdefault(schema.name, { 'name': schema.name, })
+            inputs = value.get('inputs', [])
+            outputs = value.get('outputs', [])
+            inputs = [ inputs[i] if i < len(inputs) else {} for i in range(len(arguments)) ]
+            outputs = [ outputs[i] if i < len(outputs) else {} for i in range(len(returns)) ]
+            value['inputs'] = inputs
+            value['outputs'] = outputs
+            for i, _ in enumerate(arguments):
+                argument = inputs[i]
+                argument['name'] = _.name
+                self._argument(argument, getattr(_, 'type'))
+                if hasattr(_, 'default'):
+                    argument['default'] = _.default
+                if hasattr(_, 'kwarg_only') and _.kwarg_only is True:
+                    argument['kwarg_only'] = True
+            for i, _ in enumerate(returns):
+                argument = outputs[i]
+                if hasattr(_, 'name'):
+                    argument['name'] = _.name
+                self._argument(argument, getattr(_, 'type'))
+        return self.types[key]
+
+    def _argument_type(self, value):
+        if isinstance(value, Schema.OptionalType):
+            element_type = self._argument_type(value.element_type)
+            return f'{element_type}?'
+        if isinstance(value, Schema.ListType):
+            element_type = self._argument_type(value.element_type)
+            size = str(value.size) if hasattr(value, 'size') else ''
+            return f'{element_type}[{size}]'
+        if isinstance(value, Schema.DictType):
+            key_type = self._argument_type(value.getKeyType())
+            value_type = self._argument_type(value.getValueType())
+            return f'Dict({key_type}, {value_type})'
+        if isinstance(value, Schema.TupleType):
+            elements = []
+            for element in value.elements():
+                elements.append(self._argument_type(element))
+            return f'({', '.join(elements)})'
+        name = value.name
+        return self._primitives[name] if name in self._primitives else name
+
+    def _argument(self, argument, value):
+        argument_type = self._argument_type(value)
+        if argument_type:
+            argument['type'] = argument_type
+        else:
+            argument.pop('type', None)
+        if 'optional' in argument:
+            del argument['optional']
+
+class Schema: # pylint: disable=too-few-public-methods,missing-class-docstring
+    def __init__(self, value):
+        self.value = value
+        lexer = Schema.Lexer(value)
+        lexer.whitespace(0)
+        self._parse_name(lexer)
+        lexer.whitespace(0)
+        if lexer.kind == '(':
+            self._parse_arguments(lexer)
+            lexer.whitespace(0)
+            lexer.expect('->')
+            lexer.whitespace(0)
+            self._parse_returns(lexer)
+    def __str__(self):
+        arguments = []
+        kwarg_only = False
+        for _ in self.arguments:
+            if not kwarg_only and _.kwarg_only:
+                kwarg_only = True
+                arguments.append('*')
+            arguments.append(_.__str__())
+        if self.is_vararg:
+            arguments.append('...')
+        returns = ', '.join(map(lambda _: _.__str__(), self.returns))
+        returns = returns if len(self.returns) == 1 else '(' + returns + ')'
+        return self.name + '(' + ', '.join(arguments) + ') -> ' + returns
+    def _parse_name(self, lexer):
+        self.name = lexer.expect('id')
+        if lexer.eat(':'):
+            lexer.expect(':')
+            self.name = self.name + '::' + lexer.expect('id')
+        if lexer.eat('.'):
+            self.name = self.name + '.' + lexer.expect('id')
+    def _parse_arguments(self, lexer):
+        self.arguments = []
+        self.is_vararg = False
+        self.kwarg_only = False
+        lexer.expect('(')
+        if not lexer.eat(')'):
+            while True:
+                lexer.whitespace(0)
+                if self.is_vararg:
+                    raise NotImplementedError()
+                if lexer.eat('*'):
+                    self.kwarg_only = True
+                elif lexer.eat('...'):
+                    self.is_vararg = True
+                else:
+                    self.arguments.append(Schema.Argument(lexer, False, self.kwarg_only))
+                lexer.whitespace(0)
+                if not lexer.eat(','):
+                    break
+            lexer.expect(')')
+    def _parse_returns(self, lexer):
+        self.returns = []
+        self.is_varret = False
+        if lexer.eat('...'):
+            self.is_varret = True
+        elif lexer.eat('('):
+            lexer.whitespace(0)
+            if not lexer.eat(')'):
+                while True:
+                    lexer.whitespace(0)
+                    if self.is_varret:
+                        raise NotImplementedError()
+                    if lexer.eat('...'):
+                        self.is_varret = True
+                    else:
+                        self.returns.append(Schema.Argument(lexer, True, False))
+                    lexer.whitespace(0)
+                    if not lexer.eat(','):
+                        break
+                lexer.expect(')')
+            lexer.whitespace(0)
+        else:
+            self.returns.append(Schema.Argument(lexer, True, False))
+    class Argument: # pylint: disable=too-few-public-methods
+        def __init__(self, lexer, is_return, kwarg_only):
+            value = Schema.Type.parse(lexer)
+            lexer.whitespace(0)
+            while True:
+                if lexer.eat('['):
+                    size = None
+                    if lexer.kind == '#':
+                        size = int(lexer.value)
+                        lexer.next()
+                    lexer.expect(']')
+                    value = Schema.ListType(value, size)
+                elif lexer.eat('?'):
+                    value = Schema.OptionalType(value)
+                elif lexer.kind == '(' and not hasattr(self, 'alias'):
+                    self.alias = self._parse_alias(lexer)
+                else:
+                    break
+            self.type = value
+            if is_return:
+                lexer.whitespace(0)
+                self.kwarg_only = False
+                if lexer.kind == 'id':
+                    self.name = lexer.expect('id')
+            else:
+                lexer.whitespace(1)
+                self.kwarg_only = kwarg_only
+                self.name = lexer.expect('id')
+                lexer.whitespace(0)
+                if lexer.eat('='):
+                    lexer.whitespace(0)
+                    self.default = self._parse_value(lexer)
+        def __str__(self):
+            alias = '(' + self.alias + ')' if hasattr(self, 'alias') else ''
+            name = ' ' + self.name if hasattr(self, 'name') else ''
+            default = '=' + self.default.__str__() if hasattr(self, 'default') else ''
+            return self.type.__str__() + alias + name + default
+        def _parse_value(self, lexer):
+            if lexer.kind == 'id':
+                if lexer.value in ('True', 'False'):
+                    value = bool(lexer.value == 'True')
+                elif lexer.value == 'None':
+                    value = None
+                elif lexer.value in ('Mean', 'contiguous_format', 'long'):
+                    value = lexer.value
+                else:
+                    raise NotImplementedError()
+            elif lexer.kind == '#':
+                value = float(lexer.value) if \
+                    lexer.value.find('.') != -1 or lexer.value.find('e') != -1 else \
+                    int(lexer.value)
+            elif lexer.kind == 'string':
+                value = lexer.value[1:-1]
+            elif lexer.eat('['):
+                value = []
+                if not lexer.eat(']'):
+                    while True:
+                        lexer.whitespace(0)
+                        value.append(self._parse_value(lexer))
+                        lexer.whitespace(0)
+                        if not lexer.eat(','):
+                            break
+                    lexer.expect(']')
+                return value
+            else:
+                raise NotImplementedError()
+            lexer.next()
+            return value
+        def _parse_alias(self, lexer):
+            value = ''
+            lexer.expect('(')
+            while not lexer.eat(')'):
+                value += lexer.value
+                lexer.next()
+            return value
+    class Type: # pylint: disable=too-few-public-methods,missing-class-docstring
+        def __init__(self, name):
+            self.name = name
+        def __str__(self):
+            return self.name
+        @staticmethod
+        def parse(lexer): # pylint: disable=missing-function-docstring
+            if lexer.eat('('):
+                lexer.whitespace(0)
+                elements = []
+                while not lexer.eat(')'):
+                    elements.append(Schema.Type.parse(lexer))
+                    lexer.whitespace(0)
+                    lexer.eat(',')
+                    lexer.whitespace(0)
+                return Schema.TupleType(elements)
+            name = lexer.expect('id')
+            while lexer.eat('.'):
+                name = name + '.' + lexer.expect('id')
+            if name == 'Dict':
+                lexer.expect('(')
+                lexer.whitespace(0)
+                key_type = Schema.Type.parse(lexer)
+                lexer.whitespace(0)
+                lexer.expect(',')
+                lexer.whitespace(0)
+                value_type = Schema.Type.parse(lexer)
+                lexer.whitespace(0)
+                lexer.expect(')')
+                return Schema.DictType(key_type, value_type)
+            if name == 'Future':
+                lexer.expect('(')
+                lexer.whitespace(0)
+                elem_type = Schema.Type.parse(lexer)
+                lexer.whitespace(0)
+                lexer.expect(')')
+                return Schema.Type(f'Future({elem_type})')
+            return Schema.Type(name)
+    class OptionalType: # pylint: disable=too-few-public-methods,missing-class-docstring
+        def __init__(self, element_type):
+            self.element_type = element_type
+        def __str__(self):
+            return self.element_type.__str__() + '?'
+    class ListType: # pylint: disable=too-few-public-methods,missing-class-docstring
+        def __init__(self, element_type, size):
+            self.element_type = element_type
+            if size:
+                self.size = size
+        def __str__(self):
+            size = self.size.__str__() if hasattr(self, 'size') else ''
+            return self.element_type.__str__() + '[' + size + ']'
+    class DictType:
+        def __init__(self, key_type, value_type):
+            self._key_type = key_type
+            self._value_type = value_type
+        def __str__(self):
+            return 'Dict(' + str(self._key_type) + ', ' + str(self._value_type) + ')'
+        def getKeyType(self): # pylint: disable=invalid-name,missing-function-docstring
+            return self._key_type
+        def getValueType(self): # pylint: disable=invalid-name,,missing-function-docstring
+            return self._value_type
+    class TupleType:
+        def __init__(self, elements):
+            self._elements = elements
+        def elements(self): # pylint: disable=invalid-name,,missing-function-docstring
+            return self._elements
+    class Lexer: # pylint: disable=too-few-public-methods,missing-class-docstring
+        def __init__(self, buffer):
+            self.buffer = buffer
+            self.position = 0
+            self.value = ''
+            self.next()
+        def eat(self, kind): # pylint: disable=missing-function-docstring
+            if self.kind != kind:
+                return None
+            value = self.value
+            self.next()
+            return value
+        def expect(self, kind): # pylint: disable=missing-function-docstring
+            if self.kind != kind:
+                raise SyntaxError("Unexpected '" + self.kind + "' instead of '" + kind + "'.")
+            value = self.value
+            self.next()
+            return value
+        def whitespace(self, count): # pylint: disable=missing-function-docstring
+            if self.kind != ' ':
+                if count > len(self.value):
+                    raise IndexError()
+                return False
+            self.next()
+            return True
+        def next(self): # pylint: disable=missing-function-docstring,too-many-branches
+            self.position += len(self.value)
+            i = self.position
+            if i >= len(self.buffer):
+                self.kind = '\0'
+                self.value = ''
+            elif self.buffer[i] == ' ':
+                while self.buffer[i] == ' ':
+                    i += 1
+                self.kind = ' '
+                self.value = self.buffer[self.position:i]
+            elif self.buffer[i] == '.' and self.buffer[i+1] == '.' and self.buffer[i+2] == '.':
+                self.kind = '...'
+                self.value = '...'
+            elif self.buffer[i] in ('(', ')', ':', '.', '[', ']', ',', '=', '?', '!', '*', '|'):
+                self.kind = self.buffer[i]
+                self.value = self.buffer[i]
+            elif (self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
+                 (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or self.buffer[i] == '_':
+                i += 1
+                while i < len(self.buffer) and \
+                    ((self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
+                     (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or \
+                     (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '_'):
+                    i += 1
+                self.kind = 'id'
+                self.value = self.buffer[self.position:i]
+            elif self.buffer[i] == '-' and self.buffer[i+1] == '>':
+                self.kind = '->'
+                self.value = '->'
+            elif (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '-':
+                i += 1
+                while i < len(self.buffer) and \
+                    ((self.buffer[i] >= '0' and self.buffer[i] <= '9') or \
+                    self.buffer[i] == '.' or self.buffer[i] == 'e' or self.buffer[i] == '-'):
+                    i += 1
+                self.kind = '#'
+                self.value = self.buffer[self.position:i]
+            elif self.buffer[i] in ("'", '"'):
+                quote = self.buffer[i]
+                i += 1
+                while i < len(self.buffer) and self.buffer[i] != quote:
+                    i += 2 if self.buffer[i] == '\\' and self.buffer[i+1] in ("'", '"', '\\') else 1
+                i += 1
+                self.kind = 'string'
+                self.value = self.buffer[self.position:i]
+            else:
+                raise NotImplementedError("Unsupported token at " + self.position)
+
 def _read(path):
     with open(path, 'r', encoding='utf-8') as file:
         return file.read()
@@ -597,7 +956,7 @@ known_legacy_schema_definitions = [
     'aten::grid_sampler.legacy(Tensor input, Tensor grid, int interpolation_mode, int padding_mode) -> Tensor',
     'neuron::forward_v2_1(Tensor[] _0, __torch__.torch.classes.neuron.Model _1) -> (Tensor _0)',
     'prim::shape(Tensor self) -> int[]',
-    'torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int64)',
+    'torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int)',
     'torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor',
     'torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor',
 ]
@@ -618,15 +977,15 @@ def _parse_schemas():
             definition = entry[2] + value if len(entry) > 2 else value
             if not definition in definitions:
                 definitions.add(definition)
-                schema = pytorch.Schema(definition)
+                schema = Schema(definition)
                 if schema.name in schemas:
                     raise KeyError(schema.name)
                 schemas[schema.name] = schema
     for value in known_legacy_schema_definitions:
-        schema = pytorch.Schema(value)
+        schema = Schema(value)
         schemas[schema.name] = schema
     for value in known_schema_definitions:
-        schema = pytorch.Schema(value)
+        schema = Schema(value)
         schemas[schema.name] = schema
     return schemas
 
@@ -699,7 +1058,7 @@ def _metadata():
     _check_types(types, schemas)
     _check_schemas(schemas)
     filtered_schemas = _filter_schemas(schemas, types)
-    metadata = pytorch.Metadata(types)
+    metadata = Metadata(types)
     for schema in filtered_schemas.values():
         value = metadata.type(schema)
         value['name'] = schema.value