|
|
@@ -4967,6 +4967,35 @@ python.Execution = class {
|
|
|
}
|
|
|
return this[name];
|
|
|
}
|
|
|
+ __delattr__(name) {
|
|
|
+ if (this._modules.has(name)) {
|
|
|
+ this._modules.delete(name);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ children() {
|
|
|
+ return this._modules.values();
|
|
|
+ }
|
|
|
+ named_children() {
|
|
|
+ return this._modules;
|
|
|
+ }
|
|
|
+ parameters() {
|
|
|
+ return this._parameters.values();
|
|
|
+ }
|
|
|
+ named_parameters(recurse) {
|
|
|
+ if (recurse) {
|
|
|
+ throw new python.Error('Named parameters with recurse not implemented.');
|
|
|
+ }
|
|
|
+ return this._parameters;
|
|
|
+ }
|
|
|
+ buffers() {
|
|
|
+ return this._buffers.values();
|
|
|
+ }
|
|
|
+ named_buffers(recurse) {
|
|
|
+ if (recurse) {
|
|
|
+ throw new python.Error('Named parameters with recurse not implemented.');
|
|
|
+ }
|
|
|
+ return this._buffers;
|
|
|
+ }
|
|
|
});
|
|
|
torch.nn.Module = torch.nn.modules.module.Module;
|
|
|
torch.nn.modules.Module = torch.nn.modules.module.Module;
|
|
|
@@ -6210,14 +6239,14 @@ python.Execution = class {
|
|
|
});
|
|
|
this.registerFunction('torch._utils._rebuild_tensor_v3');
|
|
|
this.registerFunction('torch._utils._rebuild_parameter', (data, requires_grad, backward_hooks) => {
|
|
|
- const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]);
|
|
|
+ const param = new torch.nn.parameter.Parameter(data, requires_grad);
|
|
|
param.backward_hooks = backward_hooks;
|
|
|
return param;
|
|
|
});
|
|
|
this.registerFunction('torch._utils._rebuild_parameter_v2', (data, requires_grad, backward_hooks, state) => {
|
|
|
- const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]);
|
|
|
+ const param = new torch.nn.parameter.Parameter(data, requires_grad);
|
|
|
param.backward_hooks = backward_hooks;
|
|
|
- execution.invoke('torch._utils._set_obj_state', [param, state]);
|
|
|
+ torch._utils._set_obj_state(param, state);
|
|
|
return param;
|
|
|
});
|
|
|
this.registerFunction('torch._utils._rebuild_parameter_with_state', (data, requires_grad, backward_hooks, state) => {
|
|
|
@@ -6225,16 +6254,16 @@ python.Execution = class {
|
|
|
const [dict_state, slots_state] = Array.isArray(state) ? state : [state, null];
|
|
|
if (dict_state) {
|
|
|
for (const [k, v] of Object.entries(dict_state)) {
|
|
|
- self.invoke('builtins.setattr', [obj, k, v]);
|
|
|
+ builtins.setattr(obj, k, v);
|
|
|
}
|
|
|
}
|
|
|
if (slots_state) {
|
|
|
for (const [k, v] of Object.entries(slots_state)) {
|
|
|
- self.invoke('builtins.setattr', [obj, k, v]);
|
|
|
+ builtins.setattr(obj, k, v);
|
|
|
}
|
|
|
}
|
|
|
};
|
|
|
- const param = self.invoke('torch.nn.parameter.Parameter', [data, requires_grad]);
|
|
|
+ const param = new torch.nn.parameter.Parameter(data, requires_grad);
|
|
|
param._backward_hooks = backward_hooks;
|
|
|
_set_obj_state(param, state);
|
|
|
return param;
|
|
|
@@ -6255,12 +6284,12 @@ python.Execution = class {
|
|
|
}
|
|
|
if (dict_state) {
|
|
|
for (const [name, value] of Object.entries(dict_state)) {
|
|
|
- execution.invoke('builtins.setattr', [obj, name, value]);
|
|
|
+ builtins.setattr(obj, name, value);
|
|
|
}
|
|
|
}
|
|
|
if (slots_state) {
|
|
|
for (const [name, value] of Object.entries(slots_state)) {
|
|
|
- execution.invoke('builtins.setattr', [obj, name, value]);
|
|
|
+ builtins.setattr(obj, name, value);
|
|
|
}
|
|
|
}
|
|
|
return obj;
|
|
|
@@ -6374,6 +6403,9 @@ python.Execution = class {
|
|
|
return value;
|
|
|
});
|
|
|
this.registerFunction('torch.clear', (value) => {
|
|
|
+ if (value instanceof torch.Value) {
|
|
|
+ throw new python.Error('Invalid value.');
|
|
|
+ }
|
|
|
if (Object(value) === value) {
|
|
|
for (const key of Object.keys(value)) {
|
|
|
delete value[key];
|
|
|
@@ -6958,6 +6990,9 @@ python.Execution = class {
|
|
|
}
|
|
|
return this.equals(rhs);
|
|
|
}
|
|
|
+ is_module() {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
expect(type) {
|
|
|
if (this instanceof type === false) {
|
|
|
throw new python.Error(`Expected '${type.kind()}' but got '${this.kind()}'.`);
|
|
|
@@ -7004,6 +7039,12 @@ python.Execution = class {
|
|
|
is_module() {
|
|
|
return this._is_module;
|
|
|
}
|
|
|
+ is_parameter(slot) {
|
|
|
+ return this._attributes[slot].is_parameter === true;
|
|
|
+ }
|
|
|
+ is_buffer(slot) {
|
|
|
+ return this._attributes[slot].is_buffer === true;
|
|
|
+ }
|
|
|
addMethod(func) {
|
|
|
this._methods.set(func.name(), func);
|
|
|
}
|
|
|
@@ -7023,6 +7064,9 @@ python.Execution = class {
|
|
|
findStaticMethod(name) {
|
|
|
return this._staticmethods.get(name);
|
|
|
}
|
|
|
+ numAttributes() {
|
|
|
+ return this._attributes.length;
|
|
|
+ }
|
|
|
addAttribute(name, type, is_parameter, is_buffer) {
|
|
|
is_parameter = is_parameter || false;
|
|
|
is_buffer = is_buffer || false;
|
|
|
@@ -7057,10 +7101,16 @@ python.Execution = class {
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
- getAttribute(name) {
|
|
|
- const slot = this.findAttributeSlot(name);
|
|
|
+ hasAttribute(name) {
|
|
|
+ return this._attributes.find((attr) => attr.name === name);
|
|
|
+ }
|
|
|
+ getAttribute(arg) {
|
|
|
+ const slot = Number.isInteger(arg) ? arg : this.findAttributeSlot(arg);
|
|
|
return this._attributes[slot].type;
|
|
|
}
|
|
|
+ getAttributeName(slot) {
|
|
|
+ return this._attributes[slot].name;
|
|
|
+ }
|
|
|
hasConstant(/* name */) {
|
|
|
}
|
|
|
methods() {
|
|
|
@@ -8131,7 +8181,10 @@ python.Execution = class {
|
|
|
parseBroadcastList(/* expr */) {
|
|
|
return null;
|
|
|
}
|
|
|
-
|
|
|
+ parseType(str) {
|
|
|
+ const expr = ast.parse(str);
|
|
|
+ return this.parseTypeFromExpr(expr.body[0]);
|
|
|
+ }
|
|
|
});
|
|
|
this.registerType('torch._ops.OpOverload', class extends torch._ops.OperatorBase {
|
|
|
constructor(overloadpacket, op, op_dk, schema, tags) {
|
|
|
@@ -8874,8 +8927,9 @@ python.Execution = class {
|
|
|
this._loaded_sources = new Set();
|
|
|
this._to_be_defined = new Map();
|
|
|
}
|
|
|
- loadType(/* name */) {
|
|
|
- //
|
|
|
+ loadType(name) {
|
|
|
+ const type_parser = new torch.jit.ScriptTypeParser(this);
|
|
|
+ return type_parser.parseType(name.qualifiedName());
|
|
|
}
|
|
|
resolveType(name) {
|
|
|
name = new torch.jit.QualifiedName(name);
|
|
|
@@ -9120,12 +9174,32 @@ python.Execution = class {
|
|
|
for (let i = 0; i < constants.length; i++) {
|
|
|
execution.builtins.CONSTANTS[`c${i}`] = constants[i];
|
|
|
}
|
|
|
- const module = this.readArchive('data');
|
|
|
- const name = `${module.__class__.__module__}.${module.__class__.__name__}`;
|
|
|
- const type = torch.ClassType.create(name, null, true);
|
|
|
- const result = new torch.ScriptModule(type);
|
|
|
- result.data = module;
|
|
|
- return result;
|
|
|
+ const obj = this.readArchive('data');
|
|
|
+ const convertModule = (obj) => {
|
|
|
+ if (obj.__class__) {
|
|
|
+ const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
|
|
|
+ const type = this._source_importer.loadType(new torch.jit.QualifiedName(name));
|
|
|
+ const module = new torch.ScriptModule(type, this._compilation_unit);
|
|
|
+ for (let i = 0; i < type.numAttributes(); i++) {
|
|
|
+ const k = type.getAttributeName(i);
|
|
|
+ const t = type.getAttribute(i);
|
|
|
+ const v = obj[k];
|
|
|
+ if (t.is_module()) {
|
|
|
+ module.__setattr__(k, convertModule(v));
|
|
|
+ } else {
|
|
|
+ module.__setattr__(k, obj[k]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const [key, value] of Object.entries(Object.getPrototypeOf(obj))) {
|
|
|
+ if (value && value.__class__ === builtins.method) {
|
|
|
+ module[key] = value;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return module;
|
|
|
+ }
|
|
|
+ throw new python.Error('Module class not found.');
|
|
|
+ };
|
|
|
+ return convertModule(obj);
|
|
|
}
|
|
|
LEGACY_deserialize() {
|
|
|
const execution = this._compilation_unit.execution;
|
|
|
@@ -9186,8 +9260,8 @@ python.Execution = class {
|
|
|
for (const tensor of tensor_table) {
|
|
|
this._constant_table.push(tensor);
|
|
|
}
|
|
|
- const temp = this.LEGACY_convertModule(module_def);
|
|
|
- const data = obj.mainModule || {};
|
|
|
+ return this.LEGACY_convertModule(module_def);
|
|
|
+ /* const data = obj.mainModule || {};
|
|
|
const queue = [data];
|
|
|
while (queue.length > 0) {
|
|
|
const module = queue.shift();
|
|
|
@@ -9237,6 +9311,8 @@ python.Execution = class {
|
|
|
const result = new torch.ScriptModule(temp.type());
|
|
|
result.data = data;
|
|
|
return result;
|
|
|
+ return module;
|
|
|
+ */
|
|
|
}
|
|
|
LEGACY_convertModule(module_def) {
|
|
|
const atoms = new torch.jit.QualifiedName(module_def.name).atoms();
|
|
|
@@ -9245,13 +9321,14 @@ python.Execution = class {
|
|
|
const sanitized = /^\d+$/.test(atom) ? `_${atom}` : atom;
|
|
|
this._LEGACY_moduleStack.push(sanitized);
|
|
|
}
|
|
|
- const module = new torch.ScriptModule(new torch.jit.QualifiedName(this._LEGACY_moduleStack), this._compilation_unit);
|
|
|
+ const qn = new torch.jit.QualifiedName(this._LEGACY_moduleStack);
|
|
|
+ const module = new torch.ScriptModule(qn, this._compilation_unit);
|
|
|
for (const sub_def of module_def.submodules || []) {
|
|
|
const submodule = this.LEGACY_convertModule(sub_def);
|
|
|
module.register_module(sub_def.name, submodule);
|
|
|
}
|
|
|
for (const param_def of module_def.parameters || []) {
|
|
|
- const tensor = this._constant_table[Number(param_def.tensorId)];
|
|
|
+ const tensor = this._constant_table[Number(param_def.tensor_id)];
|
|
|
if (param_def.isBuffer) {
|
|
|
module.register_buffer(param_def.name, tensor);
|
|
|
} else {
|
|
|
@@ -9263,11 +9340,21 @@ python.Execution = class {
|
|
|
if (module.hasattr(attr_def.name)) {
|
|
|
continue;
|
|
|
}
|
|
|
+ throw new python.Error('Not implemented.');
|
|
|
// IValue ivalue;
|
|
|
// if (attr_def.id() >= 0) {
|
|
|
// ivalue = LEGACY_pickled_ivalues_.at(attr_def.id());
|
|
|
// }
|
|
|
- // module.register_attribute(attr_def.name(), typeParser.parseType(attr_def.type()), ivalue);
|
|
|
+ // module.register_attribute(attr_def.name, typeParser.parseType(attr_def.type), ivalue);
|
|
|
+ }
|
|
|
+ if (module_def.torchscript_arena) {
|
|
|
+ const key = module_def.torchscript_arena.key;
|
|
|
+ const file = key.substring('code/'.length);
|
|
|
+ const name = file.replace(/\.py$/, '').split('/').join('.');
|
|
|
+ const code = execution.import(name);
|
|
|
+ if (code.forward.__class__ === execution.builtins.function) {
|
|
|
+ module.forward = code.forward;
|
|
|
+ }
|
|
|
}
|
|
|
/*
|
|
|
std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr;
|
|
|
@@ -9299,9 +9386,13 @@ python.Execution = class {
|
|
|
return module;
|
|
|
}
|
|
|
readArchive(archive_name) {
|
|
|
- const type_resolver = null;
|
|
|
- const obj_loader = null;
|
|
|
- return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, obj_loader, this._device, this._reader, null, this._storage_context);
|
|
|
+ const type_resolver = (qn) => {
|
|
|
+ const cls = this._source_importer.loadType(qn);
|
|
|
+ return cls;
|
|
|
+ };
|
|
|
+ const ObjLoaderFunc = (/* type, ivalue */) => {
|
|
|
+ };
|
|
|
+ return this.readArchiveAndTensors(archive_name, this._pickle_dir_prefix, this._tensor_dir_prefix, type_resolver, ObjLoaderFunc, this._device, this._reader, null, this._storage_context);
|
|
|
}
|
|
|
readArchiveAndTensors(archive_name, pickle_prefix, tensor_prefix, type_resolver, obj_loader, device, stream_reader, type_parser, storage_context) {
|
|
|
const picklename = `${pickle_prefix + archive_name}.pkl`;
|
|
|
@@ -9405,7 +9496,9 @@ python.Execution = class {
|
|
|
const cu = new torch.jit.CompilationUnit();
|
|
|
cu.execution = execution;
|
|
|
const cpp_module = torch._C.import_ir_module(cu, file, map_location, extra_files);
|
|
|
- return new torch.jit._script.RecursiveScriptModule(cpp_module);
|
|
|
+ const module = torch.jit._script.wrap_cpp_module(cpp_module);
|
|
|
+ module.forward = cpp_module.forward; // remove
|
|
|
+ return module;
|
|
|
});
|
|
|
this.registerFunction('torch._C.import_ir_module', function(cu, reader, ...args) {
|
|
|
switch (arguments.length) {
|
|
|
@@ -9495,7 +9588,7 @@ python.Execution = class {
|
|
|
const cu = new torch.jit.CompilationUnit();
|
|
|
cu.execution = execution;
|
|
|
const cpp_module = torch._C._import_ir_module_from_package(cu, importer.zip_reader, importer.storage_context, importer.last_map_location, script_module_id);
|
|
|
- return new torch.jit._script.RecursiveScriptModule(cpp_module);
|
|
|
+ return torch.jit._script.wrap_cpp_module(cpp_module);
|
|
|
});
|
|
|
this.registerFunction('torch.jit._script.wrap_cpp_module', (cpp_module) => {
|
|
|
const init_fn = (script_module) => {
|
|
|
@@ -9552,7 +9645,8 @@ python.Execution = class {
|
|
|
});
|
|
|
this.registerType('torch.ScriptObject', class {
|
|
|
constructor(type) {
|
|
|
- this._type = type;
|
|
|
+ this._typ = type;
|
|
|
+ this._ivalue = {};
|
|
|
}
|
|
|
static create(type) {
|
|
|
if (type.is_module()) {
|
|
|
@@ -9561,10 +9655,10 @@ python.Execution = class {
|
|
|
return new torch.ScriptObject(type);
|
|
|
}
|
|
|
type() {
|
|
|
- return this._type;
|
|
|
+ return this._typ;
|
|
|
}
|
|
|
_type() {
|
|
|
- return this._type; // torch.ClassType
|
|
|
+ return this._typ; // torch.ClassType
|
|
|
}
|
|
|
_get_method(name) {
|
|
|
for (const fn of this._type.methods()) {
|
|
|
@@ -9579,13 +9673,16 @@ python.Execution = class {
|
|
|
}
|
|
|
__setattr__(name, value) {
|
|
|
// if (this._type.hasContant(name))
|
|
|
- this[name] = value;
|
|
|
+ this._ivalue[name] = value;
|
|
|
}
|
|
|
__getattr__(name) {
|
|
|
- return this[name];
|
|
|
+ return this._ivalue[name];
|
|
|
}
|
|
|
hasattr(name) {
|
|
|
- return this._type.hasAttribute(name) || this._type.hasConstant(name);
|
|
|
+ return this._typ.hasAttribute(name) || this._typ.hasConstant(name);
|
|
|
+ }
|
|
|
+ getattr(name) {
|
|
|
+ return this.__getattr__(name);
|
|
|
}
|
|
|
_properties() {
|
|
|
throw new python.Error();
|
|
|
@@ -9601,7 +9698,7 @@ python.Execution = class {
|
|
|
}
|
|
|
}
|
|
|
get qualified_name() {
|
|
|
- return this._type.qualified_name();
|
|
|
+ return this.type().qualified_name();
|
|
|
}
|
|
|
get code_with_constants() {
|
|
|
const const_map = {};
|
|
|
@@ -9622,25 +9719,22 @@ python.Execution = class {
|
|
|
return false;
|
|
|
}
|
|
|
};
|
|
|
- if (!this.data) {
|
|
|
+ if (!this.forward) {
|
|
|
return null;
|
|
|
}
|
|
|
- if (!this.data.forward) {
|
|
|
- throw new python.Error("Module 'forward' not implemented.");
|
|
|
- }
|
|
|
execution.traceAttr = false;
|
|
|
const args = [];
|
|
|
if (!execution.traceAttr) {
|
|
|
- args.push(this.data); // self
|
|
|
+ args.push(this); // self
|
|
|
}
|
|
|
- if (this.data.forward.__code__ && this.data.forward.__code__.args) {
|
|
|
- const params = this.data.forward.__code__.args.args;
|
|
|
+ if (this.forward.__code__ && this.forward.__code__.args) {
|
|
|
+ const params = this.forward.__code__.args.args;
|
|
|
for (let i = 0; i < params.length; i++) {
|
|
|
const arg = params[i];
|
|
|
if (execution.traceAttr || arg.arg !== 'self') {
|
|
|
const value = execution.graph.addInput(arg.arg);
|
|
|
if (i === 0 && arg.arg === 'self' && !arg.annotation) {
|
|
|
- value.setType(this._type);
|
|
|
+ value.setType(this.type());
|
|
|
} else {
|
|
|
value.setType(execution.type(arg.annotation));
|
|
|
}
|
|
|
@@ -9653,7 +9747,7 @@ python.Execution = class {
|
|
|
}
|
|
|
}
|
|
|
execution.purge = new Set();
|
|
|
- const result = this.data.forward.__call__(args);
|
|
|
+ const result = this.forward.__call__(args);
|
|
|
const queue = Array.from(execution.purge);
|
|
|
const visited = new Set();
|
|
|
while (queue.length > 0) {
|
|
|
@@ -9715,15 +9809,15 @@ python.Execution = class {
|
|
|
}
|
|
|
register_module(name, module) {
|
|
|
this.type().addOrCheckAttribute(name, module.type());
|
|
|
- // _ivalue()->setAttr(name, module._ivalue());
|
|
|
+ this.__setattr__(name, module); // _ivalue()->setAttr(name, module._ivalue());
|
|
|
}
|
|
|
- register_buffer(name /* , v */) {
|
|
|
+ register_buffer(name, v) {
|
|
|
this.type().addOrCheckAttribute(name, torch.TensorType.get(), false, true);
|
|
|
- // _ivalue()->setAttr(name, std::move(v));
|
|
|
+ this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v));
|
|
|
}
|
|
|
register_parameter(name, v, is_buffer) {
|
|
|
this.type().addOrCheckAttribute(name, torch.TensorType.get(), !is_buffer, is_buffer);
|
|
|
- // _ivalue()->setAttr(name, std::move(v));
|
|
|
+ this.__setattr__(name, v); // _ivalue()->setAttr(name, std::move(v));
|
|
|
}
|
|
|
register_attribute(name, t, v, is_param, is_buffer) {
|
|
|
this.type().addOrCheckAttribute(name, t, is_param, is_buffer);
|
|
|
@@ -9731,11 +9825,59 @@ python.Execution = class {
|
|
|
}
|
|
|
});
|
|
|
this.registerType('torch.ModuleDict', class {
|
|
|
- constructor(module) {
|
|
|
- this._items = Object.entries(module).filter(([, value]) => value instanceof torch.ScriptModule);
|
|
|
+ constructor(mod) {
|
|
|
+ this._module = mod;
|
|
|
+ }
|
|
|
+ items() {
|
|
|
+ const result = new Map();
|
|
|
+ const type = this._module.type();
|
|
|
+ for (let i = 0; i < type.numAttributes(); i++) {
|
|
|
+ const k = type.getAttributeName(i);
|
|
|
+ const t = type.getAttribute(i);
|
|
|
+ if (t && t.is_module()) {
|
|
|
+ result.set(k, this._module.__getattr__(k));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('torch.ParameterDict', class {
|
|
|
+ constructor(mod) {
|
|
|
+ this._module = mod;
|
|
|
+ }
|
|
|
+ items() {
|
|
|
+ const result = new Map();
|
|
|
+ const type = this._module.type();
|
|
|
+ for (let i = 0; i < type.numAttributes(); i++) {
|
|
|
+ if (type.is_parameter(i)) {
|
|
|
+ const k = type.getAttributeName(i);
|
|
|
+ const v = this._module.__getattr__(k);
|
|
|
+ if (v instanceof torch.Tensor) {
|
|
|
+ result.set(k, v);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('torch.BufferDict', class {
|
|
|
+ constructor(mod) {
|
|
|
+ this._module = mod;
|
|
|
}
|
|
|
items() {
|
|
|
- return this._items;
|
|
|
+ const result = new Map();
|
|
|
+ const type = this._module.type();
|
|
|
+ for (let i = 0; i < type.numAttributes(); i++) {
|
|
|
+ if (type.is_buffer(i)) {
|
|
|
+ const t = type.getAttribute(i);
|
|
|
+ if (t.isSubtypeOf(torch.TensorType.get())) {
|
|
|
+ const k = type.getAttributeName(i);
|
|
|
+ const v = this._module.__getattr__(k);
|
|
|
+ result.set(k, v);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return result;
|
|
|
}
|
|
|
});
|
|
|
this.registerType('torch.jit.to_ir', class {
|
|
|
@@ -9846,11 +9988,11 @@ python.Execution = class {
|
|
|
torch.jit._script.RecursiveScriptModule._finalize_scriptmodule(script_module);
|
|
|
return script_module;
|
|
|
}
|
|
|
- static _finalize_scriptmodule() {
|
|
|
- this._initializing = false;
|
|
|
- }
|
|
|
- get data() {
|
|
|
- return this._c.data;
|
|
|
+ static _finalize_scriptmodule(script_module) {
|
|
|
+ script_module._parameters = new torch.ParameterDict(script_module._c).items();
|
|
|
+ script_module._buffers = new torch.BufferDict(script_module._c).items();
|
|
|
+ // script_module._modules = OrderedModuleDict(script_module._c, script_module._modules)
|
|
|
+ script_module._initializing = false;
|
|
|
}
|
|
|
get graph() {
|
|
|
// return this._c._get_method("forward").graph;
|
|
|
@@ -9863,8 +10005,8 @@ python.Execution = class {
|
|
|
__setattr__(name, value) {
|
|
|
if (this._initializing) {
|
|
|
super.__setattr__(name, value);
|
|
|
- } else if (this.modules.has(name)) {
|
|
|
- this.modules.set(name, value);
|
|
|
+ } else if (this._modules.has(name)) {
|
|
|
+ this._modules.set(name, value);
|
|
|
} else if (this._c.hasattr(name)) {
|
|
|
this._c.setattr(name, value);
|
|
|
} else {
|
|
|
@@ -9875,8 +10017,8 @@ python.Execution = class {
|
|
|
if (this._initializing) {
|
|
|
return super.__getattr__(name);
|
|
|
}
|
|
|
- if (this.modules.has(name)) {
|
|
|
- return this.modules.get(name);
|
|
|
+ if (this._modules.has(name)) {
|
|
|
+ return this._modules.get(name);
|
|
|
}
|
|
|
if (this._c.hasattr(name)) {
|
|
|
return this._c.getattr(name);
|
|
|
@@ -11646,10 +11788,7 @@ python.Execution = class {
|
|
|
this.registerType('torch.nn.parameter.Parameter', class extends torch.Tensor {
|
|
|
constructor(data, requires_grad) {
|
|
|
super();
|
|
|
- if (!data) {
|
|
|
- data = self.invoke('torch.Tensor', [[]]);
|
|
|
- }
|
|
|
- this.data = data;
|
|
|
+ this.data = data || new torch.Tensor([]);
|
|
|
this.requires_grad = requires_grad === undefined ? true : requires_grad;
|
|
|
}
|
|
|
});
|
|
|
@@ -12544,7 +12683,12 @@ python.Execution = class {
|
|
|
if (path) {
|
|
|
let target = null;
|
|
|
for (let i = path.length - 1; i >= 0; i--) {
|
|
|
- target = target ? target[path[i]] : context.get(path[i]);
|
|
|
+ const name = path[i];
|
|
|
+ if (target) {
|
|
|
+ target = target.__getattr__ ? target.__getattr__(name) : target[name];
|
|
|
+ } else {
|
|
|
+ target = context.get(name);
|
|
|
+ }
|
|
|
if (!target) {
|
|
|
break;
|
|
|
}
|