Przeglądaj źródła

Update python.js (#1211)

Lutz Roeder 2 lat temu
rodzic
commit
0717a89bb3
1 zmienionych plików z 72 dodań i 2 usunięć
  1. 72 2
      source/python.js

+ 72 - 2
source/python.js

@@ -3430,6 +3430,9 @@ python.Execution = class {
         this.registerFunction('builtins.isinstance', function(obj, type) {
             return obj.__class__ ? builtins.issubclass(obj.__class__, type) : false;
         });
+        this.registerFunction('builtins.hasattr', function(obj, name) {
+            return Object.prototype.hasOwnProperty.call(obj, name);
+        });
         this.registerFunction('builtins.getattr', function(obj, name, defaultValue) {
             if (Object.prototype.hasOwnProperty.call(obj, name)) {
                 return obj[name];
@@ -5666,7 +5669,31 @@ python.Execution = class {
             const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants);
             return torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
         });
+        this.registerFunction('torch._export.serde.serialize._dict_to_dataclass', function(cls, data) {
+            if (data === null) {
+                return data;
+            }
+            if (data.$type) {
+                const res = {};
+                res[data.$type] = data.$value;
+                return res;
+            }
+            if (Array.isArray(data)) {
+                for (let i = 0; i < data.length; i++) {
+                    data[i] = torch._export.serde.serialize._dict_to_dataclass(null, data[i]);
+                }
+                return data;
+            }
+            if (data === Object(data)) {
+                for (const key of Object.keys(data)) {
+                    data[key] = torch._export.serde.serialize._dict_to_dataclass(null, data[key]);
+                }
+                return data;
+            }
+            return data;
+        });
         this.registerFunction('torch._export.serde.serialize.deserialize', function(artifact, expected_opset_version) {
+            artifact.exported_program = torch._export.serde.serialize._dict_to_dataclass(null, artifact.exported_program);
             return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(artifact);
         });
         this.registerType('torch._export.serde.serialize.ExportedProgramDeserializer', class {
@@ -5746,7 +5773,7 @@ python.Execution = class {
                     this.constants = new Map(Object.entries(serialized_graph.constants).map(([k, v]) => [ k, torch.load(v) ]));
                 }
                 for (const [name, tensor_value] of Object.entries(serialized_graph.tensor_values)) {
-                    const meta_val = this.deserialize_tensor_meta(tensor_value.meta, this.fake_tensor_mode);
+                    const meta_val = this.deserialize_tensor_meta(tensor_value.meta || tensor_value, this.fake_tensor_mode);
                     this.serialized_name_to_meta.set(name, meta_val);
                 }
                 for (const [name, sym_int_value] of Object.entries(serialized_graph.sym_int_values)) {
@@ -5825,7 +5852,50 @@ python.Execution = class {
             deserialize_graph_output() {
             }
             deserialize_metadata(metadata) {
-                return metadata; // TODO
+                const ret = {};
+                const stack_trace = metadata['stack_trace'];
+                if (stack_trace) {
+                    ret['stack_trace'] = stack_trace;
+                }
+                const deserialize_meta_func = (serialized_target) => {
+                    let module = null;
+                    let serialized_target_names = [];
+                    if (serialized_target.startsWith('torch.nn')) {
+                        module = torch.nn;
+                        serialized_target_names = serialized_target.split('.').slice(1);
+                    } else if (serialized_target.startsWith('torch')) {
+                        module = torch;
+                        serialized_target_names = serialized_target.split('.').slice(1);
+                    } else {
+                        return this.deserialize_operator(serialized_target);
+                    }
+                    let target = module;
+                    for (const name of serialized_target_names) {
+                        if (!builtins.hasattr(target, name)) {
+                            return serialized_target;
+                        }
+                        target = builtins.getattr(target, name);
+                    }
+                    return target;
+                };
+                const nn_module_stack_str = metadata['nn_module_stack'];
+                if (nn_module_stack_str) {
+                    const import_nn_module_stack = (key, path, ty) => {
+                        return [ key, [ path, ty ] ];
+                    };
+                    const nn_module_stack = new Map(nn_module_stack_str.split(';').map((item) => import_nn_module_stack(...item.split(','))));
+                    ret['nn_module_stack'] = nn_module_stack;
+                }
+                const source_fn_st_str = metadata['source_fn_stack'];
+                if (source_fn_st_str) {
+                    const source_fn_st = [];
+                    for (const source_fn_str of source_fn_st_str.split(';')) {
+                        const [name, target_str] = source_fn_str.split(',');
+                        source_fn_st.push([ name, deserialize_meta_func(target_str) ]);
+                    }
+                    ret['source_fn_stack'] = source_fn_st;
+                }
+                return ret;
             }
             deserialize_tensor_meta(tensor_meta) {
                 const sizes = tensor_meta.sizes.map((val) => this.deserialize_sym_int(val));