Lutz Roeder 1 anno fa
parent
commit
c23cc3251c
5 ha cambiato i file con 391 aggiunte e 175 eliminazioni
  1. 210 66
      source/python.js
  2. 33 0
      source/pytorch-metadata.json
  3. 137 104
      source/pytorch.js
  4. 4 3
      source/view.js
  5. 7 2
      test/models.json

+ 210 - 66
source/python.js

@@ -4967,6 +4967,35 @@ python.Execution = class {
                 }
                 return this[name];
             }
+            __delattr__(name) {
+                if (this._modules.has(name)) {
+                    this._modules.delete(name);
+                }
+            }
+            children() {
+                return this._modules.values();
+            }
+            named_children() {
+                return this._modules;
+            }
+            parameters() {
+                return this._parameters.values();
+            }
+            named_parameters(recurse) {
+                if (recurse) {
+                    throw new python.Error('Named parameters with recurse not implemented.');
+                }
+                return this._parameters;
+            }
+            buffers() {
+                return this._buffers.values();
+            }
+            named_buffers(recurse) {
+                if (recurse) {
+                    throw new python.Error('Named parameters with recurse not implemented.');
+                }
+                return this._buffers;
+            }
         });
         torch.nn.Module = torch.nn.modules.module.Module;
         torch.nn.modules.Module = torch.nn.modules.module.Module;
@@ -6210,14 +6239,14 @@ python.Execution = class {
         });
         this.registerFunction('torch._utils._rebuild_tensor_v3');
         this.registerFunction('torch._utils._rebuild_parameter', (data, requires_grad, backward_hooks) => {
-            const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]);
+            const param = new torch.nn.parameter.Parameter(data, requires_grad);
             param.backward_hooks = backward_hooks;
             return param;
         });
         this.registerFunction('torch._utils._rebuild_parameter_v2', (data, requires_grad, backward_hooks, state) => {
-            const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]);
+            const param = new torch.nn.parameter.Parameter(data, requires_grad);
             param.backward_hooks = backward_hooks;
-            execution.invoke('torch._utils._set_obj_state', [param, state]);
+            torch._utils._set_obj_state(param, state);
             return param;
         });
         this.registerFunction('torch._utils._rebuild_parameter_with_state', (data, requires_grad, backward_hooks, state) => {
@@ -6225,16 +6254,16 @@ python.Execution = class {
                 const [dict_state, slots_state] = Array.isArray(state) ? state : [state, null];
                 if (dict_state) {
                     for (const [k, v] of Object.entries(dict_state)) {
-                        self.invoke('builtins.setattr', [obj, k, v]);
+                        builtins.setattr(obj, k, v);
                     }
                 }
                 if (slots_state) {
                     for (const [k, v] of Object.entries(slots_state)) {
-                        self.invoke('builtins.setattr', [obj, k, v]);
+                        builtins.setattr(obj, k, v);
                     }
                 }
             };
-            const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]);
+            const param = new torch.nn.parameter.Parameter(data, requires_grad);
             param._backward_hooks = backward_hooks;
             _set_obj_state(param, state);
             return param;
@@ -6255,12 +6284,12 @@ python.Execution = class {
             }
             if (dict_state) {
                 for (const [name, value] of Object.entries(dict_state)) {
-                    execution.invoke('builtins.setattr', [obj, name, value]);
+                    builtins.setattr(obj, name, value);
                 }
             }
             if (slots_state) {
                 for (const [name, value] of Object.entries(slots_state)) {
-                    execution.invoke('builtins.setattr', [obj, name, value]);
+                    builtins.setattr(obj, name, value);
                 }
             }
             return obj;
