|
|
@@ -40,7 +40,7 @@ torchscript.ModelFactory = class {
|
|
|
}
|
|
|
if (container.model) {
|
|
|
container.model = JSON.parse(textDecoder.decode(container.model));
|
|
|
- }
|
|
|
+ }
|
|
|
return torchscript.Metadata.open(host).then((metadata) => {
|
|
|
try {
|
|
|
return new torchscript.Model(metadata, host, python, container);
|
|
|
@@ -65,9 +65,9 @@ torchscript.ModelFactory = class {
|
|
|
|
|
|
static _openContainer(entries) {
|
|
|
if (entries && entries.length > 0) {
|
|
|
- let container = {};
|
|
|
const version = entries.find((entry) => entry.name == 'version' || entry.name.endsWith('/version'));
|
|
|
if (version) {
|
|
|
+ let container = {};
|
|
|
container.entries = entries;
|
|
|
container.prefix = version.name.substring(0, version.name.length - 7);
|
|
|
let find = (name) => {
|
|
|
@@ -104,8 +104,8 @@ torchscript.ModelFactory = class {
|
|
|
if (!data) {
|
|
|
return null;
|
|
|
}
|
|
|
- let functionTable = {};
|
|
|
- functionTable['collections.OrderedDict'] = function(args) {
|
|
|
+ let functionTable = new Map();
|
|
|
+ functionTable.set('collections.OrderedDict', function(args) {
|
|
|
let obj = [];
|
|
|
obj.__setitem__ = function(key, value) {
|
|
|
obj.push({ key: key, value: value });
|
|
|
@@ -116,8 +116,8 @@ torchscript.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
return obj;
|
|
|
- };
|
|
|
- functionTable['torch._utils._rebuild_tensor_v2'] = function(storage, storage_offset, size, stride, requires_grad, backward_hooks) {
|
|
|
+ });
|
|
|
+ functionTable.set('torch._utils._rebuild_tensor_v2', function(storage, storage_offset, size, stride, requires_grad, backward_hooks) {
|
|
|
return {
|
|
|
__type__: storage.__type__.replace('Storage', 'Tensor'),
|
|
|
storage: storage,
|
|
|
@@ -127,8 +127,8 @@ torchscript.ModelFactory = class {
|
|
|
requires_grad:requires_grad,
|
|
|
backward_hooks: backward_hooks
|
|
|
};
|
|
|
- };
|
|
|
- functionTable['torch._utils._rebuild_qtensor'] = function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
|
|
|
+ });
|
|
|
+ functionTable.set('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
|
|
|
return {
|
|
|
__type__: storage.__type__.replace('Storage', 'Tensor'),
|
|
|
storage: storage,
|
|
|
@@ -139,46 +139,46 @@ torchscript.ModelFactory = class {
|
|
|
requires_grad:requires_grad,
|
|
|
backward_hooks: backward_hooks
|
|
|
};
|
|
|
- }
|
|
|
- functionTable['torch.jit._pickle.build_intlist'] = function(data) {
|
|
|
+ });
|
|
|
+ functionTable.set('torch.jit._pickle.build_intlist', function(data) {
|
|
|
return data;
|
|
|
- }
|
|
|
- let constructorTable = {};
|
|
|
- constructorTable['torch.ByteStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ let constructorTable = new Map();
|
|
|
+ constructorTable.set('torch.ByteStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8';
|
|
|
- };
|
|
|
- constructorTable['torch.CharStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.CharStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 1; this.dataType = 'int8';
|
|
|
- };
|
|
|
- constructorTable['torch.ShortStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.ShortStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 2; this.dataType = 'int16';
|
|
|
- };
|
|
|
- constructorTable['torch.IntStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.IntStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 4; this.dataType = 'int32';
|
|
|
- };
|
|
|
- constructorTable['torch.LongStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.LongStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 8; this.dataType = 'int64';
|
|
|
- };
|
|
|
- constructorTable['torch.HalfStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.HalfStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 2; this.dataType = 'float16';
|
|
|
- };
|
|
|
- constructorTable['torch.FloatStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.FloatStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 4; this.dataType = 'float32';
|
|
|
- };
|
|
|
- constructorTable['torch.DoubleStorage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.DoubleStorage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 8; this.dataType = 'float64';
|
|
|
- };
|
|
|
- constructorTable['torch.QInt8Storage'] = function (size) {
|
|
|
+ });
|
|
|
+ constructorTable.set('torch.QInt8Storage', function (size) {
|
|
|
this.size = size; this.dataTypeSize = 1; this.dataType = 'qint8';
|
|
|
- };
|
|
|
+ });
|
|
|
let function_call = (name, args) => {
|
|
|
- let func = functionTable[name];
|
|
|
- if (func) {
|
|
|
+ if (functionTable.has(name)) {
|
|
|
+ const func = functionTable.get(name);
|
|
|
return func.apply(null, args);
|
|
|
}
|
|
|
let obj = { __type__: name };
|
|
|
- let constructor = constructorTable[name];
|
|
|
- if (constructor) {
|
|
|
+ if (constructorTable.has(name)) {
|
|
|
+ const constructor = constructorTable.get(name);
|
|
|
constructor.apply(obj, args);
|
|
|
}
|
|
|
else if (!name.startsWith('__torch__.')) {
|
|
|
@@ -1604,7 +1604,7 @@ torchscript.GraphContext = class {
|
|
|
}
|
|
|
// exponential_average_factor = 0.10000000000000001
|
|
|
if (expression.type === 'number') {
|
|
|
- this._state[target.value] = expression.value;
|
|
|
+ this._state[target.value] = expression;
|
|
|
return true;
|
|
|
}
|
|
|
const valueExpression = this._evaluateExpression(expression);
|