|
|
@@ -2560,107 +2560,105 @@ pytorch.Container.Zip = class {
|
|
|
const buffer = stream.peek();
|
|
|
this._data = this._unpickle(buffer, this._storage('data'));
|
|
|
}
|
|
|
- else {
|
|
|
- if (this._model) {
|
|
|
- this._producer = this._model.producerName + (this._model.producerVersion ? ' v' + this._model.producerVersion : '');
|
|
|
- this._data = this._model.mainModule || {};
|
|
|
- this._name = this._data.name || '';
|
|
|
- if (this._data.torchscriptArena) {
|
|
|
- this._torchscriptArena = this._data.torchscriptArena.key;
|
|
|
+ else if (this._model) {
|
|
|
+ this._producer = this._model.producerName + (this._model.producerVersion ? ' v' + this._model.producerVersion : '');
|
|
|
+ this._data = this._model.mainModule || {};
|
|
|
+ this._name = this._data.name || '';
|
|
|
+ if (this._data.torchscriptArena) {
|
|
|
+ this._torchscriptArena = this._data.torchscriptArena.key;
|
|
|
+ }
|
|
|
+ const queue = [ this._data ];
|
|
|
+ const entries = new Map();
|
|
|
+ for (const entry of this._entries) {
|
|
|
+ const name = entry[0];
|
|
|
+ const stream = entry[1];
|
|
|
+ const buffer = stream.peek();
|
|
|
+ entries.set(name, buffer);
|
|
|
+ }
|
|
|
+ const tensorTypeMap = new Map([
|
|
|
+ [ 'FLOAT', 'Float' ],
|
|
|
+ [ 'FLOAT16', 'Half' ],
|
|
|
+ [ 'DOUBLE', 'Double' ],
|
|
|
+ [ 'INT8', 'Char' ],
|
|
|
+ [ 'INT32', 'Int' ],
|
|
|
+ [ 'INT64', 'Long' ]
|
|
|
+ ]);
|
|
|
+ const constants = this._model.tensors || [];
|
|
|
+ this._constants = constants.map((constant) => {
|
|
|
+ const key = this._prefix + constant.data.key;
|
|
|
+ if (!tensorTypeMap.has(constant.dataType)) {
|
|
|
+ throw new pytorch.Error("Unsupported tensor data type '" + constant.dataType + "'.");
|
|
|
}
|
|
|
- const queue = [ this._data ];
|
|
|
- const entries = new Map();
|
|
|
- for (const entry of this._entries) {
|
|
|
- const name = entry[0];
|
|
|
- const stream = entry[1];
|
|
|
- const buffer = stream.peek();
|
|
|
- entries.set(name, buffer);
|
|
|
+ const type = tensorTypeMap.get(constant.dataType);
|
|
|
+ const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
|
|
|
+ const storage_type = this.execution.type('torch.' + type + 'Storage');
|
|
|
+ const size = (shape || []).reduce((a, b) => a * b, 1);
|
|
|
+ const offset = parseInt(constant.offset, 10) || 0;
|
|
|
+ const storage = new storage_type([ size ]);
|
|
|
+ const itemsize = storage.dtype.itemsize();
|
|
|
+ const buffer = entries.get(key);
|
|
|
+ const length = size * itemsize;
|
|
|
+ const data = buffer.slice(offset, offset + length);
|
|
|
+ storage._set_cdata(data);
|
|
|
+ const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
|
|
|
+ tensor.name = constant.data.key;
|
|
|
+ return tensor;
|
|
|
+ });
|
|
|
+ this._attributes = [];
|
|
|
+ const stream = this._entry('attributes.pkl');
|
|
|
+ if (stream) {
|
|
|
+ const buffer = stream.peek();
|
|
|
+ const unpickler = python.Unpickler.open(buffer);
|
|
|
+ this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
|
|
|
+ }
|
|
|
+ while (queue.length > 0) {
|
|
|
+ const module = queue.shift();
|
|
|
+ if (!module.__class__) {
|
|
|
+ module.__class__ = {
|
|
|
+ __module__: 'torch.nn.modules.module',
|
|
|
+ __name__: 'Module'
|
|
|
+ };
|
|
|
+ }
|
|
|
+ if (module.name) {
|
|
|
+ module.__id__ = module.name;
|
|
|
}
|
|
|
- const tensorTypeMap = new Map([
|
|
|
- [ 'FLOAT', 'Float' ],
|
|
|
- [ 'FLOAT16', 'Half' ],
|
|
|
- [ 'DOUBLE', 'Double' ],
|
|
|
- [ 'INT8', 'Char' ],
|
|
|
- [ 'INT32', 'Int' ],
|
|
|
- [ 'INT64', 'Long' ]
|
|
|
- ]);
|
|
|
- const constants = this._model.tensors || [];
|
|
|
- this._constants = constants.map((constant) => {
|
|
|
- const key = this._prefix + constant.data.key;
|
|
|
- if (!tensorTypeMap.has(constant.dataType)) {
|
|
|
- throw new pytorch.Error("Unsupported tensor data type '" + constant.dataType + "'.");
|
|
|
+ if (module.submodules) {
|
|
|
+ for (const submodule of module.submodules) {
|
|
|
+ module[submodule.name] = submodule;
|
|
|
+ submodule.__parent__ = module;
|
|
|
+ queue.push(submodule);
|
|
|
}
|
|
|
- const type = tensorTypeMap.get(constant.dataType);
|
|
|
- const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
|
|
|
- const storage_type = this.execution.type('torch.' + type + 'Storage');
|
|
|
- const size = (shape || []).reduce((a, b) => a * b, 1);
|
|
|
- const offset = parseInt(constant.offset, 10) || 0;
|
|
|
- const storage = new storage_type([ size ]);
|
|
|
- const itemsize = storage.dtype.itemsize();
|
|
|
- const buffer = entries.get(key);
|
|
|
- const length = size * itemsize;
|
|
|
- const data = buffer.slice(offset, offset + length);
|
|
|
- storage._set_cdata(data);
|
|
|
- const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
|
|
|
- tensor.name = constant.data.key;
|
|
|
- return tensor;
|
|
|
- });
|
|
|
- this._attributes = [];
|
|
|
- const stream = this._entry('attributes.pkl');
|
|
|
- if (stream) {
|
|
|
- const buffer = stream.peek();
|
|
|
- const unpickler = python.Unpickler.open(buffer);
|
|
|
- this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
|
|
|
+ delete module.submodules;
|
|
|
}
|
|
|
- while (queue.length > 0) {
|
|
|
- const module = queue.shift();
|
|
|
- if (!module.__class__) {
|
|
|
- module.__class__ = {
|
|
|
- __module__: 'torch.nn.modules.module',
|
|
|
- __name__: 'Module'
|
|
|
+ const attributes = [];
|
|
|
+ if (module.attributes) {
|
|
|
+ attributes.push(...module.attributes);
|
|
|
+ delete module.attributes;
|
|
|
+ }
|
|
|
+ const parameters = [];
|
|
|
+ if (module.parameters) {
|
|
|
+ parameters.push(...module.parameters);
|
|
|
+ delete module.parameters;
|
|
|
+ }
|
|
|
+ if (module.arguments) {
|
|
|
+ parameters.push(...module.arguments);
|
|
|
+ delete module.arguments;
|
|
|
+ }
|
|
|
+ for (const parameter of parameters) {
|
|
|
+ const tensor = this._constants[parameter.tensorId];
|
|
|
+ module[parameter.name] = tensor;
|
|
|
+ if (!parameter.__class__) {
|
|
|
+ parameter.__class__ = {
|
|
|
+ __module__: 'torch',
|
|
|
+ __name__: 'Tensor'
|
|
|
};
|
|
|
}
|
|
|
- if (module.name) {
|
|
|
- module.__id__ = module.name;
|
|
|
- }
|
|
|
- if (module.submodules) {
|
|
|
- for (const submodule of module.submodules) {
|
|
|
- module[submodule.name] = submodule;
|
|
|
- submodule.__parent__ = module;
|
|
|
- queue.push(submodule);
|
|
|
- }
|
|
|
- delete module.submodules;
|
|
|
- }
|
|
|
- const attributes = [];
|
|
|
- if (module.attributes) {
|
|
|
- attributes.push(...module.attributes);
|
|
|
- delete module.attributes;
|
|
|
- }
|
|
|
- const parameters = [];
|
|
|
- if (module.parameters) {
|
|
|
- parameters.push(...module.parameters);
|
|
|
- delete module.parameters;
|
|
|
- }
|
|
|
- if (module.arguments) {
|
|
|
- parameters.push(...module.arguments);
|
|
|
- delete module.arguments;
|
|
|
- }
|
|
|
- for (const parameter of parameters) {
|
|
|
- const tensor = this._constants[parameter.tensorId];
|
|
|
- module[parameter.name] = tensor;
|
|
|
- if (!parameter.__class__) {
|
|
|
- parameter.__class__ = {
|
|
|
- __module__: 'torch',
|
|
|
- __name__: 'Tensor'
|
|
|
- };
|
|
|
- }
|
|
|
- }
|
|
|
- for (const attribute of attributes) {
|
|
|
- module[attribute.name] = this._attributes[attribute.id];
|
|
|
- }
|
|
|
}
|
|
|
- delete this._model;
|
|
|
+ for (const attribute of attributes) {
|
|
|
+ module[attribute.name] = this._attributes[attribute.id];
|
|
|
+ }
|
|
|
}
|
|
|
+ delete this._model;
|
|
|
}
|
|
|
if (this.format.startsWith('TorchScript ') && (this._torchscriptArena || this._data.forward)) {
|
|
|
this._type = 'script';
|