Explorar el Código

Update pytorch.js (#1061)

Lutz Roeder hace 1 año
padre
commit
a152bb027c
Se han modificado 5 ficheros con 347 adiciones y 95 borrados
  1. 195 57
      source/python.js
  2. 69 11
      source/pytorch-metadata.json
  3. 56 10
      source/pytorch.js
  4. 9 12
      test/models.json
  5. 18 5
      tools/pytorch_script.py

+ 195 - 57
source/python.js

@@ -5594,6 +5594,12 @@ python.Execution = class {
             }
             throw new python.Error('Unsupported torch.add expression type.');
         });
+        this.registerFunction('torch.all', (input) => {
+            if (Array.isArray(input) && input.length === 0) {
+                return true;
+            }
+            throw new python.Error(`Unsupported 'torch.all' expression type.`);
+        });
         this.registerFunction('torch.append', (list, value) => {
             list.push(value);
             return value;
@@ -5648,9 +5654,6 @@ python.Execution = class {
             return NaN;
         });
         this.registerFunction('torch.eq', (left, right) => {
-            const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
-            left = value(left);
-            right = value(right);
             if (typeof left === 'string' && typeof right === 'string') {
                 return left === right;
             }
@@ -5696,9 +5699,6 @@ python.Execution = class {
             return self.replace(regex, '');
         });
         this.registerFunction('torch.gt', (left, right) => {
-            const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
-            left = value(left);
-            right = value(right);
             if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
                 if (!isNaN(left) && !isNaN(right)) {
                     return left > right;
@@ -6020,9 +6020,6 @@ python.Execution = class {
             throw new python.Error("Unsupported 'torch.remainder' expression type.");
         });
         this.registerFunction('torch.ne', (left, right) => {
-            const value = (x) => x && x.__class__ && x.__class__.__module__ === 'torch' && x.__class__.__name__ === 'Value' ? x.value : x;
-            left = value(left);
-            right = value(right);
             if (typeof left === 'boolean' && typeof right === 'boolean') {
                 return left !== right;
             }
@@ -6210,15 +6207,14 @@ python.Execution = class {
         });
         this.registerType('torch.ClassType', class extends torch.Type {
             constructor(qualified_name, cu, is_module) {
-                super('ClassType');
-                this._qualified_name = qualified_name;
+                super('ClassType', qualified_name);
                 this._is_module = is_module;
                 this._attributes = new Map();
                 this._methods = new Map();
                 this._staticmethods = new Map();
             }
             qualified_name() {
-                return this._qualified_name;
+                return this.annotation_str;
             }
             name() {
                 return this._qualified_name.split('.').pop();
@@ -6349,20 +6345,39 @@ python.Execution = class {
             }
         });
         this.registerType('torch.TupleType', class extends torch.Type {
-            constructor(elements) {
-                super('TupleType');
+            constructor(elements, annotation_str, schema) {
+                super('TupleType', annotation_str);
                 this._elements = elements;
+                this._schema = schema;
             }
             static get(elements) {
                 return new torch.TupleType(elements);
             }
+            static createNamed(qualified_name, field_names, field_types /*, field_defaults */) {
+                const args = [];
+                for (let i = 0; i < field_names.length; i++) {
+                    const arg = new torch.Argument(field_names[i], field_types[i], field_types[i]);
+                    args.push(arg);
+                }
+                const schema = new torch.FunctionSchema(qualified_name, args);
+                return new torch.TupleType(field_types, qualified_name, schema);
+            }
             elements() {
                 return this._elements;
             }
+            schema() {
+                return this._schema;
+            }
             str() {
+                if (this._schema) {
+                    return `NamedTuple(...)`;
+                }
                 return `(${this.elements().map((elem) => elem.str()).join(', ')})`;
             }
             __str__() {
+                if (this.annotation_str) {
+                    return this.annotation_str;
+                }
                 return `Tuple[${this.elements().map((elem) => elem.__str__()).join(', ')}]`;
             }
         });
@@ -7074,11 +7089,11 @@ python.Execution = class {
                 const index = name.indexOf('(');
                 if (index === -1) {
                     this._name = name;
-                    this._overload_name = overload_name;
-                    this._arguments = args;
-                    this._returns = returns;
-                    this._is_vararg = is_vararg;
-                    this._is_varret = is_varret;
+                    this._overload_name = overload_name || '';
+                    this._arguments = args || [];
+                    this._returns = returns || [];
+                    this._is_vararg = is_vararg || false;
+                    this._is_varret = is_varret || false;
                 } else {
                     const value = name.substring(0, index).trim();
                     const dot = value.indexOf('.');
@@ -7689,22 +7704,32 @@ python.Execution = class {
         this.register('torch.jit._script');
         this.register('torch.jit._trace');
         this.registerType('torch.jit.Source', class {
-            constructor(text) {
-                this._text = text;
+            constructor(text_view, filename) {
+                this._text_view = text_view;
+                this._filename = filename;
+            }
+            text_str() {
+                return this._text_view;
+            }
+            filename() {
+                return this._filename;
             }
         });
-        this.registerType('torch.jit.SourceLoader', class {
-            constructor(reader, code_prefix) {
-                this._reader = reader;
-                this._code_prefix = code_prefix;
+        this.registerType('torch.jit.QualifiedName', class {
+            constructor(name) {
+                const index = name.lastIndexOf('.');
+                this._qualifiedName = name;
+                this._prefix = index === -1 ? '' : name.substring(0, index);
+                this._name = index === -1 ? name : name.substring(index + 1);
             }
-            loadSource(qualifier) {
-                const path = `${this._code_prefix}/${qualifier}.py`;
-                if (this._reader.has_record(path)) {
-                    const data = this._reader.get_record(path);
-                    return new torch.jit.Source(data);
-                }
-                return null;
+            qualifiedName() {
+                return this._qualifiedName; // "foo.bar.baz"
+            }
+            prefix() {
+                return this._prefix; // "foo.bar"
+            }
+            name() {
+                return this._name; // "baz"
             }
         });
         this.registerType('torch.jit.SourceImporter', class {
@@ -7713,17 +7738,103 @@ python.Execution = class {
                 this._constant_table = constant_table;
                 this._source_loader = source_loader;
                 this._version = version;
+                this._loaded_sources = new Set();
+                this._to_be_defined = new Map();
             }
             loadType(/* name */) {
                 //
             }
             resolveType(name) {
-                return this.findNamedType(new torch.jit.QualifiedName(name));
+                name = new torch.jit.QualifiedName(name);
+                return this.findNamedType(name);
             }
             findNamedType(name) {
+                // if (auto custom_class = getCustomClass(name.qualifiedName())) {
+                //     return custom_class;
+                // }
                 this.parseSourceIfNeeded(name.prefix());
+                const key = name.qualifiedName();
+                const it = this._to_be_defined.get(key);
+                if (it && it.type === 'class') {
+                    this._to_be_defined.delete(key);
+                    this.importNamedType(name.prefix(), it);
+                }
+                return this._cu.get_type(name);
+            }
+            importNamedType(qualifier, class_def) {
+                const qualified_name = new torch.jit.QualifiedName(`${qualifier}.${class_def.name}`);
+                if (class_def.bases.length === 0) {
+                    this.importClass(qualified_name, class_def, false);
+                    return;
+                }
+                const superclass_name = class_def.bases[0].value;
+                if (superclass_name === 'Module') {
+                    this.importClass(qualified_name, class_def, true);
+                } else if (superclass_name === 'NamedTuple') {
+                    this.importNamedTuple(qualified_name, class_def);
+                } else if (superclass_name === 'Interface') {
+                    // cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=false);
+                } else if (superclass_name === 'ModuleInterface') {
+                    // cu_->define_interface(qualified_name, class_def, shared_from_this(), is_module=true);
+                } else if (superclass_name === 'Enum') {
+                    // importEnum(qualified_name, class_def);
+                } else {
+                    throw new python.Error('TorchScript does not support class inheritance.');
+                }
+            }
+            importClass(qualified_name, class_def, is_module) {
+                if (qualified_name.prefix().startsWith('__torch__.torch.classes')) {
+                    return;
+                }
+                const class_type = new torch.ClassType(qualified_name.qualifiedName(), this._cu, is_module);
+                for (const entry of class_def.body.statements) {
+                    if (entry.type === 'var') {
+                        const variableType = this._cu.execution.type(entry.variableType, null);
+                        class_type.addAttribute(entry.name, variableType);
+                    }
+                }
+                // debugger;
+                this._cu.register_type(class_type);
             }
-            parseSourceIfNeeded(/* qualifier */) {
+            importNamedTuple(qualified_name, named_tuple_def) {
+                const field_names = [];
+                const field_types = [];
+                const field_defaults = [];
+                for (const statement of named_tuple_def.body.statements) {
+                    if (statement.type !== 'var') {
+                        throw new python.Error('Unexpected statement in NamedTuple body.');
+                    }
+                    field_names.push(statement.name);
+                    field_types.push(this._cu.execution.type(statement.variableType));
+                }
+                const tt = torch.TupleType.createNamed(qualified_name.qualifiedName(), field_names, field_types, field_defaults);
+                this._cu.register_type(tt);
+            }
+            parseSourceIfNeeded(qualifier) {
+                if (!qualifier || this._loaded_sources.has(qualifier)) {
+                    return;
+                }
+                this._loaded_sources.add(qualifier);
+                const src = this._source_loader(qualifier);
+                if (!src) {
+                    return;
+                }
+                const program = this._cu.execution.parse(src.filename(), src.text_str(), null);
+                for (const statement of program.body) {
+                    switch (statement.type) {
+                        case 'def': {
+                            break;
+                        }
+                        case 'class': {
+                            const name = `${qualifier}.${statement.name}`;
+                            this._to_be_defined.set(name, statement);
+                            break;
+                        }
+                        default: {
+                            break;
+                        }
+                    }
+                }
             }
         });
         this.registerType('torch.jit.ScriptModuleDeserializer', class {
@@ -7734,9 +7845,11 @@ python.Execution = class {
                 this._code_prefix = !pickle_dir_prefix && !tensor_dir_prefix ? 'code/' : '.data/ts_code/code/';
                 this._pickle_dir_prefix = pickle_dir_prefix || '';
                 this._tensor_dir_prefix = tensor_dir_prefix || '';
+                const SourceLoader = (qualifier) => {
+                    return this.findSourceInArchiveFromQualifier(this._reader, this._code_prefix, qualifier);
+                };
                 this._source_importer = new torch.jit.SourceImporter(
-                    this._compilation_unit, this._constants_table,
-                    new torch.jit.SourceLoader(this._reader, this._code_prefix), reader.version());
+                    this._compilation_unit, this._constants_table, SourceLoader, reader.version());
             }
             deserialize() {
                 const execution = this._compilation_unit.execution;
@@ -7755,6 +7868,13 @@ python.Execution = class {
                 execution.builtins.ops = torch.ops;
                 execution.builtins.inf = torch.inf;
                 execution.builtins.CONSTANTS = {};
+                execution._resolver = this._source_importer;
+                const known_types = ['__torch__.torch.classes._nnapi.Compilation'];
+                for (const name of known_types) {
+                    const type = new torch.ClassType(name, this._compilation_unit, false);
+                    type.addMethod(new torch.FunctionSchema('init(Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> ()'));
+                    this._compilation_unit.register_type(type);
+                }
                 if (this._reader.has_record('model.json')) {
                     return this.LEGACY_deserialize();
                 }
@@ -7914,6 +8034,17 @@ python.Execution = class {
                 };
                 return unpickler.load();
             }
+            qualifierToArchivePath(qualifier, export_prefix) {
+                return `${export_prefix}${qualifier.replace(/\./g, '/')}.py`;
+            }
+            findSourceInArchiveFromQualifier(reader, export_prefix, qualifier) {
+                const path = this.qualifierToArchivePath(qualifier, export_prefix);
+                if (!reader.has_record(path)) {
+                    return null;
+                }
+                const data = reader.get_record(path);
+                return new torch.jit.Source(data.peek(), path);
+            }
         });
         this.registerType('torch.package.PackageImporter', class {
             constructor(reader) {
@@ -8215,6 +8346,9 @@ python.Execution = class {
                 this._functions = new Map();
                 this._classes = new Map();
             }
+            register_type(namedType) {
+                this._classes.set(namedType.annotation_str, namedType);
+            }
             register_function(fn) {
                 this._functions.set(fn.name, fn);
             }
@@ -8228,14 +8362,11 @@ python.Execution = class {
                 }
             }
             get_type(name) {
-                return this._classes.get(name);
+                return this._classes.get(name.qualifiedName());
             }
             get_class(name) {
                 return this.get_type(name);
             }
-            register_type(name, cls) {
-                this._classes.set(name, cls);
-            }
         });
         this.registerType('torch.jit._script.ScriptModule', class extends torch.nn.modules.module.Module {});
         this.registerType('torch.jit._trace.TracedModule', class extends torch.jit._script.ScriptModule {});
@@ -8399,7 +8530,7 @@ python.Execution = class {
                 if (!cls) {
                     const name = obj_type.type_name;
                     if (name.startsWith('__torch__') || name.startsWith('torch.jit')) {
-                        cls = this._cu.get_class(name);
+                        cls = this._cu.get_class(new torch.jit.QualifiedName(name));
                         if (!cls) {
                             const torch = this._torch;
                             cls = new torch.ClassType(name, this._cu, true);
@@ -10247,13 +10378,6 @@ python.Execution = class {
         return this._builtins;
     }
 
-    source(file) {
-        return this._sources.has(file) ? this._sources.get(file) : null;
-    }
-
-    debug(/* file */) {
-    }
-
     exec(code , context) {
         const reader = new python.Parser(code, '', null);
         const program = reader.parse();
@@ -10263,21 +10387,35 @@ python.Execution = class {
         this.block(program.body, context);
     }
 
-    parse(file) {
+    debug(/* file */) {
+    }
+
+    source(file) {
+        if (this._sources.has(file)) {
+            return this._sources.get(file);
+        }
+        return null;
+    }
+
+    read(file) {
         const buffer = this.source(file);
         if (buffer) {
             const debug = this.debug(file);
-            const code = this._utf8Decoder.decode(buffer);
-            const parser = new python.Parser(code, file, debug);
-            const program = parser.parse();
-            if (!program) {
-                throw new python.Error(`Module '${file}' parse error.`);
-            }
-            return program;
+            return this.parse(file, buffer, debug);
         }
         return null;
     }
 
+    parse(file, buffer, debug) {
+        const code = this._utf8Decoder.decode(buffer);
+        const parser = new python.Parser(code, file, debug);
+        const program = parser.parse();
+        if (!program) {
+            throw new python.Error(`Module '${file}' parse error.`);
+        }
+        return program;
+    }
+
     import(name, current, level) {
         if (level) {
             let bits = current.split('.');
@@ -10303,7 +10441,7 @@ python.Execution = class {
             const path = name.split('.').join('/');
             module.__path__ = [path];
             const file = `${path}.py`;
-            const program = this.parse(file);
+            const program = this.read(file);
             if (program) {
                 module.__file__ = file;
                 for (const [name, value] of Object.entries(this.builtins)) {

+ 69 - 11
source/pytorch-metadata.json

@@ -40,6 +40,9 @@
   {
     "name": "_caffe2::BoxWithNMSLimit(Tensor scores, Tensor boxes, Tensor batch_splits, float score_thresh, float nms, int detections_per_im, bool soft_nms_enabled, str soft_nms_method, float soft_nms_sigma, float soft_nms_min_score_thres, bool rotated, bool cls_agnostic_bbox_reg, bool input_boxes_include_bg_cls, bool output_classes_include_bg_cls, bool legacy_plus_one) -> (Tensor scores, Tensor boxes, Tensor classes, Tensor batch_splits, Tensor keeps, Tensor keeps_size)"
   },
+  {
+    "name": "_caffe2::CollectAndDistributeFpnRpnProposals(Tensor[] input_list, int roi_canonical_scale, int roi_canonical_level, int roi_max_level, int roi_min_level, int rpn_max_level, int rpn_min_level, int rpn_post_nms_topN, bool legacy_plus_one) -> (Tensor rois, Tensor rois_fpn2, Tensor rois_fpn3, Tensor rois_fpn4, Tensor rois_fpn5, Tensor rois_idx_restore_int32)"
+  },
   {
     "name": "_caffe2::CollectRpnProposals(Tensor[] input_list, int rpn_max_level, int rpn_min_level, int rpn_post_nms_topN) -> (Tensor rois)"
   },
@@ -1463,17 +1466,6 @@
   {
     "name": "aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)"
   },
-  {
-    "name": "aten::classes._nnapi.Compilation",
-    "inputs": [
-      { "name": "serialized_model", "type": "Tensor" },
-      { "name": "inputs", "type": "Tensor[]" },
-      { "name": "parameter_buffers", "type": "Tensor[]" }
-    ],
-    "outputs": [
-      { "type": "Tensor[]" }
-    ]
-  },
   {
     "name": "aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"
   },
@@ -5637,6 +5629,24 @@
   {
     "name": "aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)"
   },
+  {
+    "name": "aten::update.Tensor(Dict(Tensor, t)(a!) self, Dict(Tensor, t)(a!) to_add) -> ()"
+  },
+  {
+    "name": "aten::update.bool(Dict(bool, t)(a!) self, Dict(bool, t)(a!) to_add) -> ()"
+  },
+  {
+    "name": "aten::update.complex(Dict(complex, t)(a!) self, Dict(complex, t)(a!) to_add) -> ()"
+  },
+  {
+    "name": "aten::update.float(Dict(float, t)(a!) self, Dict(float, t)(a!) to_add) -> ()"
+  },
+  {
+    "name": "aten::update.int(Dict(int, t)(a!) self, Dict(int, t)(a!) to_add) -> ()"
+  },
+  {
+    "name": "aten::update.str(Dict(str, t)(a!) self, Dict(str, t)(a!) to_add) -> ()"
+  },
   {
     "name": "aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor",
     "category": "Layer"
@@ -6890,6 +6900,54 @@
     "name": "torch.nn.modules.upsampling.Upsample",
     "category": "Data"
   },
+  {
+    "name": "torch_scatter::cuda_version() -> int _0"
+  },
+  {
+    "name": "torch_scatter::gather_coo(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::gather_csr(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::scatter_max(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> (Tensor _0, Tensor _1)"
+  },
+  {
+    "name": "torch_scatter::scatter_mean(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::scatter_min(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> (Tensor _0, Tensor _1)"
+  },
+  {
+    "name": "torch_scatter::scatter_mul(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::scatter_sum(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::segment_max_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> (Tensor _0, Tensor _1)"
+  },
+  {
+    "name": "torch_scatter::segment_max_csr(Tensor _0, Tensor _1, Tensor? _2) -> (Tensor _0, Tensor _1)"
+  },
+  {
+    "name": "torch_scatter::segment_mean_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::segment_mean_csr(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::segment_min_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> (Tensor _0, Tensor _1)"
+  },
+  {
+    "name": "torch_scatter::segment_min_csr(Tensor _0, Tensor _1, Tensor? _2) -> (Tensor _0, Tensor _1)"
+  },
+  {
+    "name": "torch_scatter::segment_sum_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> Tensor _0"
+  },
+  {
+    "name": "torch_scatter::segment_sum_csr(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0"
+  },
   {
     "name": "torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int)"
   },

+ 56 - 10
source/pytorch.js

@@ -1564,7 +1564,6 @@ pytorch.Execution = class extends python.Execution {
                 }
             }
         });
-        this.register('__torch__').torch.classes._nnapi.Compilation.__type__ = new torch.ClassType('__torch__.torch.classes._nnapi.Compilation');
         this.registerType('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', class {
             __setstate__(state) {
                 if (state[0] !== '2') {
@@ -1898,7 +1897,7 @@ pytorch.Execution = class extends python.Execution {
                     return value;
                 }
                 if (expression.target.type === 'id' && expression.target.value === 'uninitialized') {
-                    const type = this.type(expression.args[0], context);
+                    const type = this.type(expression.args[0]);
                     const node = this._graph.create('prim::Uninitialized');
                     this.graph.insertNode(node);
                     const value = node.addOutput();
@@ -1907,7 +1906,7 @@ pytorch.Execution = class extends python.Execution {
                 }
                 if (expression.target.type === 'id' && expression.target.value === 'unchecked_cast') {
                     let value = this.expression(expression.args[1], context);
-                    const type = this.type(expression.args[0], context);
+                    const type = this.type(expression.args[0]);
                     const node = this._graph.create('prim::unchecked_cast');
                     this.graph.insertNode(node);
                     node.addInput(this.variable(value));
@@ -2440,6 +2439,11 @@ pytorch.Execution = class extends python.Execution {
                             if (pytorch.Utility.isTensor(value)) {
                                 return torch.TensorType.get();
                             }
+                            if (value && value.__class__ && value instanceof torch.Value === false) {
+                                const identifier = `${value.__class__.__module__}.${value.__class__.__name__}`;
+                                const type = this._resolver.resolveType(identifier);
+                                return type;
+                            }
                             return value.type();
                         };
                         this.variables(statement, statement);
@@ -2480,6 +2484,8 @@ pytorch.Execution = class extends python.Execution {
                                 const t2 = __type(entry.orelse);
                                 if (t1 === null && t2 === null) {
                                     type = null;
+                                } else if (t1 === t2) {
+                                    type = t1;
                                 } else if (t1.equals(t2)) {
                                     type = t2;
                                 } else if (t1 instanceof torch.NoneType && t2 instanceof torch.NoneType === false) {
@@ -2538,7 +2544,7 @@ pytorch.Execution = class extends python.Execution {
         return super.statement(statement, context);
     }
 
-    type(expression, context) {
+    type(expression) {
         const torch = this.torch;
         if (expression.type === '[]' && expression.target.type === 'id') {
             switch (expression.target.value) {
@@ -2581,9 +2587,10 @@ pytorch.Execution = class extends python.Execution {
             }
         }
         if (expression.type === '.') {
-            const target = this.expression(expression, context);
-            if (target && target.__type__ instanceof torch.ClassType) {
-                return target.__type__;
+            const identifier = pytorch.Utility.target(expression);
+            const type = this._resolver.resolveType(identifier);
+            if (type) {
+                return type;
             }
         }
         throw new pytorch.Error(`Unsupported type expression '${expression.type}'.`);
@@ -2597,16 +2604,18 @@ pytorch.Execution = class extends python.Execution {
         if (name === '__new__') {
             const identifier = pytorch.Utility.target(target);
             if (identifier) {
-                const type = this.resolve(identifier);
-                if (type && type.__type__) {
+                const type = this._resolver.resolveType(identifier);
+                if (type) {
                     const node = this.graph.create('prim::CreateObject');
+                    node.setSourceRange(location);
                     this.graph.insertNode(node);
                     const value = node.addOutput();
-                    value.setType(type.__type__);
+                    value.setType(type);
                     return value;
                 }
             }
         }
+        /*
         if (name === '__init__') {
             const obj = this.expression(target, context);
             if (args.length === 0) {
@@ -2625,6 +2634,7 @@ pytorch.Execution = class extends python.Execution {
             value.setType(obj.type());
             return value;
         }
+        */
         const overload = this._overload(target, name, args, context);
         if (!overload) {
             const moduleTarget = this.target(target, context);
@@ -2639,6 +2649,35 @@ pytorch.Execution = class extends python.Execution {
                 }
                 return node.addOutput();
             }
+            const prefix = pytorch.Utility.target(target);
+            if (prefix && prefix !== 'self' && !prefix.startsWith('self.') && prefix.indexOf('.') !== -1) {
+                const identifier = `${prefix}.${name}`;
+                const type = this._resolver.resolveType(identifier);
+                if (type instanceof torch.TupleType) {
+                    const node = this._graph.create('prim::TupleConstruct');
+                    node.setSourceRange(location);
+                    this.graph.insertNode(node);
+                    const evalArgs = args.map((expression) => this.expression(expression, context));
+                    for (const arg of evalArgs) {
+                        const value = this.variable(arg);
+                        node.addInput(value);
+                    }
+                    const output = node.addOutput();
+                    output.setType(type);
+                    return output;
+                }
+                if (type instanceof torch.ClassType) {
+                    const node = this.graph.create('prim::CallMethod');
+                    this.graph.insertNode(node);
+                    node.s_('name', name);
+                    const evalArgs = args.map((expression) => this.expression(expression, context));
+                    for (const arg of evalArgs) {
+                        const value = this.variable(arg);
+                        node.addInput(value);
+                    }
+                    return node.addOutput();
+                }
+            }
             return super.call(target, name, args, context);
         }
         const [schema, evalArgs] = overload;
@@ -2734,6 +2773,9 @@ pytorch.Execution = class extends python.Execution {
                 } else {
                     const value = this.variable(v);
                     value.value = v;
+                    if (!value.type() && v instanceof this.builtins.dict) {
+                        value.setType(type);
+                    }
                     input = value;
                     match = true;
                 }
@@ -3063,6 +3105,9 @@ pytorch.Execution = class extends python.Execution {
                             return true;
                         }
                     }
+                    if (obj instanceof this.builtins.dict) {
+                        return true;
+                    }
                     return false;
                 }
                 // throw new pytorch.Error(`Unknown type '${type}'.`);
@@ -3392,6 +3437,7 @@ pytorch.Utility = class {
             case 'ScalarTypeType': return `ScalarType`;
             case 'MemoryFormat': return `MemoryFormat`;
             case 'Layout': return `Layout`;
+            case 'VarType': return type.annotation_str;
             default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`);
         }
     }

+ 9 - 12
test/models.json

@@ -166,6 +166,13 @@
     "error":    "Invalid file content. File contains Python source code.",
     "link":     "https://github.com/lutzroeder/netron/issues/458"
   },
+  {
+    "type":     "_",
+    "target":   "pytorch_invalid_file.pth",
+    "source":   "https://github.com/lutzroeder/netron/files/3269093/pytorch_invalid_file.zip[pytorch_invalid_file.pth]",
+    "error":    "Could not find end of line.",
+    "link":     "https://github.com/lutzroeder/netron/issues/720"
+  },
   {
     "type":     "_",
     "target":   "random.onnx",
@@ -5348,7 +5355,7 @@
     "type":     "pytorch",
     "target":   "fasterrcnn_resnet50_fpn.pt",
     "source":   "https://github.com/lutzroeder/netron/files/7677467/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]",
-    "error":    "Unknown function 'aten::__contains__'.",
+    "format":   "TorchScript v1.7",
     "link":     "https://github.com/lutzroeder/netron/issues/689"
   },
   {
@@ -5512,7 +5519,6 @@
     "target":   "mask_model.pt",
     "source":   "https://github.com/lutzroeder/netron/files/10080302/mask_model.pt.zip[mask_model.pt]",
     "format":   "TorchScript v1.7",
-    "error":    "Unsupported type expression '.'.",
     "link":     "https://github.com/lutzroeder/netron/issues/842"
   },
   {
@@ -5866,16 +5872,9 @@
     "type":     "pytorch",
     "target":   "pyg_model.pt",
     "source":   "https://github.com/lutzroeder/netron/files/10369483/pyg_model.zip[pyg_model.pt]",
-    "error":    "Unknown function 'aten::linear'.",
+    "format":   "TorchScript v1.7",
     "link":     "https://github.com/lutzroeder/netron/issues/546"
   },
-  {
-    "type":     "pytorch",
-    "target":   "pytorch_invalid_file.pth",
-    "source":   "https://github.com/lutzroeder/netron/files/3269093/pytorch_invalid_file.zip[pytorch_invalid_file.pth]",
-    "error":    "Could not find end of line.",
-    "link":     "https://github.com/lutzroeder/netron/issues/133"
-  },
   {
     "type":     "pytorch",
     "target":   "quant_3d.pt",
@@ -5908,7 +5907,6 @@
     "type":     "pytorch",
     "target":   "rcnn.pt",
     "source":   "https://github.com/lutzroeder/netron/files/9035740/rcnn.pt.zip[rcnn.pt]",
-    "error":    "value.uses is not a function",
     "link":     "https://github.com/lutzroeder/netron/issues/842"
   },
   {
@@ -6368,7 +6366,6 @@
     "target":   "transformer.pt",
     "source":   "https://github.com/lutzroeder/netron/files/10271969/transformer.pt.zip[transformer.pt]",
     "format":   "TorchScript v1.6",
-    "error":    "value.type is not a function\nUnknown type name 'torch.all'.",
     "link":     "https://github.com/lutzroeder/netron/issues/842"
   },
   {

+ 18 - 5
tools/pytorch_script.py

@@ -84,6 +84,22 @@ known_legacy_schema_definitions = [
     'neuron::forward_v2_1(Tensor[] _0, __torch__.torch.classes.neuron.Model _1) -> (Tensor _0)',
     'prim::shape(Tensor self) -> int[]',
     'torchaudio::sox_effects_apply_effects_tensor(Tensor tensor, int sample_rate, str[][] effects, bool channels_first=True) -> (Tensor, int)',
+    'torch_scatter::gather_coo(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0',
+    'torch_scatter::segment_max_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> (Tensor _0, Tensor _1)',
+    'torch_scatter::segment_min_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> (Tensor _0, Tensor _1)',
+    'torch_scatter::segment_mean_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> Tensor _0',
+    'torch_scatter::segment_sum_coo(Tensor _0, Tensor _1, Tensor? _2, int? _3) -> Tensor _0',
+    'torch_scatter::gather_csr(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0',
+    'torch_scatter::segment_max_csr(Tensor _0, Tensor _1, Tensor? _2) -> (Tensor _0, Tensor _1)',
+    'torch_scatter::segment_min_csr(Tensor _0, Tensor _1, Tensor? _2) -> (Tensor _0, Tensor _1)',
+    'torch_scatter::segment_mean_csr(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0',
+    'torch_scatter::segment_sum_csr(Tensor _0, Tensor _1, Tensor? _2) -> Tensor _0',
+    'torch_scatter::scatter_max(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> (Tensor _0, Tensor _1)',
+    'torch_scatter::scatter_min(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> (Tensor _0, Tensor _1)',
+    'torch_scatter::scatter_mean(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> Tensor _0',
+    'torch_scatter::scatter_mul(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> Tensor _0',
+    'torch_scatter::scatter_sum(Tensor _0, Tensor _1, int _2, Tensor? _3, int? _4) -> Tensor _0',
+    'torch_scatter::cuda_version() -> int _0',
     'torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor',
     'torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor',
 ]
@@ -121,6 +137,8 @@ def _parse_schemas():
 
 def _filter_schemas(schemas, types):
     names = set(map(lambda _: _.split('.')[0], types.keys()))
+    for key in known_legacy_schema_definitions:
+        names.add(re.sub(r'[\.(].*$', '', key))
     filtered_schemas = set()
     for schema in schemas.values():
         for name in names:
@@ -144,11 +162,6 @@ def _check_types(types, schemas):
             types.pop(key)
         if key.startswith('_caffe2::'):
             types.pop(key)
-    known_keys = [
-        'aten::classes._nnapi.Compilation'
-    ]
-    for key in known_keys:
-        types.pop(key)
     if len(types) > 0:
         raise Exception('\n'.join(list(types.keys()))) # pylint: disable=broad-exception-raised