|
|
@@ -5746,27 +5746,19 @@ python.Execution = class {
|
|
|
torch.sym_not
|
|
|
]);
|
|
|
}
|
|
|
- deserialize(serialized_graph_module, symbol_name_to_range, constants) {
|
|
|
- this.shape_env = new torch.fx.experimental.symbolic_shapes.ShapeEnv(/* assume_static_by_default = True */);
|
|
|
- /*
|
|
|
- this.fake_tensor_mode = FakeTensorMode(
|
|
|
- allow_fallback_kernels=False,
|
|
|
- allow_non_fake_inputs=True,
|
|
|
- shape_env=self.shape_env,
|
|
|
- )
|
|
|
+ deserialize_graph_output(/* output */) {
|
|
|
+ /* TODO
|
|
|
+ if (output.type == 'as_tensor') {
|
|
|
+ return self.serialized_name_to_node[output.as_tensor.name]
|
|
|
+ }
|
|
|
+ else if (output.type == 'as_sym_int') {
|
|
|
+ return self.serialized_name_to_node[output.as_sym_int.as_name]
|
|
|
+ }
|
|
|
+ elif output.type == 'as_sym_bool':
|
|
|
+ return self.serialized_name_to_node[output.as_sym_bool.as_name]
|
|
|
+ else:
|
|
|
+ raise SerializeError(f'Unable to deserialize output node {output}')
|
|
|
*/
|
|
|
- this.symbol_name_to_symbol = new Map();
|
|
|
- this.symbol_name_to_range = symbol_name_to_range || new Map();
|
|
|
- this.constants = constants || new Map();
|
|
|
- this.deserialize_graph(serialized_graph_module.graph);
|
|
|
- const sig = null; // self.deserialize_signature(serialized_graph_module.signature)
|
|
|
- const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
|
|
|
- return {
|
|
|
- graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph),
|
|
|
- signature: sig,
|
|
|
- module_call_graph: module_call_graph,
|
|
|
- names_to_symbols: this.symbol_name_to_symbol
|
|
|
- };
|
|
|
}
|
|
|
deserialize_graph(serialized_graph) {
|
|
|
if (serialized_graph.constants) {
|
|
|
@@ -5797,9 +5789,9 @@ python.Execution = class {
|
|
|
}
|
|
|
deserialize_operator(serialized_target) {
|
|
|
let target = null;
|
|
|
- if (serialized_target.startsWith("_operator")) {
|
|
|
+ if (serialized_target.startsWith('_operator')) {
|
|
|
target = operator;
|
|
|
- } else if (serialized_target.startsWith("torch")) {
|
|
|
+ } else if (serialized_target.startsWith('torch')) {
|
|
|
target = torch;
|
|
|
} else {
|
|
|
return serialized_target;
|
|
|
@@ -5815,22 +5807,20 @@ python.Execution = class {
|
|
|
return target;
|
|
|
}
|
|
|
deserialize_node(serialized_node, target) {
|
|
|
- let name;
|
|
|
- const args = [];
|
|
|
- const kwargs = {};
|
|
|
let fx_node = null;
|
|
|
- /*
|
|
|
if (this._SYM_BOOL_OPS.has(target) || this._SYM_INT_OPS.has(target)) {
|
|
|
- name = serialized_node.outputs[0].value.as_name;
|
|
|
- args = self.deserialize_sym_op_inputs(serialized_node.inputs);
|
|
|
+ /*
|
|
|
+ const name = serialized_node.outputs[0].value.as_name;
|
|
|
+ const args = self.deserialize_sym_op_inputs(serialized_node.inputs);
|
|
|
fx_node = self.graph.create_node("call_function", target, args, {}, name);
|
|
|
- this.deserialize_sym_op_outputs(serialized_node, fx_node);
|
|
|
+ self.deserialize_sym_op_outputs(serialized_node, fx_node);
|
|
|
+ */
|
|
|
} else if (builtins.isinstance(target, torch._ops.HigherOrderOperator)) {
|
|
|
- // assert(len(serialized_node.outputs) == 1 && serialized_node.outputs[0].type in ("as_tensors", "as_tensor")), "Only single tensor output or list of tensor output is supported for higher order operators.")
|
|
|
+ // assert(len(serialized_node.outputs) == 1 && serialized_node.outputs[0].type in ('as_tensors', 'as_tensor')), 'Only single tensor output or list of tensor output is supported for higher order operators.')
|
|
|
const [output] = serialized_node.outputs;
|
|
|
- name = output.type == 'as_tensor' ? output.value.name : null;
|
|
|
- args = serialized_node.inputs.map((input) => this.deserialize_input(input.arg));
|
|
|
- fx_node = this.graph.create_node("call_function", target, args, {}, name);
|
|
|
+ const name = output.type == 'as_tensor' ? output.value.name : null;
|
|
|
+ const args = serialized_node.inputs.map((input) => this.deserialize_input(input.arg));
|
|
|
+ fx_node = this.graph.create_node('call_function', target, args, {}, name);
|
|
|
if (output.as_tensor !== null) {
|
|
|
this.sync_fx_node(name, fx_node);
|
|
|
}
|
|
|
@@ -5838,18 +5828,139 @@ python.Execution = class {
|
|
|
this.deserialize_multiple_outputs(serialized_node, fx_node);
|
|
|
}
|
|
|
} else if (builtins.isinstance(target, torch._ops.OpOverload)) {
|
|
|
- name = this._is_single_tensor_return(target) ? serialized_node.outputs[0].as_tensor.name : null;
|
|
|
- [args, kwargs] = this.deserialize_inputs(target, serialized_node);
|
|
|
- fx_node = self.graph.create_node("call_function", target, args, kwargs, name);
|
|
|
+ const name = this._is_single_tensor_return(target) ? serialized_node.outputs[0].as_tensor.name : null;
|
|
|
+ const [args, kwargs] = this.deserialize_inputs(target, serialized_node);
|
|
|
+ fx_node = self.graph.create_node('call_function', target, args, kwargs, name);
|
|
|
this.deserialize_outputs(serialized_node, fx_node);
|
|
|
} else {
|
|
|
- throw new python.Error(`Unsupported target type '${target}'.`);
|
|
|
+ // TODO
|
|
|
+ // throw new python.Error(`Unsupported target type '${target}'.`);
|
|
|
+ }
|
|
|
+ fx_node && Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata));
|
|
|
+ }
|
|
|
+ deserialize(serialized_graph_module, symbol_name_to_range, constants) {
|
|
|
+ this.shape_env = new torch.fx.experimental.symbolic_shapes.ShapeEnv(/* assume_static_by_default = True */);
|
|
|
+ /*
|
|
|
+ this.fake_tensor_mode = FakeTensorMode(
|
|
|
+ allow_fallback_kernels=False,
|
|
|
+ allow_non_fake_inputs=True,
|
|
|
+ shape_env=self.shape_env,
|
|
|
+ )
|
|
|
+ */
|
|
|
+ this.symbol_name_to_symbol = new Map();
|
|
|
+ this.symbol_name_to_range = symbol_name_to_range || new Map();
|
|
|
+ this.constants = constants || new Map();
|
|
|
+ this.deserialize_graph(serialized_graph_module.graph);
|
|
|
+ const sig = null; // self.deserialize_signature(serialized_graph_module.signature)
|
|
|
+ const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
|
|
|
+ return {
|
|
|
+ graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph),
|
|
|
+ signature: sig,
|
|
|
+ module_call_graph: module_call_graph,
|
|
|
+ names_to_symbols: this.symbol_name_to_symbol
|
|
|
+ };
|
|
|
+ }
|
|
|
+ sync_fx_node(name, fx_node) {
|
|
|
+ if (this.serialized_name_to_node.has(name)) {
|
|
|
+ throw new python.Error(`Node ${name} has already been deserialized before.`);
|
|
|
+ }
|
|
|
+ this.serialized_name_to_node.set(name, fx_node);
|
|
|
+ fx_node.meta['val'] = this.serialized_name_to_meta.get(name);
|
|
|
+ }
|
|
|
+ deserialize_sym_op_inputs(inputs) {
|
|
|
+ return inputs.map((input) => this.deserialize_input(input.arg));
|
|
|
+ }
|
|
|
+ deserialize_inputs(target /* , serialized_node */) {
|
|
|
+ const schema_args = target._schema.arguments;
|
|
|
+ const actual_args = null;
|
|
|
+ /*
|
|
|
+ actual_args = {
|
|
|
+ input.name: self.deserialize_input(input.arg) for input in serialized_node.inputs
|
|
|
}
|
|
|
*/
|
|
|
- fx_node = this.graph.create_node("call_function", target, args, kwargs, name);
|
|
|
- Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata));
|
|
|
+ const args = [];
|
|
|
+ const kwargs = {};
|
|
|
+ for (const schema_arg of schema_args) {
|
|
|
+ const is_positional = !schema_arg.has_default_value() && !schema_arg.kwarg_only;
|
|
|
+ if (is_positional) {
|
|
|
+ args.push(actual_args[schema_arg.name]);
|
|
|
+ } else if (schema_arg.name in actual_args) {
|
|
|
+ kwargs[schema_arg.name] = actual_args[schema_arg.name];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return [ args, kwargs ];
|
|
|
}
|
|
|
- deserialize_graph_output() {
|
|
|
+ deserialize_input(/* inp */) {
|
|
|
+ /*
|
|
|
+ value = inp.value
|
|
|
+ typ_ = inp.type
|
|
|
+ if typ_ == 'as_none':
|
|
|
+ # None should converted as None, but is encoded as bool in serialized
|
|
|
+ # Convert serialized object to torch equivalent
|
|
|
+ return None
|
|
|
+ elif typ_ == 'as_tensor':
|
|
|
+ return self.serialized_name_to_node[inp.as_tensor.name]
|
|
|
+ elif typ_ == 'as_scalar_type':
|
|
|
+ return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type]
|
|
|
+ elif typ_ == 'as_memory_format':
|
|
|
+ return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format]
|
|
|
+ elif typ_ == 'as_layout':
|
|
|
+ return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout]
|
|
|
+ elif typ_ == 'as_graph':
|
|
|
+ assert isinstance(value, GraphArgument)
|
|
|
+ with self.save_graph_module():
|
|
|
+ self.deserialize_graph(value.graph)
|
|
|
+ submodule = torch._export.exported_program._create_graph_module_for_export(self.module, self.graph)
|
|
|
+ self.module.register_module(value.name, submodule)
|
|
|
+ return self.graph.create_node(
|
|
|
+ 'get_attr',
|
|
|
+ value.name,
|
|
|
+ name=value.name,
|
|
|
+ )
|
|
|
+ elif typ_ == 'as_device':
|
|
|
+ return deserialize_device(inp.as_device)
|
|
|
+ elif typ_ == 'as_int':
|
|
|
+ return inp.as_int
|
|
|
+ elif typ_ == 'as_float':
|
|
|
+ return inp.as_float
|
|
|
+ elif typ_ == 'as_bool':
|
|
|
+ return inp.as_bool
|
|
|
+ elif typ_ == 'as_string':
|
|
|
+ return inp.as_string
|
|
|
+ elif typ_ == 'as_sym_int':
|
|
|
+ return self.deserialize_sym_argument(inp.as_sym_int)
|
|
|
+ elif typ_ == 'as_sym_bool':
|
|
|
+ return self.deserialize_sym_argument(inp.as_sym_bool)
|
|
|
+ elif isinstance(value, list):
|
|
|
+ if len(value) == 0:
|
|
|
+ return []
|
|
|
+ elif isinstance(value[0], TensorArgument):
|
|
|
+ result = []
|
|
|
+ for arg in value:
|
|
|
+ result.append(self.serialized_name_to_node[arg.name])
|
|
|
+ return result
|
|
|
+ elif isinstance(value[0], (int, float, bool)):
|
|
|
+ # convert from serialized.python.types.List to python list
|
|
|
+ return list(value)
|
|
|
+ elif isinstance(value[0], (SymIntArgument, SymBoolArgument)):
|
|
|
+ return [self.deserialize_sym_argument(arg) for arg in value]
|
|
|
+ elif isinstance(value[0], OptionalTensorArgument):
|
|
|
+ def deserialize_optional_tensor_args(a):
|
|
|
+ if a.type == 'as_none':
|
|
|
+ return None
|
|
|
+ elif a.type == 'as_tensor':
|
|
|
+ return self.serialized_name_to_node[a.value]
|
|
|
+ else:
|
|
|
+ raise SerializeError(f'Unhandled argument {inp}')
|
|
|
+ return list(map(deserialize_optional_tensor_args, value))
|
|
|
+ else:
|
|
|
+ raise SerializeError(f'Unhandled argument {inp}')
|
|
|
+ elif typ_ == 'as_custom_obj':
|
|
|
+ return self.constants[inp.as_custom_obj.name]
|
|
|
+ else {
|
|
|
+ raise SerializeError(`Unhandled argument ${inp}.`);
|
|
|
+ }
|
|
|
+ */
|
|
|
}
|
|
|
deserialize_metadata(metadata) {
|
|
|
const ret = {};
|
|
|
@@ -5940,13 +6051,6 @@ python.Execution = class {
|
|
|
}
|
|
|
return new torch.device(d.type);
|
|
|
}
|
|
|
- sync_fx_node(name, fx_node) {
|
|
|
- if (this.serialized_name_to_node.has(name)) {
|
|
|
- throw new python.Error(`Node ${name} has already been deserialized before.`);
|
|
|
- }
|
|
|
- this.serialized_name_to_node.set(name, fx_node);
|
|
|
- fx_node.meta['val'] = this.serialized_name_to_meta.get(name);
|
|
|
- }
|
|
|
});
|
|
|
this.registerFunction('torch_utils.persistence._reconstruct_persistent_obj', function(meta) {
|
|
|
const name = `_imported_module_${Math.floor(Math.random() * 10000)}`;
|