@@ -6374,6 +6403,9 @@ python.Execution = class {
             return value;
         });
         this.registerFunction('torch.clear', (value) => {
+            if (value instanceof torch.Value) {
+                throw new python.Error('Invalid value.');
+            }
             if (Object(value) === value) {
                 for (const key of Object.keys(value)) {
                     delete value[key];
@@ -6958,6 +6990,9 @@ python.Execution = class {
                 }
                 return this.equals(rhs);
             }
+            is_module() {
+                return false;
+            }
             expect(type) {
                 if (this instanceof type === false) {
                     throw new python.Error(`Expected '${type.kind()}' but got '${this.kind()}'.`);
@@ -7004,6 +7039,12 @@ python.Execution = class {
             is_module() {
                 return this._is_module;
             }
+            is_parameter(slot) {
+                return this._attributes[slot].is_parameter === true;
+            }
+            is_buffer(slot) {
+                return this._attributes[slot].is_buffer === true;
+            }
             addMethod(func) {
                 this._methods.set(func.name(), func);
             }
@@ -7023,6 +7064,9 @@ python.Execution = class {
             findStaticMethod(name) {
                 return this._staticmethods.get(name);
             }
+            numAttributes() {
+                return this._attributes.length;
+            }
             addAttribute(name, type, is_parameter, is_buffer) {
                 is_parameter = is_parameter || false;
                 is_buffer = is_buffer || false;
@@ -7057,10 +7101,16 @@ python.Execution = class {
                 }
                 return null;
             }
-            getAttribute(name) {
-                const slot = this.findAttributeSlot(name);
+            hasAttribute(name) {
+                return this._attributes.find((attr) => attr.name === name);
+            }
+            getAttribute(arg) {
+                const slot = Number.isInteger(arg) ? arg : this.findAttributeSlot(arg);
                 return this._attributes[slot].type;
             }
+            getAttributeName(slot) {
+                return this._attributes[slot].name;
+            }
             hasConstant(/* name */) {
             }
             methods() {
@@ -8131,7 +8181,10 @@ python.Execution = class {
             parseBroadcastList(/* expr */) {
                 return null;
             }
-
+            parseType(str) {
+                const expr = ast.parse(str);
+                return this.parseTypeFromExpr(expr.body[0]);
+            }
         });
         this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase {
             constructor(overloadpacket, op, op_dk, schema, tags) {
@@ -8874,8 +8927,9 @@ python.Execution = class {
                 this._loaded_sources = new Set();
                 this._to_be_defined = new Map();
             }
-            loadType(/* name */) {
-                //
+            loadType(name) {
+                const type_parser = new torch.jit.ScriptTypeParser(this);
+                return type_parser.parseType(name.qualifiedName());
             }
             resolveType(name) {
                 name = new torch.jit.QualifiedName(name);
@@ -9120,12 +9174,32 @@ python.Execution = class {
                 for (let i = 0; i < constants.length; i++) {
                     execution.builtins.CONSTANTS[`c${i}`] = constants[i];
                 }
-                const module = this.readArchive('data');
-                const name = `${module.__class__.__module__}.${module.__class__.__name__}`;
-                const type = torch.ClassType.create(name, null, true);
-                const result = new torch.ScriptModule(type);
-                result.data = module;
-                return result;
+                const obj = this.readArchive('data');
+                const convertModule = (obj) => {
+                    if (obj.__class__) {
+                        const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
+                        const type = this._source_importer.loadType(new torch.jit.QualifiedName(name));
+                        const module = new torch.ScriptModule(type, this._compilation_unit);
+                        for (let i = 0; i < type.numAttributes(); i++) {
+                            const k = type.getAttributeName(i);
+                            const t = type.getAttribute(i);
+                            const v = obj[k];
+                            if (t.is_module()) {
+                                module.__setattr__(k, convertModule(v));
+                            } else {
+                                module.__setattr__(k, obj[k]);
+                            }
+                        }
+                        for (const [key, value] of Object.entries(Object.getPrototypeOf(obj))) {
+                            if (value && value.__class__ === builtins.method) {
+                                module[key] = value;
+                            }
+                        }
+                        return module;
+                    }
+                    throw new python.Error('Module class not found.');
+                };
+                return convertModule(obj);
             }
             LEGACY_deserialize() {
                 const execution = this._compilation_unit.execution;
@@ -9186,8 +9260,8 @@ python.Execution = class {
                 for (const tensor of tensor_table) {
                     this._constant_table.push(tensor);
                 }
-                const temp = this.LEGACY_convertModule(module_def);
-                const data = obj.mainModule || {};
+                return this.LEGACY_convertModule(module_def);
+                /* const data = obj.mainModule || {};
                 const queue = [data];
                 while (queue.length > 0) {
                     const module = queue.shift();
@@ -9237,6 +9311,8 @@ python.Execution = class {
                 const result = new torch.ScriptModule(temp.type());
                 result.data = data;
                 return result;
+                return module;
+                */
             }
             LEGACY_convertModule(module_def) {
                 const atoms = new torch.jit.QualifiedName(module_def.name).atoms();
@@ -9245,13 +9321,14 @@ python.Execution = class {
                     const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom;
                     this._LEGACY_moduleStack.push(sanitized);
                 }
-                const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit);
+                const qn = new torch.jit.QualifiedName(this._LEGACY_moduleStack);
+                const module = new torch.ScriptModule(qn, this._compilation_unit);
                 for (const sub_def of module_def.submodules || []) {
                     const submodule = this.LEGACY_convertModule(sub_def);
                     module.register_module(sub_def.name, submodule);
                 }
                 for (const param_def of module_def.parameters || []) {
-                    const tensor = this._constant_table[Number(param_def.tensorId)];
+                    const tensor = this._constant_table[Number(param_def.tensor_id)];
                     if (param_def.isBuffer) {
                         module.register_buffer(param_def.name, tensor);
                     } else {
@@ -9263,11 +9340,21 @@ python.Execution = class {
                     if (module.hasattr(attr_def.name)) {
                         continue;
                     }
+                    throw new python.Error('Not implemented.');
                     // IValue ivalue;
                     // if (attr_def.id() >= 0) {
                     //    ivalue = LEGACY_pickled_ivalues_.at(attr_def.id());
                     // }
-                    // module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
+                    // module.register_attribute(attr_def.name, typeParser.parseType(attr_def.type), ivalue);
+                }
+                if (module_def.torchscript_arena) {
+                    const key = module_def.torchscript_arena.key;
+                    const file = key.substring('code/'.length);
+                    const name = file.replace(/\.py$/, '').split('/').join('.');
+                    const code = execution.import(name);
+                    if (code.forward.__class__ === execution.builtins.function) {
+                        module.forward = code.forward;
+                    }
                 }
                 /*
                 std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
@@ -9299,9 +9386,13 @@ python.Execution = class {
                 return module;
             }
             readArchive(archive_name) {
-                const type_resolver = null;
-                const obj_loader = null;
-                return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, obj_loader, this._device, this._reader, null, this._storage_context);
+                const type_resolver = (qn) => {
+                    const cls = this._source_importer.loadType(qn);
+                    return cls;
+                };
+                const ObjLoaderFunc = (/* type, ivalue */) => {
+                };
+                return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, ObjLoaderFunc, this._device, this._reader, null, this._storage_context);
             }
             readArchiveAndTensors(archive_name, pickle_prefix, tensor_prefix, type_resolver, obj_loader, device, stream_reader, type_parser, storage_context) {
                 const picklename = `${pickle_prefix + archive_name}.pkl`;
@@ -9405,7 +9496,9 @@ python.Execution = class {
             const cu = new torch.jit.CompilationUnit();
             cu.execution = execution;
             const cpp_module = torch._C.import_ir_module(cu, file, map_location, extra_files);
-            return new torch.jit._script.RecursiveScriptModule(cpp_module);
+            const module = torch.jit._script.wrap_cpp_module(cpp_module);
+            module.forward = cpp_module.forward; // remove
+            return module;
         });
         this.registerFunction('torch._C.import_ir_module', function(cu, reader, ...args) {
             switch (arguments.length) {
@@ -9495,7 +9588,7 @@ python.Execution = class {
             const cu = new torch.jit.CompilationUnit();
             cu.execution = execution;
             const cpp_module = torch._C._import_ir_module_from_package(cu, importer.zip_reader, importer.storage_context, importer.last_map_location, script_module_id);
-            return new torch.jit._script.RecursiveScriptModule(cpp_module);
+            return torch.jit._script.wrap_cpp_module(cpp_module);
         });
         this.registerFunction('torch.jit._script.wrap_cpp_module', (cpp_module) => {
             const init_fn = (script_module) => {
@@ -9552,7 +9645,8 @@ python.Execution = class {
         });
         this.registerType('torch.ScriptObject', class {
             constructor(type) {
-                this._type = type;
+                this._typ = type;
+                this._ivalue = {};
             }
             static create(type) {
                 if (type.is_module()) {
@@ -9561,10 +9655,10 @@ python.Execution = class {
                 return new torch.ScriptObject(type);
             }
             type() {
-                return this._type;
+                return this._typ;
             }
             _type() {
-                return this._type; // torch.ClassType
+                return this._typ; // torch.ClassType
             }
             _get_method(name) {
                 for (const fn of this._type.methods()) {
@@ -9579,13 +9673,16 @@ python.Execution = class {
             }
             __setattr__(name, value) {
                 // if (this._type.hasContant(name))
-                this[name] = value;
+                this._ivalue[name] = value;
             }
             __getattr__(name) {
-                return this[name];
+                return this._ivalue[name];
             }
             hasattr(name) {
-                return this._type.hasAttribute(name) || this._type.hasConstant(name);
+                return this._typ.hasAttribute(name) || this._typ.hasConstant(name);
+            }
+            getattr(name) {
+                return this.__getattr__(name);
             }
             _properties() {
                 throw new python.Error();
@@ -9601,7 +9698,7 @@ python.Execution = class {
                 }
             }
             get qualified_name() {
-                return this._type.qualified_name();
+                return this.type().qualified_name();
             }
             get code_with_constants() {
                 const const_map = {};
@@ -9622,25 +9719,22 @@ python.Execution = class {
                                 return false;
                         }
                     };
-                    if (!this.data) {
+                    if (!this.forward) {
                         return null;
                     }
-                    if (!this.data.forward) {
-                        throw new python.Error("Module 'forward' not implemented.");
-                    }
                     execution.traceAttr = false;
                     const args = [];
                     if (!execution.traceAttr) {
-                        args.push(this.data); // self
+                        args.push(this); // self
                     }
-                    if (this.data.forward.__code__ && this.data.forward.__code__.args) {
-                        const params = this.data.forward.__code__.args.args;
+                    if (this.forward.__code__ && this.forward.__code__.args) {
+                        const params = this.forward.__code__.args.args;
                         for (let i = 0; i < params.length; i++) {
                             const arg = params[i];
                             if (execution.traceAttr || arg.arg !== 'self') {
                                 const value = execution.graph.addInput(arg.arg);
                                 if (i === 0 && arg.arg === 'self' && !arg.annotation) {
-                                    value.setType(this._type);
+                                    value.setType(this.type());
                                 } else {
                                     value.setType(execution.type(arg.annotation));
                                 }
@@ -9653,7 +9747,7 @@ python.Execution = class {
                         }
                     }
                     execution.purge = new Set();
-                    const result = this.data.forward.__call__(args);
+                    const result = this.forward.__call__(args);
                     const queue = Array.from(execution.purge);
                     const visited = new Set();
                     while (queue.length > 0) {
@@ -9715,15 +9809,15 @@ python.Execution = class {
             }
             register_module(name, module) {
                 this.type().addOrCheckAttribute(name, module.type());
-                // _ivalue()->setAttr(name, module._ivalue());
+                this.__setattr__(name, module); // _ivalue()->setAttr(name, module._ivalue());
             }
-            register_buffer(name /* , v */) {
+            register_buffer(name, v) {
                 this.type().addOrCheckAttribute(name, torch.TensorType.get(), false, true);
-                // _ivalue()->setAttr(name, std::move(v));
+                this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v));
             }
             register_parameter(name, v, is_buffer) {
                 this.type().addOrCheckAttribute(name, torch.TensorType.get(), !is_buffer, is_buffer);
-                // _ivalue()->setAttr(name, std::move(v));
+                this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v));
             }
             register_attribute(name, t, v, is_param, is_buffer) {
                 this.type().addOrCheckAttribute(name, t, is_param, is_buffer);
@@ -9731,11 +9825,59 @@ python.Execution = class {
             }
         });
         this.registerType('torch.ModuleDict', class {
-            constructor(module) {
-                this._items = Object.entries(module).filter(([, value]) => value instanceof torch.ScriptModule);
+            constructor(mod) {
+                this._module = mod;
+            }
+            items() {
+                const result = new Map();
+                const type = this._module.type();
+                for (let i = 0; i < type.numAttributes(); i++) {
+                    const k = type.getAttributeName(i);
+                    const t = type.getAttribute(i);
+                    if (t && t.is_module()) {
+                        result.set(k, this._module.__getattr__(k));
+                    }
+                }
+                return result;
+            }
+        });
+        this.registerType('torch.ParameterDict', class {
+            constructor(mod) {
+                this._module = mod;
+            }
+            items() {
+                const result = new Map();
+                const type = this._module.type();
+                for (let i = 0; i < type.numAttributes(); i++) {
+                    if (type.is_parameter(i)) {
+                        const k = type.getAttributeName(i);
+                        const v = this._module.__getattr__(k);
+                        if (v instanceof torch.Tensor) {
+                            result.set(k, v);
+                        }
+                    }
+                }
+                return result;
+            }
+        });
+        this.registerType('torch.BufferDict', class {
+            constructor(mod) {
+                this._module = mod;
             }
             items() {
-                return this._items;
+                const result = new Map();
+                const type = this._module.type();
+                for (let i = 0; i < type.numAttributes(); i++) {
+                    if (type.is_buffer(i)) {
+                        const t = type.getAttribute(i);
+                        if (t.isSubtypeOf(torch.TensorType.get())) {
+                            const k = type.getAttributeName(i);
+                            const v = this._module.__getattr__(k);
+                            result.set(k, v);
+                        }
+                    }
+                }
+                return result;
             }
         });
         this.registerType('torch.jit.to_ir', class {
@@ -9846,11 +9988,11 @@ python.Execution = class {
                 torch.jit._script.RecursiveScriptModule._finalize_scriptmodule(script_module);
                 return script_module;
             }
-            static _finalize_scriptmodule() {
-                this._initializing = false;
-            }
-            get data() {
-                return this._c.data;
+            static _finalize_scriptmodule(script_module) {
+                script_module._parameters = new torch.ParameterDict(script_module._c).items();
+                script_module._buffers = new torch.BufferDict(script_module._c).items();
+                // script_module._modules = OrderedModuleDict(script_module._c, script_module._modules)
+                script_module._initializing = false;
             }
             get graph() {
                 // return this._c._get_method("forward").graph;
@@ -9863,8 +10005,8 @@ python.Execution = class {
             __setattr__(name, value) {
                 if (this._initializing) {
                     super.__setattr__(name, value);
-                } else if (this.modules.has(name)) {
-                    this.modules.set(name, value);
+                } else if (this._modules.has(name)) {
+                    this._modules.set(name, value);
                 } else if (this._c.hasattr(name)) {
                     this._c.setattr(name, value);
                 } else {
@@ -9875,8 +10017,8 @@ python.Execution = class {
                 if (this._initializing) {
                     return super.__getattr__(name);
                 }
-                if (this.modules.has(name)) {
-                    return this.modules.get(name);
+                if (this._modules.has(name)) {
+                    return this._modules.get(name);
                 }
                 if (this._c.hasattr(name)) {
                     return this._c.getattr(name);
@@ -11646,10 +11788,7 @@ python.Execution = class {
         this.registerType('torch.nn.parameter.Parameter', class extends torch.Tensor {
             constructor(data, requires_grad) {
                 super();
-                if (!data) {
-                    data = self.invoke('torch.Tensor', [[]]);
-                }
-                this.data = data;
+                this.data = data || new torch.Tensor([]);
                 this.requires_grad = requires_grad === undefined ? true : requires_grad;
             }
         });
@@ -12544,7 +12683,12 @@ python.Execution = class {
         if (path) {
             let target = null;
             for (let i = path.length - 1; i >= 0; i--) {
-                target = target ? target[path[i]] : context.get(path[i]);
+                const name = path[i];
+                if (target) {
+                    target = target.__getattr__ ? target.__getattr__(name) : target[name];
+                } else {
+                    target = context.get(name);
+                }
                 if (!target) {
                     break;
                 }

+ 33 - 0
source/pytorch-metadata.json

@@ -1514,6 +1514,27 @@
   {
     "name": "aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)"
   },
+  {
+    "name": "aten::clear.Tensor(Dict(Tensor, t)(a!) self) -> ()"
+  },
+  {
+    "name": "aten::clear.bool(Dict(bool, t)(a!) self) -> ()"
+  },
+  {
+    "name": "aten::clear.complex(Dict(complex, t)(a!) self) -> ()"
+  },
+  {
+    "name": "aten::clear.float(Dict(float, t)(a!) self) -> ()"
+  },
+  {
+    "name": "aten::clear.int(Dict(int, t)(a!) self) -> ()"
+  },
+  {
+    "name": "aten::clear.str(Dict(str, t)(a!) self) -> ()"
+  },
+  {
+    "name": "aten::clear.t(t[](a!) self) -> ()"
+  },
   {
     "name": "aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"
   },
@@ -4091,6 +4112,18 @@
   {
     "name": "aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)"
   },
+  {
+    "name": "aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=1, SymInt ignore_index=-100) -> Tensor"
+  },
+  {
+    "name": "aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=1, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)"
+  },
+  {
+    "name": "aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=1, SymInt ignore_index=-100) -> Tensor"
+  },
+  {
+    "name": "aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=1, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)"
+  },
   {
     "name": "aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=1, SymInt ignore_index=-100) -> Tensor"
   },

+ 137 - 104
source/pytorch.js

@@ -89,8 +89,7 @@ pytorch.Graph = class {
             return values.get(name);
         };
         const torch = execution ? execution.torch : null;
-        // type = module && module.__class__ && module.__class__.__module__ && module.__class__.__name__ ? `${module.__class__.__module__}.${module.__class__.__name__}` : null;
-        if (torch && (module instanceof torch.ScriptModule || module instanceof torch.jit._script.ScriptModule || module instanceof torch.jit._script.RecursiveScriptModule) && module.graph) {
+        if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module.graph) {
             const initializers = new Map();
             const graph = module.graph;
             const constants = module.code_with_constants[1].const_mapping;
@@ -106,10 +105,36 @@ pytorch.Graph = class {
                     }
                 }
             }
-            const queue = [module.data];
+            const queue = [module];
             while (queue.length > 0) {
                 const module = queue.shift();
-                for (const [key, obj] of Object.entries(module)) {
+                const children = module.named_children();
+                for (const [key, obj] of children) {
+                    obj.__parent__ = module;
+                    obj.__name__ = obj.__name__ || key;
+                    queue.push(obj);
+                    const type = obj._c._type();
+                    for (let i = 0; i < type.numAttributes(); i++) {
+                        const k = type.getAttributeName(i);
+                        const v = obj.__getattr__(k);
+                        if (pytorch.Utility.isObject(v)) {
+                            initializers.set(v, v);
+                        }
+                    }
+                }
+                for (const buffer of module.buffers()) {
+                    buffer.__parent__ = module;
+                    if (buffer.storage() && !buffer.__origin__ && (buffer.__count__ === undefined || buffer.__count__ === 1)) {
+                        initializers.set(buffer, new pytorch.Tensor(buffer.name, buffer));
+                    }
+                }
+                for (const parameter of module.parameters()) {
+                    parameter.__parent__ = module;
+                    if (parameter.storage() && !parameter.__origin__ && (parameter.__count__ === undefined || parameter.__count__ === 1)) {
+                        initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter));
+                    }
+                }
+                for (const [key, obj] of children) {
                     if (key !== '__module__' && key !== '__name__' && key !== '__class__' && key !== '__parent__') {
                         if (!Array.isArray(obj) && obj === Object(obj)) {
                             if (pytorch.Utility.isTensor(obj)) {
@@ -185,31 +210,19 @@ pytorch.Graph = class {
                 this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, values));
             }
             if (module) {
-                const queue = [module.data];
+                const queue = [module];
                 while (queue.length > 0) {
                     const module = queue.pop();
-                    if (module && !pytorch.Utility.isObject(module)) {
-                        if (!module.__hide__ && pytorch.Graph._getParameters(module).size > 0) {
-                            for (const [name, obj] of Object.entries(module)) {
-                                if ((obj && obj.__hide__) || (obj !== null && !pytorch.Utility.isTensor(obj)) && typeof obj !== 'boolean' && typeof obj !== 'number' && typeof obj !== 'string') {
-                                    delete module[name];
-                                }
+                    if (module) {
+                        const modules = Array.from(module.children());
+                        queue.push(...modules.reverse());
+                        if (!module.__hide__ && module.named_parameters().size > 0) {
+                            for (const [name] of module.named_children()) {
+                                module.__delattr__(name);
                             }
                             const node = new pytorch.Node(execution, metadata, null, null, module, initializers, values);
                             this.nodes.push(node);
                         }
-                        const modules = [];
-                        if (module.__class__ && module.__class__.__module__ && module.__class__.__name__) {
-                            for (const [key, value] of Object.entries(module)) {
-                                if (!key.startsWith('__') && value && value.__class__ && value.__class__.__module__ && value.__class__.__name__ && !pytorch.Utility.isTensor(value)) {
-                                    if (value instanceof torch.Value) {
-                                        continue;
-                                    }
-                                    modules.push(value);
-                                }
-                            }
-                        }
-                        queue.push(...modules.reverse());
                     }
                 }
             }
@@ -331,18 +344,6 @@ pytorch.Graph = class {
             }
         }
     }
-
-    static _getParameters(module) {
-        const parameters = new Map();
-        if (module && module.__class__.__module__ && module.__class__.__name__) {
-            for (const [key, value] of Object.entries(module)) {
-                if (pytorch.Utility.isTensor(value)) {
-                    parameters.set(key, value);
-                }
-            }
-        }
-        return parameters;
-    }
 };
 
 pytorch.Argument = class {
@@ -425,6 +426,9 @@ pytorch.Node = class {
                     values = Object.values(value);
                 } else if (pytorch.Utility.isTensor(value)) {
                     values = [value];
+                } else if (Array.isArray(value) && value.every((value) => pytorch.Utility.isTensor(value))) {
+                    values = value;
+                } else if (input instanceof torch.Value && input.type() instanceof torch.ListType && input.type().getElementType() instanceof torch.TensorType) {
                     if (input.node() &&
                         input.node().kind() === 'prim::ListConstruct' &&
                         input.uses().length === 1 &&
@@ -451,9 +455,15 @@ pytorch.Node = class {
                 }
             }
             if (module) {
-                const parameters = pytorch.Graph._getParameters(module);
-                parameters.delete('num_batches_tracked');
-                if (parameters.size === count && match) {
+                const tensors = new Map();
+                for (const [name, value] of module.named_parameters()) {
+                    tensors.set(name, value);
+                }
+                for (const [name, value] of module.named_buffers()) {
+                    tensors.set(name, value);
+                }
+                tensors.delete('num_batches_tracked');
+                if (tensors.size === count && match) {
                     module.__hide__ = true;
                 } else {
                     module = null;
@@ -703,8 +713,24 @@ pytorch.Node = class {
                 throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`);
             }
         } else {
+            if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
+                type = obj._c._type();
+                const target = {
+                    _modules: obj._modules,
+                    _parameters: obj._parameters,
+                    _buffers: obj._buffers,
+                };
+                for (let i = 0; i < type.numAttributes(); i++) {
+                    if (!type.is_parameter(i) && !type.is_buffer(i) && !type.getAttribute(i).is_module()) {
+                        const k = type.getAttributeName(i);
+                        target[k] = obj.__getattr__(k);
+                    }
+                }
+                type = obj._c.qualified_name;
+                obj = target;
+            }
             if (!type) {
-                if (torch && pytorch.Utility.isInstance(obj, 'torch.jit._script.RecursiveScriptModule') && obj._c && obj._c.qualified_name) {
+                if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
                     type = obj._c.qualified_name;
                 } else if (pytorch.Utility.isInstance(obj, 'builtins.function')) {
                     type = `${obj.__module__}.${obj.__name__}`;
@@ -751,17 +777,13 @@ pytorch.Node = class {
                     } else if (pytorch.Utility.isInstance(value, 'torch.Size') && Array.isArray(value) && value.length === 0) {
                         continue;
                     }
-                    const parameters = new Map();
-                    if ((name === '_parameters' || name === '_buffers') && value instanceof Map && value.size > 0) {
-                        for (const [name, obj] of Array.from(value)) {
-                            parameters.set(name, obj);
-                        }
-                    } else if (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor))) {
-                        parameters.set(name, value);
-                    } else if (pytorch.Utility.isTensor(value)) {
-                        parameters.set(name, value);
+                    let parameters = null;
+                    if ((name === '_parameters' || name === '_buffers') && value instanceof Map) {
+                        parameters = value;
+                    } else if (pytorch.Utility.isTensor(value) || (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor)))) {
+                        parameters = new Map([[name, value]]);
                     }
-                    if (parameters.size > 0) {
+                    if (parameters) {
                         for (const [name, value] of parameters) {
                             const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)];
                             const visible = inputs.has(name) ? inputs.get(name).visible || true : true;
@@ -786,7 +808,6 @@ pytorch.Node = class {
                         }
                         continue;
                     }
-                    const type = this.type.identifier;
                     if (pytorch.Utility.isTensor(value)) {
                         const tensor = new pytorch.Tensor('', value);
                         const argument = new pytorch.Argument(name, tensor, 'tensor');
@@ -811,14 +832,14 @@ pytorch.Node = class {
                         this.inputs.push(argument);
                     } else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') &&
                         value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) {
-                        const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
+                        const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
                             stack.add(value);
                             const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`;
-                            const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj);
+                            const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, values, stack);
                             stack.delete(value);
                             return node;
                         });
-                        const argument = new pytorch.Argument(name, values, 'object[]');
+                        const argument = new pytorch.Argument(name, list, 'object[]');
                         this.inputs.push(argument);
                     } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) {
                         const argument = new pytorch.Argument(name, value, 'attribute');
@@ -841,34 +862,30 @@ pytorch.Node = class {
                         const argument = new pytorch.Argument(name, node, 'object', visible);
                         this.inputs.push(argument);
                     } else {
-                        const createAttribute = (metadata, name, value) => {
-                            let visible = true;
-                            let type = 'attribute';
-                            metadata = name === 'training' ? { type: 'boolean', visible: false } : metadata;
-                            if (metadata) {
-                                if (metadata.type) {
-                                    type = metadata.type;
-                                }
-                                if (metadata.visible === false) {
-                                    visible = false;
-                                } else if (metadata.default !== undefined) {
-                                    if (Array.isArray(value)) {
-                                        if (Array.isArray(metadata.default)) {
-                                            visible = value.length !== metadata.default || !value.every((item, index) => item === metadata.default[index]);
-                                        } else {
-                                            visible = !value.every((item) => item === metadata.default);
-                                        }
+                        let schema = metadata.attribute(this.type.identifier, name);
+                        schema = name === 'training' ? { type: 'boolean', visible: false } : schema;
+                        let visible = true;
+                        let obj = value;
+                        const type = schema && schema.type ? schema.type : 'attribute';
+                        if (schema) {
+                            if (schema.visible === false) {
+                                visible = false;
+                            } else if (schema.default !== undefined) {
+                                if (Array.isArray(obj)) {
+                                    if (Array.isArray(schema.default)) {
+                                        visible = obj.length !== schema.default || !obj.every((item, index) => item === schema.default[index]);
                                     } else {
-                                        visible = value !== metadata.default;
+                                        visible = !obj.every((item) => item === schema.default);
                                     }
+                                } else {
+                                    visible = obj !== schema.default;
                                 }
                             }
-                            if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
-                                value = '?';
-                            }
-                            return new pytorch.Argument(name, value, type, visible);
-                        };
-                        const argument = createAttribute(metadata.attribute(type, name), name, value);
+                        }
+                        if (Array.isArray(obj) && obj.length > 0 && obj.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
+                            obj = '?';
+                        }
+                        const argument = new pytorch.Argument(name, obj, type, visible);
                         this.inputs.push(argument);
                     }
                 }
@@ -890,6 +907,7 @@ pytorch.Tensor = class {
 
     constructor(name, tensor) {
         this.name = name || '';
+        tensor = tensor.data ? tensor.data : tensor;
         const layout = tensor.layout ? tensor.layout.__str__() : null;
         const storage = tensor.storage();
         const size = tensor.size() || [];
@@ -1300,15 +1318,10 @@ pytorch.Container.Zip = class extends pytorch.Container {
         const version = reader.version();
         if (torchscript) {
             this.execution.trace = false;
-            const module = torch.jit.load(reader);
+            this.module = torch.jit.load(reader);
             this.execution.trace = true;
             metadata.register(this.execution);
-            if (module.data && module.data.forward) {
-                this.module = module;
-            } else {
-                torchscript = false;
-                this.module = module.data;
-            }
+            torchscript = this.module.forward;
         } else {
             const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]);
             const entries = new Map(records);
@@ -1353,9 +1366,15 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
             'version',
             ...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key)
         ];
-        if (this._model.mainModule.torchscriptArena && this._model.mainModule.torchscriptArena.key) {
-            keys.push(this._model.mainModule.torchscriptArena.key);
-        }
+        const walk = (module) => {
+            if (module.torchscriptArena && module.torchscriptArena.key) {
+                keys.push(module.torchscriptArena.key);
+            }
+            for (const submodule of module.submodules || []) {
+                walk(submodule);
+            }
+        };
+        walk(this._model.mainModule);
         const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null)));
         for (let i = 0; i < keys.length; i++) {
             if (values[i]) {
@@ -1374,14 +1393,9 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
         }
         this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
         this.execution.trace = false;
-        const module = torch.jit.load(reader);
+        this.module = torch.jit.load(reader);
         this.execution.trace = true;
         metadata.register(this.execution);
-        if (module.data && module.data.forward) {
-            this.module = module;
-        } else {
-            this.module = module.data;
-        }
         delete this._context;
         delete this._model;
         delete this._entries;
@@ -1767,7 +1781,12 @@ pytorch.Execution = class extends python.Execution {
         if (path) {
             let target = null;
             for (let i = path.length - 1; i >= 0; i--) {
-                target = target ? target[path[i]] : context.get(path[i]);
+                const name = path[i];
+                if (target) {
+                    target = target.__getattr__ ? target.__getattr__(name) : target[name];
+                } else {
+                    target = context.get(name);
+                }
                 if (!target) {
                     break;
                 }
@@ -2009,7 +2028,7 @@ pytorch.Execution = class extends python.Execution {
                     this._graph.insertNode(node);
                     return node.output();
                 }
-                return target[attr];
+                return target.__getattr__ ? target.__getattr__(attr) : target[attr];
             }
             case 'List': {
                 const list = expr.elts.map((item) => this.expression(item, context));
@@ -2047,6 +2066,8 @@ pytorch.Execution = class extends python.Execution {
                     let value = null;
                     if (elt instanceof torch.Value) {
                         value = elt;
+                    } else if (pytorch.Utility.isTensor(elt)) {
+                        value = this.variable(elt, null);
                     } else if (elt === null || Number.isInteger(elt) || typeof elt === 'number' || typeof elt === 'boolean' || typeof elt === 'string') {
                         value = this._graph.insertConstant(elt);
                     } else {
@@ -2143,7 +2164,7 @@ pytorch.Execution = class extends python.Execution {
             }
             case 'Attribute': {
                 const target = this.target(expr.value, context);
-                return target[expr.attr];
+                return target.__getattr__ ? target.__getattr__(expr.attr) : target[expr.attr];
             }
             case 'Call': {
                 const func = expr.func;
@@ -2528,7 +2549,9 @@ pytorch.Execution = class extends python.Execution {
     statement(stmt, context) {
         if (stmt.__class__.__name__ === 'ClassDef') {
             const name = `${context.get('__name__')}.${stmt.name}`;
-            this._resolver.resolveType(name);
+            if (this._resolver) {
+                this._resolver.resolveType(name);
+            }
         }
 
         if (!this.trace) {
@@ -2754,7 +2777,7 @@ pytorch.Execution = class extends python.Execution {
                 optional = true;
             }
             if (optional === true &&
-                (type instanceof torch.FloatType || type instanceof torch.BoolType || type instanceof torch.IntType || type instanceof torch.ComplexType || type.kind() === 'ScalarTypeType' || type instanceof torch.DeviceObjType || type.kind() === 'LayoutKind') &&
+                (type instanceof torch.FloatType || type instanceof torch.BoolType || type instanceof torch.IntType || type instanceof torch.ComplexType || type instanceof torch.TensorType || type.kind() === 'ScalarTypeType' || type instanceof torch.DeviceObjType || type.kind() === 'LayoutKind') &&
                 v instanceof torch.Value && v.type() instanceof torch.NoneType) {
                 position++;
                 input = v;
@@ -2912,7 +2935,8 @@ pytorch.Execution = class extends python.Execution {
                 case 'complex':
                 case 'bool':
                 case 'bool[]':
-                case 'Device': {
+                case 'Device':
+                case 'Layout': {
                     break;
                 }
                 case 't': {
@@ -3230,7 +3254,7 @@ pytorch.Execution = class extends python.Execution {
                     optional = true;
                 }
                 if (optional === true &&
-                    (type instanceof torch.FloatType || type instanceof torch.BoolType || type instanceof torch.IntType || type instanceof torch.ComplexType || type.kind() === 'ScalarTypeType' || type instanceof torch.DeviceObjType || type.kind() === 'LayoutKind') &&
+                    (type instanceof torch.FloatType || type instanceof torch.BoolType || type instanceof torch.IntType || type instanceof torch.ComplexType || type instanceof torch.TensorType || type.kind() === 'ScalarTypeType' || type instanceof torch.DeviceObjType || type.kind() === 'LayoutKind') &&
                     v instanceof torch.Value && v.type() instanceof torch.NoneType) {
                     position++;
                 } else if (!this.isType(v, type, arg.N) && v !== null) {
@@ -3315,8 +3339,8 @@ pytorch.Container.Package = class extends pytorch.Container {
         this.entries = entries;
     }
 
-    async read() {
-        this.execution = new python.Execution();
+    async read(metadata) {
+        this.execution = new pytorch.Execution(null, metadata);
         for (const event of this._events) {
             this.execution.on(event[0], event[1]);
         }
@@ -3504,8 +3528,17 @@ pytorch.Utility = class {
     }
 
     static weights(obj) {
-        const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
-        if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') {
+        let type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
+        if (type === 'torch.jit._script.RecursiveScriptModule') {
+            type = obj._c._type();
+            const target = {};
+            for (let i = 0; i < type.numAttributes(); i++) {
+                const k = type.getAttributeName(i);
+                target[k] = obj.__getattr__(k);
+            }
+            type = obj._c.qualified_name;
+            obj = target;
+        } else if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') {
             return null;
         }
         if (pytorch.Utility.isTensor(obj)) {

+ 4 - 3
source/view.js

@@ -3016,7 +3016,8 @@ view.PrimitiveView = class extends view.Expander {
                     break;
                 }
                 default: {
-                    let content = new view.Formatter(value, type).toString();
+                    const formatter = new view.Formatter(value, type);
+                    let content = formatter.toString();
                     if (content && content.length > 1000) {
                         content = `${content.substring(0, 1000)}\u2026`;
                     }
@@ -4359,7 +4360,6 @@ view.Formatter = class {
     }
 
     _format(value, type, quote) {
-
         if (value && value.__class__ && value.__class__.__module__ === 'builtins' && value.__class__.__name__ === 'type') {
             return `${value.__module__}.${value.__name__}`;
         }
@@ -4451,7 +4451,8 @@ view.Formatter = class {
         }
         this._values.add(value);
         let list = null;
-        const entries = Object.entries(value).filter(([name]) => !name.startsWith('__') && !name.endsWith('__'));
+        const map = value instanceof Map ? Array.from(value) : Object.entries(value);
+        const entries = map.filter(([name]) => typeof name === 'string' && !name.startsWith('__') && !name.endsWith('__'));
         if (entries.length === 1) {
             list = [this._format(entries[0][1], null, true)];
         } else {

+ 7 - 2
test/models.json

@@ -5301,6 +5301,7 @@
     "target":   "alexnet_traced.pt.zip",
     "source":   "https://github.com/lutzroeder/netron/files/6096602/alexnet_traced.pt.zip",
     "format":   "TorchScript v1.6",
+    "assert":   "model.graphs[0].nodes.length == 28",
     "link":     "https://github.com/lutzroeder/netron/issues/281"
   },
   {
@@ -5425,6 +5426,7 @@
     "target":   "coco128-yolov8n-seg_output.torchscript.ptl",
     "source":   "https://github.com/user-attachments/files/16091260/coco128-yolov8n-seg_output.torchscript.ptl.zip[coco128-yolov8n-seg_output.torchscript.ptl]",
     "format":   "TorchScript v1.6",
+    "assert":   "model.graphs[0].nodes[0].inputs[1].value.type.name == '__torch__.torch.classes.xnnpack.Conv2dOpContext'",
     "link":     "https://github.com/lutzroeder/netron/issues/1067"
   },
   {
@@ -5492,6 +5494,7 @@
     "target":   "cruise_go_vehicle_model.pt",
     "source":   "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/cruise_go_vehicle_model.pt",
     "format":   "TorchScript v1.0",
+    "assert":   "model.graphs[0].nodes.length == 73",
     "link":     "https://github.com/ApolloAuto/apollo"
   },
   {
@@ -5522,6 +5525,7 @@
     "target":   "deeplabv3_scripted.pt",
     "source":   "https://github.com/lutzroeder/netron/files/5604999/deeplabv3_scripted.pt.zip[deeplabv3_scripted.pt]",
     "format":   "TorchScript v1.6",
+    "assert":   "model.graphs[0].nodes.length == 478",
     "link":     "https://github.com/lutzroeder/netron/issues/630"
   },
   {
@@ -5747,6 +5751,7 @@
     "target":   "lane_scanning_vehicle_model.pt",
     "source":   "https://raw.githubusercontent.com/ApolloAuto/apollo/master/modules/prediction/data/lane_scanning_vehicle_model.pt",
     "format":   "TorchScript v1.0",
+    "assert":   "model.graphs[0].nodes.length == 121",
     "link":     "https://github.com/ApolloAuto/apollo"
   },
   {
@@ -6243,7 +6248,7 @@
     "target":   "resnet18.ot",
     "source":   "https://github.com/lutzroeder/netron/files/7664092/resnet18.ot.zip[resnet18.ot]",
     "format":   "TorchScript v1.0",
-    "assert":   "model.graphs[0].nodes[1].inputs[0].value[0].name == 'conv1|weight'",
+    "assert":   "model.graphs[0].nodes[0].inputs[0].value[0].name == 'conv1|weight'",
     "link":     "https://github.com/lutzroeder/netron/issues/686"
   },
   {
@@ -6381,7 +6386,7 @@
     "target":   "segmentor.pt",
     "source":   "https://github.com/lutzroeder/netron/files/7663953/segmentor.pt.zip[segmentor.pt]",
     "format":   "PyTorch v1.6",
-    "assert":   "model.graphs[0].nodes[0].inputs[0].value.inputs[0].value.type.name == '__torch__.___torch_mangle_1.Module'",
+    "assert":   "model.graphs[0].nodes[0].inputs[0].value[0].inputs[0].value[0].type.name == '__torch__.___torch_mangle_1.Module'",
     "link":     "https://github.com/lutzroeder/netron/issues/686"
   },
   {