2
0
Эх сурвалжийг харах

Add torch.export test file (#1211)

Lutz Roeder 2 сар өмнө
parent
commit
8557a7e1b1
3 өөрчлөгдсөн 246 нэмэгдсэн , 137 устгасан
  1. 31 2
      source/python.js
  2. 207 135
      source/pytorch.js
  3. 8 0
      test/models.json

+ 31 - 2
source/python.js

@@ -4578,7 +4578,11 @@ python.Execution = class {
             return obj.length;
         });
         this.registerFunction('builtins.setattr', (obj, name, value) => {
-            obj[name] = value;
+            if (obj && obj.__setattr__) {
+                obj.__setattr__(name, value);
+            } else {
+                obj[name] = value;
+            }
         });
         this.registerType('builtins.set', class extends Set {
             __contains__(item) {
@@ -8505,11 +8509,37 @@ python.Execution = class {
                 }
             }
         });
+        this.registerFunction('torch.fx.graph_module._copy_attr', (from_module, to_module, target) => {
+            const parts = target.split('.');
+            const field = parts.pop();
+            for (const item of parts) {
+                const f = builtins.getattr(from_module, item);
+                let t = builtins.getattr(to_module, item, null);
+                if (f === t) {
+                    return;
+                }
+                if (t === null) {
+                    t = new torch.nn.modules.module.Module();
+                    builtins.setattr(to_module, item, t);
+                }
+                from_module = f;
+                to_module = t;
+            }
+            const orig = builtins.getattr(from_module, field);
+            builtins.setattr(to_module, field, orig);
+        });
         this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module {
             constructor(root, graph, class_name) {
                 super();
                 this.__class__.__name__ = class_name || 'GraphModule';
                 this.graph = graph;
+                if (root instanceof torch.nn.modules.module.Module && graph && graph.nodes) {
+                    for (const node of graph.nodes) {
+                        if (node.op === 'get_attr' || node.op === 'call_module') {
+                            torch.fx.graph_module._copy_attr(root, this, node.target);
+                        }
+                    }
+                }
             }
         });
         torch.fx.Graph = torch.fx.graph.Graph;
