|
|
@@ -3430,6 +3430,9 @@ python.Execution = class {
|
|
|
this.registerFunction('builtins.isinstance', function(obj, type) {
|
|
|
return obj.__class__ ? builtins.issubclass(obj.__class__, type) : false;
|
|
|
});
|
|
|
+ this.registerFunction('builtins.hasattr', function(obj, name) {
|
|
|
+ return Object.prototype.hasOwnProperty.call(obj, name);
|
|
|
+ });
|
|
|
this.registerFunction('builtins.getattr', function(obj, name, defaultValue) {
|
|
|
if (Object.prototype.hasOwnProperty.call(obj, name)) {
|
|
|
return obj[name];
|
|
|
@@ -5666,7 +5669,31 @@ python.Execution = class {
|
|
|
const artifact = new torch._export.serde.serialize.SerializedArtifact(serialized_exported_program, serialized_state_dict, serialized_constants);
|
|
|
return torch._export.serde.serialize.deserialize(artifact, expected_opset_version);
|
|
|
});
|
|
|
+ this.registerFunction('torch._export.serde.serialize._dict_to_dataclass', function(cls, data) {
|
|
|
+ if (data === null) {
|
|
|
+ return data;
|
|
|
+ }
|
|
|
+ if (data.$type) {
|
|
|
+ const res = {};
|
|
|
+ res[data.$type] = data.$value;
|
|
|
+ return res;
|
|
|
+ }
|
|
|
+ if (Array.isArray(data)) {
|
|
|
+ for (let i = 0; i < data.length; i++) {
|
|
|
+ data[i] = torch._export.serde.serialize._dict_to_dataclass(null, data[i]);
|
|
|
+ }
|
|
|
+ return data;
|
|
|
+ }
|
|
|
+ if (data === Object(data)) {
|
|
|
+ for (const key of Object.keys(data)) {
|
|
|
+ data[key] = torch._export.serde.serialize._dict_to_dataclass(null, data[key]);
|
|
|
+ }
|
|
|
+ return data;
|
|
|
+ }
|
|
|
+ return data;
|
|
|
+ });
|
|
|
this.registerFunction('torch._export.serde.serialize.deserialize', function(artifact, expected_opset_version) {
|
|
|
+ artifact.exported_program = torch._export.serde.serialize._dict_to_dataclass(null, artifact.exported_program);
|
|
|
return new torch._export.serde.serialize.ExportedProgramDeserializer(expected_opset_version).deserialize(artifact);
|
|
|
});
|
|
|
this.registerType('torch._export.serde.serialize.ExportedProgramDeserializer', class {
|
|
|
@@ -5746,7 +5773,7 @@ python.Execution = class {
|
|
|
this.constants = new Map(Object.entries(serialized_graph.constants).map(([k, v]) => [ k, torch.load(v) ]));
|
|
|
}
|
|
|
for (const [name, tensor_value] of Object.entries(serialized_graph.tensor_values)) {
|
|
|
- const meta_val = this.deserialize_tensor_meta(tensor_value.meta, this.fake_tensor_mode);
|
|
|
+ const meta_val = this.deserialize_tensor_meta(tensor_value.meta || tensor_value, this.fake_tensor_mode);
|
|
|
this.serialized_name_to_meta.set(name, meta_val);
|
|
|
}
|
|
|
for (const [name, sym_int_value] of Object.entries(serialized_graph.sym_int_values)) {
|
|
|
@@ -5825,7 +5852,50 @@ python.Execution = class {
|
|
|
deserialize_graph_output() {
|
|
|
}
|
|
|
deserialize_metadata(metadata) {
|
|
|
- return metadata; // TODO
|
|
|
+ const ret = {};
|
|
|
+ const stack_trace = metadata['stack_trace'];
|
|
|
+ if (stack_trace) {
|
|
|
+ ret['stack_trace'] = stack_trace;
|
|
|
+ }
|
|
|
+ const deserialize_meta_func = (serialized_target) => {
|
|
|
+ let module = null;
|
|
|
+ let serialized_target_names = [];
|
|
|
+ if (serialized_target.startsWith('torch.nn')) {
|
|
|
+ module = torch.nn;
|
|
|
+ serialized_target_names = serialized_target.split('.').slice(1);
|
|
|
+ } else if (serialized_target.startsWith('torch')) {
|
|
|
+ module = torch;
|
|
|
+ serialized_target_names = serialized_target.split('.').slice(1);
|
|
|
+ } else {
|
|
|
+ return this.deserialize_operator(serialized_target);
|
|
|
+ }
|
|
|
+ let target = module;
|
|
|
+ for (const name of serialized_target_names) {
|
|
|
+ if (!builtins.hasattr(target, name)) {
|
|
|
+ return serialized_target;
|
|
|
+ }
|
|
|
+ target = builtins.getattr(target, name);
|
|
|
+ }
|
|
|
+ return target;
|
|
|
+ };
|
|
|
+ const nn_module_stack_str = metadata['nn_module_stack'];
|
|
|
+ if (nn_module_stack_str) {
|
|
|
+ const import_nn_module_stack = (key, path, ty) => {
|
|
|
+ return [ key, [ path, ty ] ];
|
|
|
+ };
|
|
|
+ const nn_module_stack = new Map(nn_module_stack_str.split(';').map((item) => import_nn_module_stack(...item.split(','))));
|
|
|
+ ret['nn_module_stack'] = nn_module_stack;
|
|
|
+ }
|
|
|
+ const source_fn_st_str = metadata['source_fn_stack'];
|
|
|
+ if (source_fn_st_str) {
|
|
|
+ const source_fn_st = [];
|
|
|
+ for (const source_fn_str of source_fn_st_str.split(';')) {
|
|
|
+ const [name, target_str] = source_fn_str.split(',');
|
|
|
+ source_fn_st.push([ name, deserialize_meta_func(target_str) ]);
|
|
|
+ }
|
|
|
+ ret['source_fn_stack'] = source_fn_st;
|
|
|
+ }
|
|
|
+ return ret;
|
|
|
}
|
|
|
deserialize_tensor_meta(tensor_meta) {
|
|
|
const sizes = tensor_meta.sizes.map((val) => this.deserialize_sym_int(val));
|