Pārlūkot izejas kodu

Update python.js (#1211)

Lutz Roeder 2 gadi atpakaļ
vecāks
revīzija
3ceef5cd94
3 mainītis faili ar 138 papildinājumiem un 39 dzēšanām
  1. 85 12
      source/python.js
  2. 41 27
      source/pytorch.js
  3. 12 0
      source/view.js

+ 85 - 12
source/python.js

@@ -5450,6 +5450,28 @@ python.Execution = class {
         this.registerType('torch._ops.OperatorBase', class {});
         this.registerType('torch._ops.HigherOrderOperator', class extends torch._ops.OperatorBase {});
         this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase {});
+        this.registerType('torch.export.unflatten.UnflattenedModule', class extends torch.nn.modules.module.Module {
+            constructor(/* export_module, flat_args_adapter */) {
+                super();
+            }
+        });
+        this.registerType('torch.export.exported_program.ExportedProgram', class {
+            constructor(/* root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, verifier, tensor_constants */) {
+            }
+        });
+        this.registerFunction('torch.export.unflatten', function(/* module, flat_args_adapter */) {
+            throw new python.Error("'torch.export.unflatten' not implemented.");
+        });
+        this.registerFunction('torch._export.exported_program._create_graph_module_for_export', function(root, graph) {
+            return new torch.fx.graph_module.GraphModule(root, graph);
+        });
+        this.registerType('torch._export.serde.serialize.SerializedArtifact', class {
+            constructor(exported_program, state_dict, constants) {
+                this.exported_program = exported_program;
+                this.state_dict = state_dict;
+                this.constants = constants;
+            }
+        });
         this.registerType('torch.fx.experimental.symbolic_shapes.ShapeEnv', class {
             constructor() {
             }
@@ -5619,11 +5641,61 @@ python.Execution = class {
             }
         });
         torch.fx.Graph = torch.fx.graph.Graph;
-        this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module {});
+        this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module {
+            constructor(root, graph) {
+                super();
+                this.graph = graph;
+            }
+        });
         this.registerFunction('torch.fx._symbolic_trace.wrap', function(fn_or_name) {
             return fn_or_name;
         });
         this.registerType('torch.fx._symbolic_trace.Tracer', class {});