@@ -19519,7 +19549,6 @@ python.Execution = class {
                 } else if (typ_ === 'as_layout') {
                     return torch._export.serde.serialize._SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout];
                 } else if (typ_ === 'as_graph') {
-                    // throw new python.Error('GraphArgument deserialization is not implemented.');
                     const context = this.save_graph_module();
                     context.__enter__();
                     this.deserialize_graph(value.graph);

+ 207 - 135
source/pytorch.js

@@ -11,9 +11,9 @@ const numpy = {};
 pytorch.ModelFactory = class {
 
     async match(context) {
-        const container = await pytorch.Container.open(context);
-        if (container) {
-            return context.set(container.type, container);
+        const reader = await pytorch.Reader.open(context);
+        if (reader) {
+            return context.set(reader.type, reader);
         }
         return null;
     }
@@ -42,7 +42,7 @@ pytorch.ModelFactory = class {
         });
         await target.read(metadata);
         if (!target.format || (!target.modules && !target.module)) {
-            throw new pytorch.Error("Container not implemented.");
+            throw new pytorch.Error("Reader not implemented.");
         }
         return new pytorch.Model(metadata, target);
     }
@@ -76,17 +76,17 @@ pytorch.Graph = class {
         this.outputs = [];
         this.name = name;
         this.type = type;
-        const values = new Map();
-        values.map = (name, type, tensor) => {
+        const context = new pytorch.Context(execution, metadata);
+        context.values.map = (name, type, tensor) => {
             if (tensor) {
                 return new pytorch.Value(name, type, null, tensor);
             }
-            if (!values.has(name)) {
-                values.set(name, new pytorch.Value(name, type, null, tensor));
+            if (!context.values.has(name)) {
+                context.values.set(name, new pytorch.Value(name, type, null, tensor));
             } else if (type || tensor) {
                 throw new pytorch.Error(`Duplicate value '${name}'.`);
             }
-            return values.get(name);
+            return context.values.get(name);
         };
         const torch = execution ? execution.torch : null;
         if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module._c._has_method('forward')) {
@@ -188,12 +188,12 @@ pytorch.Graph = class {
                 }
                 const identifier = pytorch.Utility.unique(v);
                 const name = v.debugName() || identifier;
-                const value = values.map(identifier);
+                const value = context.values.map(identifier);
                 this.inputs.push(new pytorch.Argument(name, [value]));
             }
             for (const value of graph.outputs()) {
                 const identifier = pytorch.Utility.unique(value);
-                this.outputs.push(new pytorch.Argument(identifier, [values.map(identifier)]));
+                this.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)]));
             }
             for (const node of graph.nodes()) {
                 if (deleted.has(node)) {
@@ -210,29 +210,15 @@ pytorch.Graph = class {
                         continue;
                     }
                 }
-                this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, values));
+                this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, context));
             }
         } else if (torch && module instanceof torch.export.exported_program.ExportedProgram && module.graph) {
             const exported_program = module;
             const graph = exported_program.graph;
+            const graph_module = exported_program.graph_module;
             const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters;
             const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers;
             const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants;
-            const values = new Map();
-            values.map = (obj) => {
-                if (!values.has(obj)) {
-                    let type = null;
-                    const val = obj.meta.get('val');
-                    if (val && val.dtype) {
-                        const dataType = val.dtype.__reduce__();
-                        const shape = new pytorch.TensorShape(val.shape);
-                        type = new pytorch.TensorType(dataType, shape);
-                    }
-                    const value = new pytorch.Value(obj.name, type);
-                    values.set(obj, value);
-                }
-                return values.get(obj);
-            };
             const nodes = new Map(graph.nodes.map((node) => [node.name, node]));
             for (const obj of graph.nodes) {
                 if (obj.op === 'placeholder') {
@@ -242,60 +228,40 @@ pytorch.Graph = class {
                         const tensor = parameter && parameter.data ? parameter.data : obj.meta.get('val');
                         const initializer = new pytorch.Tensor(key, tensor);
                         const value = new pytorch.Value(key, null, null, initializer);
-                        values.set(obj, value);
+                        context.values.set(obj, value);
                     } else if (inputs_to_buffers.has(obj.name)) {
                         const key = inputs_to_buffers.get(obj.name);
                         const buffer = exported_program.state_dict.get(key);
                         const tensor = buffer || obj.meta.get('val');
                         const initializer = new pytorch.Tensor(key, tensor);
                         const value = new pytorch.Value(key, null, null, initializer);
-                        values.set(obj, value);
+                        context.values.set(obj, value);
                     } else if (inputs_to_lifted_tensor_constants.has(obj.name)) {
                         const key = inputs_to_lifted_tensor_constants.get(obj.name);
                         const constant = exported_program.constants.get(key);
                         const tensor = constant && constant.data ? constant.data : obj.meta.get('val');
                         const initializer = new pytorch.Tensor(key, tensor);
                         const value = new pytorch.Value(key, null, null, initializer);
-                        values.set(obj, value);
+                        context.values.set(obj, value);
                     }
-                    if (obj.users.size > 1 && values.has(obj)) {
-                        const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values);
+                    if (obj.users.size > 1 && context.values.has(obj)) {
+                        const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, context);
                         this.nodes.push(node);
-                        values.set(obj, node.outputs[0].value[0]);
-                    }
-                }
-            }
-            for (const obj of graph.nodes) {
-                if (obj.op === 'placeholder') {
-                    continue;
-                }
-                if (obj.op === 'call_function') {
-                    if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
-                        continue;
-                    }
-                }
-                if (obj.op === 'output') {
-                    for (const output of obj.args) {
-                        if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
-                            continue;
-                        }
-                        const value = values.map(output);
-                        const argument = new pytorch.Argument(output.name, [value]);
-                        this.outputs.push(argument);
+                        context.values.set(obj, node.outputs[0].value[0]);
                     }
