Lutz Roeder 2 лет назад
Родитель
Сommit
e7771ec383
1 измененных файлов с 151 добавлено и 47 удалено
  1. 151 47
      source/python.js

+ 151 - 47
source/python.js

@@ -5746,27 +5746,19 @@ python.Execution = class {
                     torch.sym_not
                 ]);
             }
-            deserialize(serialized_graph_module, symbol_name_to_range, constants) {
-                this.shape_env = new torch.fx.experimental.symbolic_shapes.ShapeEnv(/* assume_static_by_default = True */);
-                /*
-                this.fake_tensor_mode = FakeTensorMode(
-                    allow_fallback_kernels=False,
-                    allow_non_fake_inputs=True,
-                    shape_env=self.shape_env,
-                )
+            deserialize_graph_output(/* output */) {
+                /* TODO
+                if (output.type == 'as_tensor') {
+                    return self.serialized_name_to_node[output.as_tensor.name]
+                }
+                else if (output.type == 'as_sym_int') {
+                    return self.serialized_name_to_node[output.as_sym_int.as_name]
+                }
+                elif output.type == 'as_sym_bool':
+                    return self.serialized_name_to_node[output.as_sym_bool.as_name]
+                else:
+                    raise SerializeError(f'Unable to deserialize output node {output}')
                 */
-                this.symbol_name_to_symbol = new Map();
-                this.symbol_name_to_range = symbol_name_to_range || new Map();
-                this.constants = constants || new Map();
-                this.deserialize_graph(serialized_graph_module.graph);
-                const sig = null; // self.deserialize_signature(serialized_graph_module.signature)
-                const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
-                return {
-                    graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph),
-                    signature: sig,
-                    module_call_graph: module_call_graph,
-                    names_to_symbols: this.symbol_name_to_symbol
-                };
             }
             deserialize_graph(serialized_graph) {
                 if (serialized_graph.constants) {
@@ -5797,9 +5789,9 @@ python.Execution = class {
             }
             deserialize_operator(serialized_target) {
                 let target = null;
-                if (serialized_target.startsWith("_operator")) {
+                if (serialized_target.startsWith('_operator')) {
                     target = operator;
-                } else if (serialized_target.startsWith("torch")) {
+                } else if (serialized_target.startsWith('torch')) {
                     target = torch;
                 } else {
                     return serialized_target;
@@ -5815,22 +5807,20 @@ python.Execution = class {
                 return target;
             }
             deserialize_node(serialized_node, target) {
-                let name;
-                const args = [];
-                const kwargs = {};
                 let fx_node = null;
-                /*
                 if (this._SYM_BOOL_OPS.has(target) || this._SYM_INT_OPS.has(target)) {
-                    name = serialized_node.outputs[0].value.as_name;
-                    args = self.deserialize_sym_op_inputs(serialized_node.inputs);
+                    /*
+                    const name = serialized_node.outputs[0].value.as_name;
+                    const args = self.deserialize_sym_op_inputs(serialized_node.inputs);
                     fx_node = self.graph.create_node("call_function", target, args, {}, name);
-                    this.deserialize_sym_op_outputs(serialized_node, fx_node);
+                    self.deserialize_sym_op_outputs(serialized_node, fx_node);
+                    */
                 } else if (builtins.isinstance(target, torch._ops.HigherOrderOperator)) {
-                    // assert(len(serialized_node.outputs) == 1 && serialized_node.outputs[0].type in ("as_tensors", "as_tensor")), "Only single tensor output or list of tensor output is supported for higher order operators.")
+                    // assert(len(serialized_node.outputs) == 1 && serialized_node.outputs[0].type in ('as_tensors', 'as_tensor')), 'Only single tensor output or list of tensor output is supported for higher order operators.')
                     const [output] = serialized_node.outputs;
-                    name = output.type == 'as_tensor' ? output.value.name : null;
-                    args = serialized_node.inputs.map((input) => this.deserialize_input(input.arg));
-                    fx_node = this.graph.create_node("call_function", target, args, {}, name);
+                    const name = output.type == 'as_tensor' ? output.value.name : null;
+                    const args = serialized_node.inputs.map((input) => this.deserialize_input(input.arg));
+                    fx_node = this.graph.create_node('call_function', target, args, {}, name);
                     if (output.as_tensor !== null) {
                         this.sync_fx_node(name, fx_node);
                     }
@@ -5838,18 +5828,139 @@ python.Execution = class {
                         this.deserialize_multiple_outputs(serialized_node, fx_node);
                     }
                 } else if (builtins.isinstance(target, torch._ops.OpOverload)) {
-                    name = this._is_single_tensor_return(target) ? serialized_node.outputs[0].as_tensor.name : null;
-                    [args, kwargs] = this.deserialize_inputs(target, serialized_node);
-                    fx_node = self.graph.create_node("call_function", target, args, kwargs, name);
+                    const name = this._is_single_tensor_return(target) ? serialized_node.outputs[0].as_tensor.name : null;
+                    const [args, kwargs] = this.deserialize_inputs(target, serialized_node);
+                    fx_node = self.graph.create_node('call_function', target, args, kwargs, name);
                     this.deserialize_outputs(serialized_node, fx_node);
                 } else {
-                    throw new python.Error(`Unsupported target type '${target}'.`);
+                    // TODO
+                    // throw new python.Error(`Unsupported target type '${target}'.`);
+                }
+                fx_node && Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata));
+            }
+            deserialize(serialized_graph_module, symbol_name_to_range, constants) {
+                this.shape_env = new torch.fx.experimental.symbolic_shapes.ShapeEnv(/* assume_static_by_default = True */);
+                /*
+                this.fake_tensor_mode = FakeTensorMode(
+                    allow_fallback_kernels=False,
+                    allow_non_fake_inputs=True,
+                    shape_env=self.shape_env,
+                )
+                */
+                this.symbol_name_to_symbol = new Map();
+                this.symbol_name_to_range = symbol_name_to_range || new Map();
+                this.constants = constants || new Map();
+                this.deserialize_graph(serialized_graph_module.graph);
+                const sig = null; // self.deserialize_signature(serialized_graph_module.signature)
+                const module_call_graph = null; // self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
+                return {
+                    graph_module: torch._export.exported_program._create_graph_module_for_export(this.module, this.graph),
+                    signature: sig,
+                    module_call_graph: module_call_graph,
+                    names_to_symbols: this.symbol_name_to_symbol
+                };
+            }
+            sync_fx_node(name, fx_node) {
+                if (this.serialized_name_to_node.has(name)) {
+                    throw new python.Error(`Node ${name} has already been deserialized before.`);
+                }
+                this.serialized_name_to_node.set(name, fx_node);
+                fx_node.meta['val'] = this.serialized_name_to_meta.get(name);
+            }
+            deserialize_sym_op_inputs(inputs) {
+                return inputs.map((input) => this.deserialize_input(input.arg));
+            }
+            deserialize_inputs(target /* , serialized_node */) {
+                const schema_args = target._schema.arguments;
+                const actual_args = null;
+                /*
+                actual_args = {
+                    input.name: self.deserialize_input(input.arg) for input in serialized_node.inputs
                 }
                 */
-                fx_node = this.graph.create_node("call_function", target, args, kwargs, name);
-                Object.assign(fx_node.meta, this.deserialize_metadata(serialized_node.metadata));
+                const args = [];
+                const kwargs = {};
+                for (const schema_arg of schema_args) {
+                    const is_positional = !schema_arg.has_default_value() && !schema_arg.kwarg_only;
+                    if (is_positional) {
+                        args.push(actual_args[schema_arg.name]);
+                    } else if (schema_arg.name in actual_args) {
+                        kwargs[schema_arg.name] = actual_args[schema_arg.name];
+                    }
+                }
+                return [ args, kwargs ];
             }
-            deserialize_graph_output() {
+            deserialize_input(/* inp */) {
+                /*
+                value = inp.value
+                typ_ = inp.type
+                if typ_ == 'as_none':
+                    # None should converted as None, but is encoded as bool in serialized
+                    # Convert serialized object to torch equivalent
+                    return None
+                elif typ_ == 'as_tensor':
+                    return self.serialized_name_to_node[inp.as_tensor.name]
+                elif typ_ == 'as_scalar_type':
+                    return _SERIALIZE_TO_TORCH_DTYPE[inp.as_scalar_type]
+                elif typ_ == 'as_memory_format':
+                    return _SERIALIZE_TO_TORCH_MEMORY_FORMAT[inp.as_memory_format]
+                elif typ_ == 'as_layout':
+                    return _SERIALIZE_TO_TORCH_LAYOUT[inp.as_layout]
+                elif typ_ == 'as_graph':
+                    assert isinstance(value, GraphArgument)
+                    with self.save_graph_module():
+                        self.deserialize_graph(value.graph)
+                        submodule = torch._export.exported_program._create_graph_module_for_export(self.module, self.graph)
+                    self.module.register_module(value.name, submodule)
+                    return self.graph.create_node(
+                        'get_attr',
+                        value.name,
+                        name=value.name,
+                    )
+                elif typ_ == 'as_device':
+                    return deserialize_device(inp.as_device)
+                elif typ_ == 'as_int':
+                    return inp.as_int
+                elif typ_ == 'as_float':
+                    return inp.as_float
+                elif typ_ == 'as_bool':
+                    return inp.as_bool
+                elif typ_ == 'as_string':
+                    return inp.as_string
+                elif typ_ == 'as_sym_int':
+                    return self.deserialize_sym_argument(inp.as_sym_int)
+                elif typ_ == 'as_sym_bool':
+                    return self.deserialize_sym_argument(inp.as_sym_bool)
+                elif isinstance(value, list):
+                    if len(value) == 0:
+                        return []
+                    elif isinstance(value[0], TensorArgument):
+                        result = []
+                        for arg in value:
+                            result.append(self.serialized_name_to_node[arg.name])
+                        return result
+                    elif isinstance(value[0], (int, float, bool)):
+                        # convert from serialized.python.types.List to python list
+                        return list(value)
+                    elif isinstance(value[0], (SymIntArgument, SymBoolArgument)):
+                        return [self.deserialize_sym_argument(arg) for arg in value]
+                    elif isinstance(value[0], OptionalTensorArgument):
+                        def deserialize_optional_tensor_args(a):
+                            if a.type == 'as_none':
+                                return None
+                            elif a.type == 'as_tensor':
+                                return self.serialized_name_to_node[a.value]
+                            else:
+                                raise SerializeError(f'Unhandled argument {inp}')
+                        return list(map(deserialize_optional_tensor_args, value))
+                    else:
+                        raise SerializeError(f'Unhandled argument {inp}')
+                elif typ_ == 'as_custom_obj':
+                    return self.constants[inp.as_custom_obj.name]
+                else {
+                    raise SerializeError(`Unhandled argument ${inp}.`);
+                }
+                */
             }
             deserialize_metadata(metadata) {
                 const ret = {};
@@ -5940,13 +6051,6 @@ python.Execution = class {
                 }
                 return new torch.device(d.type);
             }
-            sync_fx_node(name, fx_node) {
-                if (this.serialized_name_to_node.has(name)) {
-                    throw new python.Error(`Node ${name} has already been deserialized before.`);
-                }
-                this.serialized_name_to_node.set(name, fx_node);
-                fx_node.meta['val'] = this.serialized_name_to_meta.get(name);
-            }
         });
         this.registerFunction('torch_utils.persistence._reconstruct_persistent_obj', function(meta) {
             const name = `_imported_module_${Math.floor(Math.random() * 10000)}`;