|
|
@@ -5450,6 +5450,28 @@ python.Execution = class {
|
|
|
this.registerType('torch._ops.OperatorBase', class {});
|
|
|
this.registerType('torch._ops.HigherOrderOperator', class extends torch._ops.OperatorBase {});
|
|
|
this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase {});
|
|
|
+ this.registerType('torch.export.unflatten.UnflattenedModule', class extends torch.nn.modules.module.Module {
|
|
|
+ constructor(/* export_module, flat_args_adapter */) {
|
|
|
+ super();
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('torch.export.exported_program.ExportedProgram', class {
|
|
|
+ constructor(/* root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs, verifier, tensor_constants */) {
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerFunction('torch.export.unflatten', function(/* module, flat_args_adapter */) {
|
|
|
+ throw new python.Error("'torch.export.unflatten' not implemented.");
|
|
|
+ });
|
|
|
+ this.registerFunction('torch._export.exported_program._create_graph_module_for_export', function(root, graph) {
|
|
|
+ return new torch.fx.graph_module.GraphModule(root, graph);
|
|
|
+ });
|
|
|
+ this.registerType('torch._export.serde.serialize.SerializedArtifact', class {
|
|
|
+ constructor(exported_program, state_dict, constants) {
|
|
|
+ this.exported_program = exported_program;
|
|
|
+ this.state_dict = state_dict;
|
|
|
+ this.constants = constants;
|
|
|
+ }
|
|
|
+ });
|
|
|
this.registerType('torch.fx.experimental.symbolic_shapes.ShapeEnv', class {
|
|
|
constructor() {
|
|
|
}
|
|
|
@@ -5619,11 +5641,61 @@ python.Execution = class {
|
|
|
}
|
|
|
});
|
|
|
torch.fx.Graph = torch.fx.graph.Graph;
|
|
|
- this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module {});
|
|
|
+ this.registerType('torch.fx.graph_module.GraphModule', class extends torch.nn.modules.module.Module {
|
|
|
+ constructor(root, graph) {
|
|
|
+ super();
|
|
|
+ this.graph = graph;
|
|
|
+ }
|
|
|
+ });
|
|
|
this.registerFunction('torch.fx._symbolic_trace.wrap', function(fn_or_name) {
|
|
|
return fn_or_name;
|
|
|
});
|
|
|
this.registerType('torch.fx._symbolic_trace.Tracer', class {});
|
|
|
+ this.registerFunction('torch._export.load', function(f, expected_opset_version) {
|
|
|
+ const serialized_exported_program = f.get('serialized_exported_program.json');
|
|
|
+ const serialized_state_dict = f.get('serialized_state_dict.pt');
|
|
|
+ const serialized_constants = f.get('serialized_constants.pt');
|
|
|
+ const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants);
|
|
|
+ return torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
|
|
|
+ });
|
|
|
+ this.registerFunction('torch._export.serde.serialize.deserialize', function(artifact, expected_opset_version) {
|
|
|
+ return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(artifact);
|
|
|
+ });
|
|
|
+ this.registerType('torch._export.serde.serialize.ExportedProgramDeserializer', class {
|
|
|
+ constructor(expected_opset_version) {
|
|
|
+ this.expected_opset_version = expected_opset_version;
|
|
|
+ }
|
|
|
+ deserialize(serialized_artifact) {
|
|
|
+ const symbol_name_to_range = new Map(Object.entries(serialized_artifact.exported_program.range_constraints));
|
|
|
+ /*
|
|
|
+ symbol_name_to_range = {
|
|
|
+ k: symbolic_shapes.ValueRanges(_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val))
|
|
|
+ for k, v in serialized_artifact.exported_program.range_constraints.items()
|
|
|
+ }
|
|
|
+ */
|
|
|
+ const constants = serialized_artifact.constants ? torch.load(serialized_artifact.constants) : null;
|
|
|
+ const tensor_constants = constants ? new Map(Object.entries(constants).filter(([, tensor]) => tensor instanceof torch.Tensor)) : null;
|
|
|
+ const deserializer = new torch._export.serde.serialize.GraphModuleDeserializer();
|
|
|
+ const res = deserializer.deserialize(serialized_artifact.exported_program.graph_module, symbol_name_to_range, constants);
|
|
|
+ const range_constraints = null;
|
|
|
+ /*
|
|
|
+ range_constraints = self.deserialize_range_constraints(
|
|
|
+ symbol_name_to_range, res.names_to_symbols,
|
|
|
+ )
|
|
|
+ model_opset_version: Optional[Dict[str, int]] = serialized_artifact.exported_program.opset_version
|
|
|
+ self._validate_model_opset_version(model_opset_version)
|
|
|
+ upgrader = GraphModuleOpUpgrader(self.expected_opset_version, model_opset_version)
|
|
|
+ */
|
|
|
+ const state_dict = serialized_artifact.state_dict ? torch.load(serialized_artifact.state_dict) : null;
|
|
|
+ const exported_program = new torch.export.exported_program.ExportedProgram(
|
|
|
+ res.graph_module, res.graph_module.graph, res.signature,
|
|
|
+ state_dict, range_constraints, res.module_call_graph, null,
|
|
|
+ null, // verifier=load_verifier(serialized_artifact.exported_program.dialect),
|
|
|
+ tensor_constants);
|
|
|
+ return exported_program;
|
|
|
+ // return upgrader.upgrade(exported_program)
|
|
|
+ }
|
|
|
+ });
|
|
|
this.registerType('torch._export.serde.serialize.GraphModuleDeserializer', class {
|
|
|
constructor() {
|
|
|
this.serialized_name_to_node = new Map();
|
|
|
@@ -5652,20 +5724,21 @@ python.Execution = class {
|
|
|
this.symbol_name_to_range = symbol_name_to_range || new Map();
|
|
|
this.constants = constants || new Map();
|
|
|
this.deserialize_graph(serialized_graph_module.graph);
|
|
|
+ const sig = null; // self.deserialize_signature(serialized_graph_module.signature)
|
|
|
+ const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
|
|
|
+ return {
|
|
|
+ graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph),
|
|
|
+ signature: sig,
|
|
|
+ module_call_graph: module_call_graph,
|
|
|
+ names_to_symbols: this.symbol_name_to_symbol
|
|
|
+ };
|
|
|
}
|
|
|
deserialize_graph(serialized_graph) {
|
|
|
- this.constants = new Map(Object.entries(serialized_graph.constants).map(([k, v]) => {
|
|
|
- const str = atob(v);
|
|
|
- const buffer = new Uint8Array(str.length);
|
|
|
- for (let i = 0; i < str.length; i++) {
|
|
|
- buffer[i] = str.charCodeAt(i);
|
|
|
- }
|
|
|
- const archive = self.zip.Archive.open(buffer);
|
|
|
- const value = torch.load(archive.entries);
|
|
|
- return [ k, value ];
|
|
|
- }));
|
|
|
+ if (serialized_graph.constants) {
|
|
|
+ this.constants = new Map(Object.entries(serialized_graph.constants).map(([k, v]) => [ k, torch.load(v) ]));
|
|
|
+ }
|
|
|
for (const [name, tensor_value] of Object.entries(serialized_graph.tensor_values)) {
|
|
|
- const meta_val = this.deserialize_tensor_meta(tensor_value.meta, self.fake_tensor_mode);
|
|
|
+ const meta_val = this.deserialize_tensor_meta(tensor_value.meta, this.fake_tensor_mode);
|
|
|
this.serialized_name_to_meta.set(name, meta_val);
|
|
|
}
|
|
|
for (const [name, sym_int_value] of Object.entries(serialized_graph.sym_int_values)) {
|