-                    continue;
                 }
-                const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, values);
-                this.nodes.push(node);
             }
+            context.graph(this, graph_module, false);
             for (const input_spec of exported_program.graph_signature.user_inputs) {
                 if (nodes.has(input_spec)) {
                     const node = nodes.get(input_spec);
-                    const value = values.map(node);
+                    const value = context.value(node);
                     const argument = new pytorch.Argument(input_spec, [value]);
                     this.inputs.push(argument);
                 }
             }
+        } else if (torch && module instanceof torch.fx.GraphModule && module.graph) {
+            context.graph(this, module, true);
         } else if (pytorch.Utility.isTensor(module)) {
             const node = new pytorch.Node(execution, metadata, null, type, { value: module });
             this.nodes.push(node);
@@ -311,7 +277,7 @@ pytorch.Graph = class {
                 const modules = Array.isArray(module) && module.every((module) => module && !pytorch.Utility.isTensor(module) && (module._modules !== undefined || module.__class__)) ? module : [module];
                 for (const module of modules) {
                     const type = this.type === 'weights' ? 'Weights' : null;
-                    const node = new pytorch.Node(execution, metadata, null, type, module, null, values);
+                    const node = new pytorch.Node(execution, metadata, null, type, module, null, context);
                     this.nodes.push(node);
                 }
             }
@@ -344,7 +310,7 @@ pytorch.Value = class Value {
 
 pytorch.Node = class {
 
-    constructor(execution, metadata, name, type, obj, initializers, values, stack) {
+    constructor(execution, metadata, name, type, obj, initializers, context, stack) {
         const torch = execution ? execution.torch : null;
         const builtins = execution ? execution.builtins : null;
         this.name = name || '';
@@ -390,11 +356,11 @@ pytorch.Node = class {
             const mapTensor = (value) => {
                 if (value.identifier && pytorch.Utility.isTensor(value.value)) {
                     const identifier = value.identifier;
-                    if (!values.has(identifier)) {
+                    if (!context.values.has(identifier)) {
                         const tensor = new pytorch.Tensor(identifier, value.value);
-                        values.set(identifier, new pytorch.Value(identifier, null, null, tensor));
+                        context.values.set(identifier, new pytorch.Value(identifier, null, null, tensor));
                     }
-                    return values.map(identifier);
+                    return context.values.map(identifier);
                 }
                 let initializer = null;
                 let identifier = value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
@@ -407,7 +373,7 @@ pytorch.Node = class {
                 if (initializer) {
                     return new pytorch.Value(identifier, null, null, initializer);
                 }
-                return values.map(identifier);
+                return context.values.map(identifier);
             };
             for (let i = 0; i < inputs.length; i++) {
                 const input = inputs[i];
@@ -423,20 +389,20 @@ pytorch.Node = class {
                 if (type && type instanceof torch.ClassType) {
                     const obj = input.value;
                     if (!array && initializers.has(obj)) {
-                        const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, values);
+                        const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context);
                         argument = new pytorch.Argument(name, node, 'object');
                     } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
-                        const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, values));
+                        const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context));
                         argument = new pytorch.Argument(name, node, 'object[]');
                     } else if (array && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => input.value)) {
-                        const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, values));
+                        const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, context));
                         argument = new pytorch.Argument(name, node, 'object[]');
                     } else if (input.value === undefined) {
                         const identifier = pytorch.Utility.unique(input);
-                        const value = values.map(identifier);
+                        const value = context.values.map(identifier);
                         argument = new pytorch.Argument(name, [value]);
                     } else {
-                        const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, values);
+                        const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, context);
                         argument = new pytorch.Argument(name, node, 'object');
                     }
                 } else if ((input.type() instanceof torch.TensorType || (input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.TensorType)) && pytorch.Utility.isTensor(input.value)) {
@@ -465,13 +431,13 @@ pytorch.Node = class {
                                     return value.value;
                                 }
                                 const identifier = pytorch.Utility.unique(value);
-                                return values.map(identifier);
+                                return context.values.map(identifier);
                             });
                             const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type());
                             argument = new pytorch.Argument(name, args, type);
                         } else {
                             const identifier = pytorch.Utility.unique(input);
-                            argument = new pytorch.Argument(name, [values.map(identifier)]);
+                            argument = new pytorch.Argument(name, [context.values.map(identifier)]);
                         }
                     } else if (input.type() instanceof torch.StringType && typeof input.value === 'string') {
                         argument = new pytorch.Argument(name, input.value, 'string');
@@ -485,7 +451,7 @@ pytorch.Node = class {
                         argument = new pytorch.Argument(name, null, 'attribute');
                     } else {
                         const identifier = pytorch.Utility.unique(input);
-                        const value = values.map(identifier);
+                        const value = context.values.map(identifier);
                         argument = new pytorch.Argument(name, [value]);
                     }
                 } else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) {
@@ -508,14 +474,14 @@ pytorch.Node = class {
                         if (initializer) {
                             return new pytorch.Value(identifier, null, null, initializer);
                         }
-                        return values.map(identifier);
+                        return context.values.map(identifier);
                     });
                     argument = new pytorch.Argument(name, args);
                 } else if (Array.isArray(input.value) && input.value.some((value) => value instanceof torch.Value)) {
                     const args = input.value.map((value) => {
                         if (value instanceof torch.Value) {
                             const identifier = pytorch.Utility.unique(value);
-                            return values.map(identifier);
+                            return context.values.map(identifier);
                         }
                         return value;
                     });
@@ -540,7 +506,7 @@ pytorch.Node = class {
                     output.uses()[0].user.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
                     list = output.uses()[0].user.outputs();
                 }
-                const args = list.map((output) => values.map(pytorch.Utility.unique(output)));
+                const args = list.map((output) => context.values.map(pytorch.Utility.unique(output)));
                 const argument = new pytorch.Argument(name, args);
                 this.outputs.push(argument);
             }