+        this.registerFunction('torch._export.load', function(f, expected_opset_version) {
+            const serialized_exported_program = f.get('serialized_exported_program.json');
+            const serialized_state_dict = f.get('serialized_state_dict.pt');
+            const serialized_constants = f.get('serialized_constants.pt');
+            const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants);
+            return torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
+        });
+        this.registerFunction('torch._export.serde.serialize.deserialize', function(artifact, expected_opset_version) {
+            return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(artifact);
+        });
+        this.registerType('torch._export.serde.serialize.ExportedProgramDeserializer', class {
+            constructor(expected_opset_version) {
+                this.expected_opset_version = expected_opset_version;
+            }
+            deserialize(serialized_artifact) {
+                const symbol_name_to_range = new Map(Object.entries(serialized_artifact.exported_program.range_constraints));
+                /*
+                symbol_name_to_range = {
+                    k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val))
+                    for k, v in serialized_artifact.exported_program.range_constraints.items()
+                }
+                */
+                const constants = serialized_artifact.constants ? torch.load(serialized_artifact.constants) : null;
+                const tensor_constants = constants ? new Map(Object.entries(constants).filter(([, tensor]) => tensor instanceof torch.Tensor)) : null;
+                const deserializer = new torch._export.serde.serialize.GraphModuleDeserializer();
+                const res = deserializer.deserialize(serialized_artifact.exported_program.graph_module, symbol_name_to_range, constants);
+                const range_constraints = null;
+                /*
+                range_constraints = self.deserialize_range_constraints(
+                    symbol_name_to_range, res.names_to_symbols,
+                )
+                model_opset_version: Optional[Dict[str, int]] = serialized_artifact.exported_program.opset_version
+                self._validate_model_opset_version(model_opset_version)
+                upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version)
+                */
+                const state_dict = serialized_artifact.state_dict ? torch.load(serialized_artifact.state_dict) : null;
+                const exported_program = new torch.export.exported_program.ExportedProgram(
+                    res.graph_module, res.graph_module.graph, res.signature,
+                    state_dict, range_constraints, res.module_call_graph, null,
+                    null, // verifier=load_verifier(serialized_artifact.exported_program.dialect),
+                    tensor_constants);
+                return exported_program;
+                // return upgrader.upgrade(exported_program)
+            }
+        });
         this.registerType('torch._export.serde.serialize.GraphModuleDeserializer', class {
             constructor() {
                 this.serialized_name_to_node = new Map();
@@ -5652,20 +5724,21 @@ python.Execution = class {
                 this.symbol_name_to_range = symbol_name_to_range || new Map();
                 this.constants = constants || new Map();
                 this.deserialize_graph(serialized_graph_module.graph);
+                const sig = null; // self.deserialize_signature(serialized_graph_module.signature)
+                const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
+                return {
+                    graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph),
+                    signature: sig,
+                    module_call_graph: module_call_graph,
+                    names_to_symbols: this.symbol_name_to_symbol
+                };
             }
             deserialize_graph(serialized_graph) {
-                this.constants = new Map(Object.entries(serialized_graph.constants).map(([k, v]) => {
-                    const str = atob(v);
-                    const buffer = new Uint8Array(str.length);
-                    for (let i = 0; i < str.length; i++) {
-                        buffer[i] = str.charCodeAt(i);
-                    }
-                    const archive = self.zip.Archive.open(buffer);
-                    const value = torch.load(archive.entries);
-                    return [ k, value ];
-                }));
+                if (serialized_graph.constants) {
+                    this.constants = new Map(Object.entries(serialized_graph.constants).map(([k, v]) => [ k, torch.load(v) ]));
+                }
                 for (const [name, tensor_value] of Object.entries(serialized_graph.tensor_values)) {
-                    const meta_val = this.deserialize_tensor_meta(tensor_value.meta, self.fake_tensor_mode);
+                    const meta_val = this.deserialize_tensor_meta(tensor_value.meta, this.fake_tensor_mode);
                     this.serialized_name_to_meta.set(name, meta_val);
                 }
                 for (const [name, sym_int_value] of Object.entries(serialized_graph.sym_int_values)) {

+ 41 - 27
source/pytorch.js

@@ -764,7 +764,7 @@ pytorch.Container = class {
         if (index) {
             return index;
         }
-        const dynamo = pytorch.Container.Dynamo.open(context);
+        const dynamo = pytorch.Container.ExportedProgram.open(context);
         if (dynamo) {
             return dynamo;
         }
@@ -1219,48 +1219,50 @@ pytorch.Container.Index = class extends pytorch.Container {
     }
 };
 
-pytorch.Container.Dynamo = class extends pytorch.Container {
+pytorch.Container.ExportedProgram = class extends pytorch.Container {
 
     static open(context) {
         const program = context.peek('json');
         if (program && program.schema_version && program.graph_module) {
-            return new pytorch.Container.Dynamo(context, program);
+            return new pytorch.Container.ExportedProgram(context, program);
         }
         return null;
     }
 
-    constructor(context, program) {
+    constructor(context, serialized_exported_program) {
         super();
         this._context = context;
-        this._program = program;
+        this._serialized_exported_program = serialized_exported_program;
     }
 
     async read() {
         this._format = 'PyTorch Export';
-        let content = null;
-        try {
-            content = await this._context.fetch('serialized_state_dict.json');
-        } catch (error) {
-            // continue regardless of error
+        const serialized_state_dict = await this._fetch('serialized_state_dict.pt') || await this._fetch('serialized_state_dict.json');
+        const serialized_constants = await this._fetch('serialized_constants.pt') || await this._fetch('serialized_constants.json');
+        const f = new Map();
+        f.set('serialized_exported_program.json', this._serialized_exported_program);
+        f.set('serialized_state_dict.pt', serialized_state_dict);
+        f.set('serialized_constants.pt', serialized_constants);
+        const execution = new pytorch.Execution();
+        for (const event of this._events) {
+            execution.on(event[0], event[1]);
         }
-        if (content) {
-            const state_dict = content.peek('zip');
-            const execution = new pytorch.Execution();
-            execution.zip = await import('./zip.js');
-            for (const event of this._events) {
-                execution.on(event[0], event[1]);
-            }
-            const torch = execution.__import__('torch');
-            this._data = torch.load(state_dict);
-            const serialized_exported_program = this._program;
-            const deserializer = new torch._export.serde.serialize.GraphModuleDeserializer();
-            const symbol_name_to_range = new Map(Object.entries(serialized_exported_program.range_constraints));
-            /* TODO
-                k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val))
-                for k, v in serialized_exported_program.range_constraints.items()
-            */
-            deserializer.deserialize(serialized_exported_program.graph_module, symbol_name_to_range);
+        const torch = execution.__import__('torch');
+        if (this._serialized_exported_program.graph_module.graph.constants) {
+            const zip = await import('./zip.js');
+            const constants = this._serialized_exported_program.graph_module.graph.constants;
+            for (const key of Object.keys(constants)) {
+                const value = constants[key];
+                const str = atob(value);
+                const buffer = new Uint8Array(str.length);
+                for (let i = 0; i < str.length; i++) {
+                    buffer[i] = str.charCodeAt(i);
+                }
+                const archive = zip.Archive.open(buffer);
+                constants[key] = archive.entries;
+            }
         }
+        /* const exported_program = */ torch._export.load(f);
         throw new pytorch.Error(`'torch.export' not supported.`);
     }
 
@@ -1271,6 +1273,18 @@ pytorch.Container.Dynamo = class extends pytorch.Container {
     get modules() {
         return this._modules;
     }
+
+    async _fetch(name) {
+        try {
+            const context = await this._context.fetch(name);
+            if (context) {
+                return context.peek('zip');
+            }
+        } catch (error) {
+            // continue regardless of error
+        }
+        return null;
+    }
 };
 
 pytorch.Execution = class extends python.Execution {

+ 12 - 0
source/view.js

@@ -5675,6 +5675,18 @@ view.ModelFactoryService = class {
                 if (matches.length === 0) {
                     return null;
                 }
+                // PyTorch
+                if (matches.length === 2 &&
+                    matches.some((context) => context.identifier === 'serialized_exported_program.json') &&
+                    matches.some((context) => context.identifier === 'serialized_state_dict.pt')) {
+                    matches = matches.filter((context) => context.identifier === 'serialized_exported_program.json');
+                }
+                if (matches.length === 3 &&
+                    matches.some((context) => context.identifier === 'serialized_exported_program.json') &&
+                    matches.some((context) => context.identifier === 'serialized_state_dict.pt') &&
+                    matches.some((context) => context.identifier === 'serialized_constants.pt')) {
+                    matches = matches.filter((context) => context.identifier === 'serialized_exported_program.json');
+                }
                 // MXNet
                 if (matches.length === 2 &&
                     matches.some((context) => context.identifier.toLowerCase().endsWith('.params')) &&