|
|
@@ -56,6 +56,9 @@ pytorch.ModelFactory = class {
|
|
|
constructorTable['torch.autograd.variable.Variable'] = function() {};
|
|
|
constructorTable['torch.backends.cudnn.rnn.Unserializable'] = function() {};
|
|
|
constructorTable['torch.nn.backends.thnn._get_thnn_function_backend'] = function() {};
|
|
|
+ constructorTable['torch.nn.quantized.modules.functional_modules.FloatFunctional'] = function() {};
|
|
|
+ constructorTable['torch.quantization.stubs.DeQuantStub'] = function() {};
|
|
|
+ constructorTable['torch.quantization.stubs.QuantStub'] = function() {};
|
|
|
constructorTable['torch.nn.modules.activation.ELU'] = function() {};
|
|
|
constructorTable['torch.nn.modules.activation.GLU'] = function() {};
|
|
|
constructorTable['torch.nn.modules.activation.Hardtanh'] = function() {};
|
|
|
@@ -94,6 +97,7 @@ pytorch.ModelFactory = class {
|
|
|
constructorTable['torch.nn.modules.instancenorm.InstanceNorm2d'] = function() {};
|
|
|
constructorTable['torch.nn.modules.instancenorm.InstanceNorm3d'] = function() {};
|
|
|
constructorTable['torch.nn.modules.linear.Linear'] = function() {};
|
|
|
+ constructorTable['torch.nn.modules.linear.Identity'] = function() {};
|
|
|
constructorTable['torch.nn.modules.loss.BCELoss'] = function() {};
|
|
|
constructorTable['torch.nn.modules.loss.BCEWithLogitsLoss'] = function() {};
|
|
|
constructorTable['torch.nn.modules.loss.CrossEntropyLoss'] = function() {};
|
|
|
@@ -232,6 +236,22 @@ pytorch.ModelFactory = class {
|
|
|
this.stride = state[3];
|
|
|
};
|
|
|
};
|
|
|
+ constructorTable['torch.cuda.FloatTensor'] = function () {
|
|
|
+ this.__setstate__ = function(state) {
|
|
|
+ this.storage = state[0];
|
|
|
+ this.storage_offset = state[1];
|
|
|
+ this.size = state[2];
|
|
|
+ this.stride = state[3];
|
|
|
+ };
|
|
|
+ };
|
|
|
+ constructorTable['torch.cuda.DoubleTensor'] = function () {
|
|
|
+ this.__setstate__ = function(state) {
|
|
|
+ this.storage = state[0];
|
|
|
+ this.storage_offset = state[1];
|
|
|
+ this.size = state[2];
|
|
|
+ this.stride = state[3];
|
|
|
+ };
|
|
|
+ };
|
|
|
constructorTable['numpy.dtype'] = function(obj, align, copy) {
|
|
|
switch (obj) {
|
|
|
case 'i1': this.name = 'int8'; this.itemsize = 1; break;
|
|
|
@@ -665,10 +685,10 @@ pytorch.ModelFactory = class {
|
|
|
root.state_dict_stylepredictor, root.state_dict_ghiasi
|
|
|
];
|
|
|
for (let dict of candidates) {
|
|
|
- const state_dict =
|
|
|
- pytorch.ModelFactory._convertStateDictList(dict) ||
|
|
|
- pytorch.ModelFactory._convertStateDictMap(dict) ||
|
|
|
- pytorch.ModelFactory._convertStateDictGroupMap(dict);
|
|
|
+ let state_dict = null;
|
|
|
+ state_dict = state_dict || pytorch.ModelFactory._convertStateDictList(dict);
|
|
|
+ state_dict = state_dict || pytorch.ModelFactory._convertStateDictMap(dict);
|
|
|
+ state_dict = state_dict || pytorch.ModelFactory._convertStateDictGroupMap(dict);
|
|
|
if (state_dict) {
|
|
|
return state_dict;
|
|
|
}
|
|
|
@@ -777,21 +797,27 @@ pytorch.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
else if (Object(item) === item) {
|
|
|
+ let hasTensors = false;
|
|
|
for (let key in item) {
|
|
|
const value = item[key];
|
|
|
if (pytorch.ModelFactory._isTensor(value)) {
|
|
|
state_group.states.push({ name: key, value: value, id: state_group_name + '.' + key });
|
|
|
+ hasTensors = true;
|
|
|
}
|
|
|
else if (value !== Object(value)) {
|
|
|
state_group.attributes.push({ name: key, value: value });
|
|
|
}
|
|
|
else if (value && value.data && value.__module__ === 'torch.nn.parameter' && value.__name__ === 'Parameter') {
|
|
|
state_group.states.push({ name: key, value: value.data, id: state_group_name + '.' + key });
|
|
|
+ hasTensors = true;
|
|
|
}
|
|
|
else {
|
|
|
return null;
|
|
|
}
|
|
|
}
|
|
|
+ if (!hasTensors) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
}
|
|
|
else {
|
|
|
return null;
|
|
|
@@ -852,7 +878,7 @@ pytorch.ModelFactory = class {
|
|
|
}
|
|
|
|
|
|
static _isTensor(obj) {
|
|
|
- return obj && obj.__module__ === 'torch' && obj.__name__ && obj.__name__.endsWith('Tensor');
|
|
|
+ return obj && (obj.__module__ === 'torch' || obj.__module__ === 'torch.cuda') && obj.__name__ && obj.__name__.endsWith('Tensor');
|
|
|
}
|
|
|
};
|
|
|
|