@@ -611,15 +577,24 @@ pytorch.Node = class {
                 for (const [name, arg] of args) {
                     const type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null;
                     if (arg instanceof torch.fx.node.Node) {
-                        const value = values.map(arg);
-                        const argument = new pytorch.Argument(name, [value]);
+                        let argument = null;
+                        if (arg.op === 'get_attr' && arg.users.size === 1) {
+                            const subgraph = context.function(arg);
+                            if (subgraph) {
+                                argument = new pytorch.Argument(name, subgraph, 'function');
+                            }
+                        }
+                        if (!argument) {
+                            const value = context.value(arg);
+                            argument = new pytorch.Argument(name, [value]);
+                        }
                         this.inputs.push(argument);
                     } else if (Array.isArray(arg) && arg.every((arg) => arg instanceof torch.fx.node.Node || arg === null)) {
-                        const list = arg.map((arg) => arg === null ? null : values.map(arg));
+                        const list = arg.map((arg) => arg === null ? null : context.value(arg));
                         const argument = new pytorch.Argument(name, list);
                         this.inputs.push(argument);
                     } else if (Array.isArray(arg)) {
-                        const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? values.map(arg) : arg);
+                        const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? context.value(arg) : arg);
                         const argument = new pytorch.Argument(name, list, type || 'attribute');
                         this.inputs.push(argument);
                     } else if (arg instanceof torch.dtype || arg instanceof torch.device || arg instanceof torch.layout || arg instanceof torch.memory_format) {
@@ -643,7 +618,7 @@ pytorch.Node = class {
                 }
                 for (let i = 0; i < outputs.length; i++) {
                     const node = outputs[i];
-                    const value = values.map(node);
+                    const value = context.value(node);
                     const name = schema && schema.returns && schema.returns[i] ? schema.returns[i].name || 'output' : 'output';
                     const argument = new pytorch.Argument(name, [value]);
                     this.outputs.push(argument);
@@ -671,19 +646,26 @@ pytorch.Node = class {
             } else if (obj.op === 'placeholder') {
                 this.type = { name: obj.op };
                 {
-                    const value = values.map(obj);
+                    const value = context.value(obj);
                     const argument = new pytorch.Argument('value', [value]);
                     this.inputs.push(argument);
                 }
                 {
-                    const value = values.map({ name: obj.name, meta: obj.meta });
+                    const node = new torch.fx.node.Node(null, obj.name);
+                    node.meta = obj.meta;
+                    const value = context.value(node);
                     const argument = new pytorch.Argument('value', [value]);
                     this.outputs.push(argument);
                 }
             } else if (obj.op === 'get_attr') {
                 this.type = { name: obj.op };
-                this.inputs.push(new pytorch.Argument('name', obj.target, 'string'));
-                const value = values.map(obj);
+                const subgraph = context.function(obj);
+                if (subgraph) {
+                    this.inputs.push(new pytorch.Argument('name', subgraph, 'function'));
+                } else {
+                    this.inputs.push(new pytorch.Argument('name', obj.target, 'string'));
+                }
+                const value = context.value(obj);
                 this.outputs.push(new pytorch.Argument('value', [value]));
             } else if (obj.op === 'root') {
                 this.type = { name: obj.op };
@@ -783,7 +765,7 @@ pytorch.Node = class {
                             const argument = new pytorch.Argument(name, args, null, visible);
                             this.inputs.push(argument);
                             if (value && value.__variable__) {
-                                const argument = new pytorch.Argument(name, [values.map(value.__variable__)]);
+                                const argument = new pytorch.Argument(name, [context.values.map(value.__variable__)]);
                                 this.outputs.push(argument);
                             }
                         }
@@ -816,7 +798,7 @@ pytorch.Node = class {
                         const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
                             stack.add(value);
                             const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`;
-                            const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, values, stack);
+                            const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, context, stack);
                             stack.delete(value);
                             return node;
                         });
@@ -829,7 +811,7 @@ pytorch.Node = class {
                         const list = value.filter((value) => !stack.has(value));
                         const nodes = list.map((value) => {
                             stack.add(value);
-                            const node = new pytorch.Node(execution, metadata, null, null, value, initializers, values, stack);
+                            const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
                             stack.delete(value);
                             return node;
                         });
@@ -837,7 +819,7 @@ pytorch.Node = class {
                         this.inputs.push(argument);
                     } else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) {
                         stack.add(value);
-                        const node = new pytorch.Node(execution, metadata, null, null, value, initializers, values, stack);
+                        const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
                         stack.delete(value);
                         const visible = name !== '_metadata' || !pytorch.Utility.isMetadataObject(value);
                         const argument = new pytorch.Argument(name, node, 'object', visible);
@@ -994,27 +976,119 @@ pytorch.TensorShape = class {
     }
 };
 
-pytorch.Container = class {
+pytorch.Context = class {
+
+    constructor(execution, metadata) {
+        this.execution = execution;
+        this.torch = execution ? execution.__import__('torch') : null;
+        this.metadata = metadata;
+        this.values = new Map();
+        this.modules = new Map();
+    }
+
+    value(obj) {
+        const torch = this.torch;
+        if (obj instanceof torch.fx.node.Node) {
+            if (!this.values.has(obj)) {
+                let type = null;
+                const val = obj.meta ? obj.meta.get('val') : null;
+                if (val && val.dtype) {
+                    const dataType = val.dtype.__reduce__();
+                    const shape = new pytorch.TensorShape(val.shape);
+                    type = new pytorch.TensorType(dataType, shape);
+                }
+                const value = new pytorch.Value(obj.name, type);
+                this.values.set(obj, value);
+            }
+            return this.values.get(obj);
+        }
+        return null;
+    }
+
+    function(obj) {
+        const torch = this.torch;
+        if (obj instanceof torch.fx.node.Node) {
+            let subgraph = this.modules.get(obj);
+            if (subgraph) {
+                if (subgraph instanceof pytorch.Graph === false) {
+                    subgraph = new pytorch.Graph(this.execution, this.metadata, 'function', obj.target, subgraph);
+                    this.modules.set(obj, subgraph);
+                }
+                return subgraph;
+            }
+        }
+        return null;
+    }
+
+    graph(target, module, inputs) {
+        const graph = module.graph;
+        if (module.named_modules) {
+            const modules = module.named_modules();
+            for (const obj of graph.nodes) {
+                if (obj.op === 'get_attr') {
+                    const submodule = modules.get(obj.target);
+                    if (submodule && submodule.graph) {
+                        this.modules.set(obj, submodule);
+                    }
+                }
+            }
+        }
+        for (const obj of graph.nodes) {
+            if (obj.op === 'placeholder') {
+                if (inputs) {
+                    const value = this.value(obj);
+                    const argument = new pytorch.Argument(obj.name, [value]);
+                    target.inputs.push(argument);
+                }
+                continue;
+            }
+            if (obj.op === 'call_function') {
+                if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
+                    continue;
+                }
+            }
+            if (obj.op === 'get_attr') {
+                if (this.modules.has(obj) && obj.users.size === 1) {
+                    continue;
+                }
+            }
+            if (obj.op === 'output') {
+                for (const output of obj.args) {
+                    if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
+                        continue;
+                    }
+                    const value = this.value(output);
+                    const argument = new pytorch.Argument(output.name, [value]);
+                    target.outputs.push(argument);
+                }
+                continue;
+            }
+            const node = new pytorch.Node(this.execution, this.metadata, obj.name, null, obj, null, this);
+            target.nodes.push(node);
+        }
+    }
+};
+
+pytorch.Reader = class {
 
     static async open(context) {
         const types = [
-            pytorch.Container.Zip,
-            pytorch.Container.Pickle,
-            pytorch.Container.Tar,
-            pytorch.Container.data_pkl,
-            pytorch.Container.torch_utils,
-            pytorch.Container.Mobile,
-            pytorch.Container.ModelJson,
-            pytorch.Container.IR,
-            pytorch.Container.Index,
-            pytorch.Container.ExportedProgram
+            pytorch.Reader.Zip,
+            pytorch.Reader.Pickle,
+            pytorch.Reader.Tar,
+            pytorch.Reader.data_pkl,
+            pytorch.Reader.torch_utils,
+            pytorch.Reader.Mobile,
+            pytorch.Reader.ModelJson,
+            pytorch.Reader.IR,
+            pytorch.Reader.Index,
+            pytorch.Reader.ExportedProgram
         ];
         for (const type of types) {
-            /* eslint-disable no-await-in-loop */
-            const container = await type.open(context);
-            /* eslint-enable no-await-in-loop */
-            if (container) {
-                return container;
+            // eslint-disable-next-line no-await-in-loop
+            const reader = await type.open(context);
+            if (reader) {
+                return reader;
             }
         }
         return null;
@@ -1032,12 +1106,12 @@ pytorch.Container = class {
     }
 };
 
-pytorch.Container.Tar = class extends pytorch.Container {
+pytorch.Reader.Tar = class extends pytorch.Reader {
 
     static async open(context) {
         const entries = await context.peek('tar');
         if (entries instanceof Map && entries.has('pickle')) {
-            return new pytorch.Container.Tar(entries);
+            return new pytorch.Reader.Tar(entries);
         }
         return null;
     }
@@ -1060,13 +1134,13 @@ pytorch.Container.Tar = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.Pickle = class extends pytorch.Container {
+pytorch.Reader.Pickle = class extends pytorch.Reader {
 
     static async open(context) {
         const stream = context.stream;
         const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
         if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
-            return new pytorch.Container.Pickle(stream);
+            return new pytorch.Reader.Pickle(stream);
         }
         return null;
     }
@@ -1090,7 +1164,7 @@ pytorch.Container.Pickle = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.data_pkl = class extends pytorch.Container {
+pytorch.Reader.data_pkl = class extends pytorch.Reader {
 
     static async open(context) {
         const obj = await context.peek('pkl');
@@ -1098,30 +1172,30 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
             if (obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
                 const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
                 if (name.startsWith('__torch__.')) {
-                    return new pytorch.Container.data_pkl('', obj);
+                    return new pytorch.Reader.data_pkl('', obj);
                 }
             }
             if (pytorch.Utility.isTensor(obj)) {
-                return new pytorch.Container.data_pkl('tensor', obj);
+                return new pytorch.Reader.data_pkl('tensor', obj);
             }
             if (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) {
-                return new pytorch.Container.data_pkl('tensor', obj);
+                return new pytorch.Reader.data_pkl('tensor', obj);
             }
             if (obj instanceof Map) {
                 const entries = Array.from(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
                 if (entries.length > 0) {
-                    return new pytorch.Container.data_pkl('tensor', obj);
+                    return new pytorch.Reader.data_pkl('tensor', obj);
                 }
             } else if (!Array.isArray(obj)) {
                 const entries = Object.entries(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
                 if (entries.length > 0) {
-                    return new pytorch.Container.data_pkl('tensor', obj);
+                    return new pytorch.Reader.data_pkl('tensor', obj);
                 }
             }
             for (const key of ['', 'model', 'net']) {
                 const module = key === '' ? obj : obj[key];
                 if (module && module._modules && pytorch.Utility.isInstance(module._modules, 'collections.OrderedDict')) {
-                    return new pytorch.Container.data_pkl('module', module);
+                    return new pytorch.Reader.data_pkl('module', module);
                 }
             }
         }
@@ -1139,7 +1213,7 @@ pytorch.Container.data_pkl = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.torch_utils = class extends pytorch.Container {
+pytorch.Reader.torch_utils = class extends pytorch.Reader {
 
     static async open(context) {
         const stream = context.stream;
@@ -1150,7 +1224,7 @@ pytorch.Container.torch_utils = class extends pytorch.Container {
                 if (content.indexOf('torch_utils') !== -1) {
                     const obj = await context.peek('pkl');
                     if (obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isInstance(value, 'torch.nn.modules.module.Module'))) {
-                        return new pytorch.Container.torch_utils(obj);
+                        return new pytorch.Reader.torch_utils(obj);
                     }
                 }
             }
@@ -1171,12 +1245,12 @@ pytorch.Container.torch_utils = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.Mobile = class extends pytorch.Container {
+pytorch.Reader.Mobile = class extends pytorch.Reader {
 
     static async open(context) {
         const reader = await context.peek('flatbuffers.binary');
         if (reader && reader.identifier === 'PTMF') {
-            return new pytorch.Container.Mobile(context);
+            return new pytorch.Reader.Mobile(context);
         }
         return null;
     }
@@ -1203,7 +1277,7 @@ pytorch.Container.Mobile = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.Zip = class extends pytorch.Container {
+pytorch.Reader.Zip = class extends pytorch.Reader {
 
     static async open(context) {
         const entries = await context.peek('zip');
@@ -1220,10 +1294,10 @@ pytorch.Container.Zip = class extends pytorch.Container {
                 return null;
             }
             if (records.has('data.pkl')) {
-                return new pytorch.Container.Zip(entries);
+                return new pytorch.Reader.Zip(entries);
             }
             if (records.has('.data/version') && !records.has('archive_format')) {
-                return new pytorch.Container.Package(entries);
+                return new pytorch.Reader.Package(entries);
             }
         }
         return null;
@@ -1267,7 +1341,7 @@ pytorch.Container.Zip = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.ModelJson = class extends pytorch.Container {
+pytorch.Reader.ModelJson = class extends pytorch.Reader {
 
     static async open(context) {
         const identifier = context.identifier;
@@ -1276,7 +1350,7 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
             if (model && model.mainModule) {
                 const entries = new Map();
                 entries.set('model.json', context.stream);
-                return new pytorch.Container.ModelJson(context, entries, model);
+                return new pytorch.Reader.ModelJson(context, entries, model);
             }
         }
         return null;
@@ -1336,14 +1410,14 @@ pytorch.Container.ModelJson = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.IR = class extends pytorch.Container {
+pytorch.Reader.IR = class extends pytorch.Reader {
 
     static async open(context) {
         const reader = await context.read('text', 0x100);
         if (reader && reader.length > 0) {
             const line = reader.read('\n');
             if (line.startsWith('graph(')) {
-                return new pytorch.Container.IR(context);
+                return new pytorch.Reader.IR(context);
             }
         }
         return null;
@@ -1367,14 +1441,14 @@ pytorch.Container.IR = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.Index = class extends pytorch.Container {
+pytorch.Reader.Index = class extends pytorch.Reader {
 
     static async open(context) {
         const obj = await context.peek('json');
         if (obj && obj.weight_map) {
             const entries = Object.entries(obj.weight_map);
             if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) {
-                return new pytorch.Container.Index(context, entries);
+                return new pytorch.Reader.Index(context, entries);
             }
         }
         return null;
@@ -1424,18 +1498,18 @@ pytorch.Container.Index = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.ExportedProgram = class extends pytorch.Container {
+pytorch.Reader.ExportedProgram = class extends pytorch.Reader {
 
     static async open(context) {
         const program = await context.peek('json');
         if (program && program.schema_version && program.graph_module) {
-            return new pytorch.Container.ExportedProgram(context, program);
+            return new pytorch.Reader.ExportedProgram(context, program);
         }
         if (context.identifier === 'archive_format' && context.stream && context.stream.length < 10) {
             const buffer = context.stream.peek();
             const archive_format = String.fromCharCode.apply(null, buffer);
             if (archive_format === 'pt2') {
-                return new pytorch.Container.ExportedProgram(context, null, context);
+                return new pytorch.Reader.ExportedProgram(context, null, context);
             }
         }
         return null;
@@ -1530,9 +1604,8 @@ pytorch.Container.ExportedProgram = class extends pytorch.Container {
         const torch = this.execution.__import__('torch');
         for (const exported_program of exported_programs.values()) {
             if (exported_program.graph_module.graph.constants) {
-                /* eslint-disable no-await-in-loop */
+                // eslint-disable-next-line no-await-in-loop
                 const zip = await import('./zip.js');
-                /* eslint-enable no-await-in-loop */
                 const constants = exported_program.graph_module.graph.constants;
                 for (const key of Object.keys(constants)) {
                     const value = constants[key];
@@ -1591,9 +1664,8 @@ pytorch.Execution = class extends python.Execution {
     constructor(sources, metadata) {
         super(sources);
         this._metadata = metadata;
-        /* eslint-disable consistent-this */
+        // eslint-disable-next-line consistent-this
         const execution = this;
-        /* eslint-enable consistent-this */
         const torch = this.torch;
         this.registerFunction('torch.jit.jit_module_from_flatbuffer', (f) => {
             const cu = new torch.jit.CompilationUnit();
@@ -1797,7 +1869,7 @@ pytorch.Execution = class extends python.Execution {
     }
 };
 
-pytorch.Container.Package = class extends pytorch.Container {
+pytorch.Reader.Package = class extends pytorch.Reader {
 
     constructor(entries) {
         super();

+ 8 - 0
test/models.json

@@ -6389,6 +6389,14 @@
     "error":    "Unknown type name 'torch.ops.mylib.custom_relu.default'.",
     "link":     "https://github.com/lutzroeder/netron/issues/1211"
   },
+  {
+    "type":     "pytorch",
+    "target":   "nested_autocast.pt2",
+    "source":   "https://github.com/user-attachments/files/24322606/nested_autocast.pt2.zip[nested_autocast.pt2]",
+    "format":   "PyTorch Export v8.15",
+    "assert":   "model.modules[0].nodes[0].inputs[4].value.nodes[1].inputs[4].value.nodes.length == 1",
+    "link":     "https://github.com/lutzroeder/netron/issues/1211"
+  },
   {
     "type":     "pytorch",
     "target":   "netron_issue_313_v1.pt",