|
|
@@ -2133,6 +2133,7 @@ pytorch.Container.Tar = class {
|
|
|
|
|
|
constructor(entries) {
|
|
|
this._entries = entries;
|
|
|
+ this._graphs = [ this ];
|
|
|
}
|
|
|
|
|
|
set metadata(value) {
|
|
|
@@ -2147,6 +2148,10 @@ pytorch.Container.Tar = class {
|
|
|
return 'PyTorch v0.1.1';
|
|
|
}
|
|
|
|
|
|
+ get graphs() {
|
|
|
+ return this._graphs;
|
|
|
+ }
|
|
|
+
|
|
|
get type() {
|
|
|
this._unpickle();
|
|
|
return this._type;
|
|
|
@@ -2267,6 +2272,7 @@ pytorch.Container.Pickle = class {
|
|
|
|
|
|
constructor(stream) {
|
|
|
this._stream = stream;
|
|
|
+ this._graphs = [ this ];
|
|
|
}
|
|
|
|
|
|
set metadata(value) {
|
|
|
@@ -2281,6 +2287,10 @@ pytorch.Container.Pickle = class {
|
|
|
return 'PyTorch v0.1.10';
|
|
|
}
|
|
|
|
|
|
+ get graphs() {
|
|
|
+ return this._graphs;
|
|
|
+ }
|
|
|
+
|
|
|
get type() {
|
|
|
this._unpickle();
|
|
|
return this._type;
|
|
|
@@ -2394,35 +2404,49 @@ pytorch.Container.Pickle = class {
|
|
|
pytorch.Container.Zip = class {
|
|
|
|
|
|
static open(entries) {
|
|
|
- const name = Array.from(entries.keys()).find((name) => name == 'model.json' || name == 'data.pkl' || name.endsWith('/model.json') || name.endsWith('/data.pkl'));
|
|
|
- if (!name) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- let model = null;
|
|
|
- if (name.endsWith('.json')) {
|
|
|
- try {
|
|
|
- const stream = entries.get(name);
|
|
|
- const buffer = stream.peek();
|
|
|
- const decoder = new TextDecoder('utf-8');
|
|
|
- const content = decoder.decode(buffer);
|
|
|
- model = JSON.parse(content);
|
|
|
- if (!model.mainModule) {
|
|
|
- return null;
|
|
|
+ if (entries.size > 0) {
|
|
|
+ let prefix = [];
|
|
|
+ const paths = Array.from(entries.keys()).map((path) => path.split('/').reverse());
|
|
|
+ for (;;) {
|
|
|
+ const set = new Set(paths.map((path) => path.length > 0 ? path.pop() : null));
|
|
|
+ if (set.size !== 1 || set.keys().next().value === null) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ prefix.push(set.keys().next().value);
|
|
|
+ }
|
|
|
+ prefix = prefix.join('/');
|
|
|
+ prefix = prefix.length > 0 ? prefix + '/' : prefix;
|
|
|
+ entries = new Map(Array.from(entries).map((entry) => [ entry[0].substring(prefix.length), entry[1] ]));
|
|
|
+ if (entries.has('model.json')) {
|
|
|
+ try {
|
|
|
+ const stream = entries.get('model.json');
|
|
|
+ const buffer = stream.peek();
|
|
|
+ const decoder = new TextDecoder('utf-8');
|
|
|
+ const content = decoder.decode(buffer);
|
|
|
+ const model = JSON.parse(content);
|
|
|
+ if (model.mainModule) {
|
|
|
+ return new pytorch.Container.Zip.Json(entries, model);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ catch (error) {
|
|
|
+ // continue regardless of error
|
|
|
}
|
|
|
}
|
|
|
- catch (error) {
|
|
|
- return null;
|
|
|
+ if (entries.has('data.pkl')) {
|
|
|
+ return new pytorch.Container.Zip.Pickle(entries);
|
|
|
+ }
|
|
|
+ if (Array.from(entries.keys()).find((name) => name.startsWith('.data/'))) {
|
|
|
+ return new pytorch.Container.Zip.Package(entries);
|
|
|
}
|
|
|
}
|
|
|
- return new pytorch.Container.Zip(entries, name, model);
|
|
|
+ return null;
|
|
|
}
|
|
|
|
|
|
- constructor(entries, name, model) {
|
|
|
- this._entries = entries;
|
|
|
+ constructor(entries) {
|
|
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
|
|
|
- this._model = model;
|
|
|
- const lastIndex = name.lastIndexOf('/');
|
|
|
- this._prefix = lastIndex === -1 ? '' : name.substring(0, lastIndex + 1);
|
|
|
+ this._entries = entries;
|
|
|
+ this._producer = '';
|
|
|
+ this._graphs = [ this ];
|
|
|
}
|
|
|
|
|
|
set metadata(value) {
|
|
|
@@ -2433,48 +2457,12 @@ pytorch.Container.Zip = class {
|
|
|
this._exceptionCallback = value;
|
|
|
}
|
|
|
|
|
|
- get format() {
|
|
|
- if (this._format === undefined) {
|
|
|
- if (this._entry('model.json')) {
|
|
|
- this._format = this._entry('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
|
|
|
- }
|
|
|
- else if (this._entry('data.pkl')) {
|
|
|
- // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
|
|
|
- // kProducedFileFormatVersion
|
|
|
- const versions = new Map([
|
|
|
- [ '1', 'v1.3' ],
|
|
|
- [ '2', 'v1.5' ], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
|
|
|
- [ '3', 'v1.6' ], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
|
|
|
- [ '4', 'v1.6' ], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
|
|
|
- [ '5', 'v1.7' ], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
|
|
|
- [ '6', 'v1.9' ], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
|
|
|
- [ '7', 'v1.10' ] // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
|
|
|
- ]);
|
|
|
- const value = this.version;
|
|
|
- if (!versions.has(value)) {
|
|
|
- this._exceptionCallback(new pytorch.Error("Unsupported PyTorch Zip version '" + value + "'."));
|
|
|
- }
|
|
|
- const version = versions.get(value);
|
|
|
- const constants = this._entry('constants.pkl');
|
|
|
- this._format = (constants ? 'TorchScript' : 'PyTorch') + ' ' + (version || 'v-' + value.toString() );
|
|
|
- }
|
|
|
- }
|
|
|
- return this._format;
|
|
|
- }
|
|
|
-
|
|
|
- get version() {
|
|
|
- const stream = this._entry('version');
|
|
|
- if (stream) {
|
|
|
- const decoder = new TextDecoder('utf-8');
|
|
|
- const buffer = stream.peek();
|
|
|
- const value = decoder.decode(buffer);
|
|
|
- return value.split('\n').shift();
|
|
|
- }
|
|
|
- return '';
|
|
|
+ get producer() {
|
|
|
+ return this._producer;
|
|
|
}
|
|
|
|
|
|
- get producer() {
|
|
|
- return this.data ? this._producer : '';
|
|
|
+ get graphs() {
|
|
|
+ return this.graphs;
|
|
|
}
|
|
|
|
|
|
get name() {
|
|
|
@@ -2486,19 +2474,19 @@ pytorch.Container.Zip = class {
|
|
|
}
|
|
|
|
|
|
get type() {
|
|
|
- this._load();
|
|
|
+ this.read();
|
|
|
return this._type;
|
|
|
}
|
|
|
|
|
|
get data() {
|
|
|
- this._load();
|
|
|
+ this.read();
|
|
|
return this._data;
|
|
|
}
|
|
|
|
|
|
get constants() {
|
|
|
if (this._constants === undefined) {
|
|
|
this._constants = [];
|
|
|
- const stream = this._entry('constants.pkl');
|
|
|
+ const stream = this._entries.get('constants.pkl');
|
|
|
if (stream) {
|
|
|
const buffer = stream.peek();
|
|
|
this._constants = this._unpickle(buffer, this._storage('constants'));
|
|
|
@@ -2540,8 +2528,8 @@ pytorch.Container.Zip = class {
|
|
|
const sources = new Map();
|
|
|
for (const entry of this._entries) {
|
|
|
const name = entry[0];
|
|
|
- if (name.startsWith(this._prefix + 'code')) {
|
|
|
- const file = name.substring(this._prefix.length);
|
|
|
+ if (name.startsWith('code')) {
|
|
|
+ const file = name;
|
|
|
if (sources.has(file)) {
|
|
|
throw new pytorch.Error("Duplicate source file '" + file + "'.");
|
|
|
}
|
|
|
@@ -2560,136 +2548,50 @@ pytorch.Container.Zip = class {
|
|
|
return this._execution;
|
|
|
}
|
|
|
|
|
|
- _entry(name) {
|
|
|
- return this._entries.get(this._prefix + name);
|
|
|
+ version(name) {
|
|
|
+ const stream = this._entries.get(name);
|
|
|
+ if (stream) {
|
|
|
+ const decoder = new TextDecoder('utf-8');
|
|
|
+ const buffer = stream.peek();
|
|
|
+ const text = decoder.decode(buffer);
|
|
|
+ const value = text.split('\n').shift();
|
|
|
+ // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
|
|
|
+ // kProducedFileFormatVersion
|
|
|
+ const versions = new Map([
|
|
|
+ [ '1', 'v1.3' ],
|
|
|
+ [ '2', 'v1.5' ], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
|
|
|
+ [ '3', 'v1.6' ], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
|
|
|
+ [ '4', 'v1.6' ], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
|
|
|
+ [ '5', 'v1.7' ], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
|
|
|
+ [ '6', 'v1.9' ], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
|
|
|
+ [ '7', 'v1.10' ] // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
|
|
|
+ ]);
|
|
|
+ if (!versions.has(value)) {
|
|
|
+ this._exceptionCallback(new pytorch.Error("Unsupported PyTorch Zip version '" + value + "'."));
|
|
|
+ }
|
|
|
+ return versions.get(value) || 'v-' + value.toString();
|
|
|
+ }
|
|
|
+ return '';
|
|
|
}
|
|
|
|
|
|
- _load() {
|
|
|
- if (this._data === undefined) {
|
|
|
- this._data = null;
|
|
|
- const stream = this._entry('data.pkl');
|
|
|
- if (stream) {
|
|
|
- const buffer = stream.peek();
|
|
|
- this._data = this._unpickle(buffer, this._storage('data'));
|
|
|
- }
|
|
|
- else if (this._model) {
|
|
|
- this._producer = this._model.producerName + (this._model.producerVersion ? ' v' + this._model.producerVersion : '');
|
|
|
- this._data = this._model.mainModule || {};
|
|
|
- this._name = this._data.name || '';
|
|
|
- if (this._data.torchscriptArena) {
|
|
|
- this._torchscriptArena = this._data.torchscriptArena.key;
|
|
|
- }
|
|
|
- const queue = [ this._data ];
|
|
|
- const entries = new Map();
|
|
|
- for (const entry of this._entries) {
|
|
|
- const name = entry[0];
|
|
|
- const stream = entry[1];
|
|
|
- const buffer = stream.peek();
|
|
|
- entries.set(name, buffer);
|
|
|
- }
|
|
|
- const tensorTypeMap = new Map([
|
|
|
- [ 'FLOAT', 'Float' ],
|
|
|
- [ 'FLOAT16', 'Half' ],
|
|
|
- [ 'DOUBLE', 'Double' ],
|
|
|
- [ 'INT8', 'Char' ],
|
|
|
- [ 'INT32', 'Int' ],
|
|
|
- [ 'INT64', 'Long' ]
|
|
|
- ]);
|
|
|
- const constants = this._model.tensors || [];
|
|
|
- this._constants = constants.map((constant) => {
|
|
|
- const key = this._prefix + constant.data.key;
|
|
|
- if (!tensorTypeMap.has(constant.dataType)) {
|
|
|
- throw new pytorch.Error("Unsupported tensor data type '" + constant.dataType + "'.");
|
|
|
- }
|
|
|
- const type = tensorTypeMap.get(constant.dataType);
|
|
|
- const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
|
|
|
- const storage_type = this.execution.type('torch.' + type + 'Storage');
|
|
|
- const size = (shape || []).reduce((a, b) => a * b, 1);
|
|
|
- const offset = parseInt(constant.offset, 10) || 0;
|
|
|
- const storage = new storage_type([ size ]);
|
|
|
- const itemsize = storage.dtype.itemsize();
|
|
|
- const buffer = entries.get(key);
|
|
|
- const length = size * itemsize;
|
|
|
- const data = buffer.slice(offset, offset + length);
|
|
|
- storage._set_cdata(data);
|
|
|
- const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
|
|
|
- tensor.name = constant.data.key;
|
|
|
- return tensor;
|
|
|
- });
|
|
|
- this._attributes = [];
|
|
|
- const stream = this._entry('attributes.pkl');
|
|
|
- if (stream) {
|
|
|
- const buffer = stream.peek();
|
|
|
- const unpickler = python.Unpickler.open(buffer);
|
|
|
- this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
|
|
|
- }
|
|
|
- while (queue.length > 0) {
|
|
|
- const module = queue.shift();
|
|
|
- if (!module.__class__) {
|
|
|
- module.__class__ = {
|
|
|
- __module__: 'torch.nn.modules.module',
|
|
|
- __name__: 'Module'
|
|
|
- };
|
|
|
- }
|
|
|
- if (module.name) {
|
|
|
- module.__id__ = module.name;
|
|
|
- }
|
|
|
- if (module.submodules) {
|
|
|
- for (const submodule of module.submodules) {
|
|
|
- module[submodule.name] = submodule;
|
|
|
- submodule.__parent__ = module;
|
|
|
- queue.push(submodule);
|
|
|
- }
|
|
|
- delete module.submodules;
|
|
|
- }
|
|
|
- const attributes = [];
|
|
|
- if (module.attributes) {
|
|
|
- attributes.push(...module.attributes);
|
|
|
- delete module.attributes;
|
|
|
- }
|
|
|
- const parameters = [];
|
|
|
- if (module.parameters) {
|
|
|
- parameters.push(...module.parameters);
|
|
|
- delete module.parameters;
|
|
|
- }
|
|
|
- if (module.arguments) {
|
|
|
- parameters.push(...module.arguments);
|
|
|
- delete module.arguments;
|
|
|
- }
|
|
|
- for (const parameter of parameters) {
|
|
|
- const tensor = this._constants[parameter.tensorId];
|
|
|
- module[parameter.name] = tensor;
|
|
|
- if (!parameter.__class__) {
|
|
|
- parameter.__class__ = {
|
|
|
- __module__: 'torch',
|
|
|
- __name__: 'Tensor'
|
|
|
- };
|
|
|
- }
|
|
|
- }
|
|
|
- for (const attribute of attributes) {
|
|
|
- module[attribute.name] = this._attributes[attribute.id];
|
|
|
- }
|
|
|
- }
|
|
|
- delete this._model;
|
|
|
- }
|
|
|
- if (this.format.startsWith('TorchScript ') && (this._torchscriptArena || this._data.forward)) {
|
|
|
- this._type = 'script';
|
|
|
- return;
|
|
|
- }
|
|
|
- const root = pytorch.Utility.findModule(this._data);
|
|
|
- if (root) {
|
|
|
- this._type = 'module';
|
|
|
- this._data = root;
|
|
|
+ read() {
|
|
|
+ if (this.format.startsWith('TorchScript ') && (this._torchscriptArena || this._data.forward)) {
|
|
|
+ this._type = 'script';
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ const root = pytorch.Utility.findModule(this._data);
|
|
|
+ if (root) {
|
|
|
+ this._type = 'module';
|
|
|
+ this._data = root;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ const weights = pytorch.Utility.findWeights(this._data);
|
|
|
+ if (weights) {
|
|
|
+ this._type = 'weights';
|
|
|
+ this._data = weights;
|
|
|
}
|
|
|
else {
|
|
|
- const weights = pytorch.Utility.findWeights(this._data);
|
|
|
- if (weights) {
|
|
|
- this._type = 'weights';
|
|
|
- this._data = weights;
|
|
|
- }
|
|
|
- else {
|
|
|
- throw new pytorch.Error('File does not contain root module or state dictionary.');
|
|
|
- }
|
|
|
+ throw new pytorch.Error('File does not contain root module or state dictionary.');
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -2742,7 +2644,7 @@ pytorch.Container.Zip = class {
|
|
|
|
|
|
_storage(dirname) {
|
|
|
const map = new Map();
|
|
|
- const prefix = this._prefix + dirname + '/';
|
|
|
+ const prefix = dirname + '/';
|
|
|
for (const entry of this._entries) {
|
|
|
if (entry[0].startsWith(prefix)) {
|
|
|
const key = entry[0].substring(prefix.length);
|
|
|
@@ -2884,6 +2786,221 @@ pytorch.Container.Zip = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+pytorch.Container.Zip.Json = class extends pytorch.Container.Zip {
|
|
|
+
|
|
|
+ constructor(entries, model) {
|
|
|
+ super(entries);
|
|
|
+ this._producer = model && model.producerName ? model.producerName + (model.producerVersion ? ' v' + model.producerVersion : '') : '';
|
|
|
+ this._model = model;
|
|
|
+ }
|
|
|
+
|
|
|
+ get format() {
|
|
|
+ return this._entries.get('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
|
|
|
+ }
|
|
|
+
|
|
|
+ read() {
|
|
|
+ if (!this._data) {
|
|
|
+ this._data = this._model.mainModule || {};
|
|
|
+ this._name = this._data.name || '';
|
|
|
+ if (this._data.torchscriptArena) {
|
|
|
+ this._torchscriptArena = this._data.torchscriptArena.key;
|
|
|
+ }
|
|
|
+ const queue = [ this._data ];
|
|
|
+ const entries = new Map();
|
|
|
+ for (const entry of this._entries) {
|
|
|
+ const name = entry[0];
|
|
|
+ const stream = entry[1];
|
|
|
+ const buffer = stream.peek();
|
|
|
+ entries.set(name, buffer);
|
|
|
+ }
|
|
|
+ const tensorTypeMap = new Map([
|
|
|
+ [ 'FLOAT', 'Float' ],
|
|
|
+ [ 'FLOAT16', 'Half' ],
|
|
|
+ [ 'DOUBLE', 'Double' ],
|
|
|
+ [ 'INT8', 'Char' ],
|
|
|
+ [ 'INT32', 'Int' ],
|
|
|
+ [ 'INT64', 'Long' ]
|
|
|
+ ]);
|
|
|
+ const constants = this._model.tensors || [];
|
|
|
+ this._constants = constants.map((constant) => {
|
|
|
+ const key = constant.data.key;
|
|
|
+ if (!tensorTypeMap.has(constant.dataType)) {
|
|
|
+ throw new pytorch.Error("Unsupported tensor data type '" + constant.dataType + "'.");
|
|
|
+ }
|
|
|
+ const type = tensorTypeMap.get(constant.dataType);
|
|
|
+ const shape = constant.dims ? constant.dims.map((dim) => parseInt(dim, 10)) : null;
|
|
|
+ const storage_type = this.execution.type('torch.' + type + 'Storage');
|
|
|
+ const size = (shape || []).reduce((a, b) => a * b, 1);
|
|
|
+ const offset = parseInt(constant.offset, 10) || 0;
|
|
|
+ const storage = new storage_type([ size ]);
|
|
|
+ const itemsize = storage.dtype.itemsize();
|
|
|
+ const buffer = entries.get(key);
|
|
|
+ const length = size * itemsize;
|
|
|
+ const data = buffer.slice(offset, offset + length);
|
|
|
+ storage._set_cdata(data);
|
|
|
+ const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
|
|
|
+ tensor.name = constant.data.key;
|
|
|
+ return tensor;
|
|
|
+ });
|
|
|
+ this._attributes = [];
|
|
|
+ const stream = this._entries.get('attributes.pkl');
|
|
|
+ if (stream) {
|
|
|
+ const buffer = stream.peek();
|
|
|
+ const unpickler = python.Unpickler.open(buffer);
|
|
|
+ this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
|
|
|
+ }
|
|
|
+ while (queue.length > 0) {
|
|
|
+ const module = queue.shift();
|
|
|
+ if (!module.__class__) {
|
|
|
+ module.__class__ = {
|
|
|
+ __module__: 'torch.nn.modules.module',
|
|
|
+ __name__: 'Module'
|
|
|
+ };
|
|
|
+ }
|
|
|
+ if (module.name) {
|
|
|
+ module.__id__ = module.name;
|
|
|
+ }
|
|
|
+ if (module.submodules) {
|
|
|
+ for (const submodule of module.submodules) {
|
|
|
+ module[submodule.name] = submodule;
|
|
|
+ submodule.__parent__ = module;
|
|
|
+ queue.push(submodule);
|
|
|
+ }
|
|
|
+ delete module.submodules;
|
|
|
+ }
|
|
|
+ const attributes = [];
|
|
|
+ if (module.attributes) {
|
|
|
+ attributes.push(...module.attributes);
|
|
|
+ delete module.attributes;
|
|
|
+ }
|
|
|
+ const parameters = [];
|
|
|
+ if (module.parameters) {
|
|
|
+ parameters.push(...module.parameters);
|
|
|
+ delete module.parameters;
|
|
|
+ }
|
|
|
+ if (module.arguments) {
|
|
|
+ parameters.push(...module.arguments);
|
|
|
+ delete module.arguments;
|
|
|
+ }
|
|
|
+ for (const parameter of parameters) {
|
|
|
+ const tensor = this._constants[parameter.tensorId];
|
|
|
+ module[parameter.name] = tensor;
|
|
|
+ if (!parameter.__class__) {
|
|
|
+ parameter.__class__ = {
|
|
|
+ __module__: 'torch',
|
|
|
+ __name__: 'Tensor'
|
|
|
+ };
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const attribute of attributes) {
|
|
|
+ module[attribute.name] = this._attributes[attribute.id];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ delete this._model;
|
|
|
+ super.read();
|
|
|
+ }
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+pytorch.Container.Zip.Pickle = class extends pytorch.Container.Zip {
|
|
|
+
|
|
|
+ constructor(entries) {
|
|
|
+ super(entries);
|
|
|
+ }
|
|
|
+
|
|
|
+ get format() {
|
|
|
+ return (this._entries.get('constants.pkl') ? 'TorchScript' : 'PyTorch') + ' ' + this.version('version');
|
|
|
+ }
|
|
|
+
|
|
|
+ read() {
|
|
|
+ if (!this._data) {
|
|
|
+ const stream = this._entries.get('data.pkl');
|
|
|
+ const buffer = stream.peek();
|
|
|
+ this._data = this._unpickle(buffer, this._storage('data'));
|
|
|
+ super.read();
|
|
|
+ }
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+pytorch.Container.Zip.Package = class extends pytorch.Container.Zip {
|
|
|
+
|
|
|
+ constructor(entries) {
|
|
|
+ super(entries);
|
|
|
+ }
|
|
|
+
|
|
|
+ get format() {
|
|
|
+ return 'PyTorch Package' + ' ' + this.version('.data/version');
|
|
|
+ }
|
|
|
+
|
|
|
+ read() {
|
|
|
+ const entries = Array.from(this._entries).filter((entry) => !entry[0].startsWith('.data/') && !entry[0].endsWith('py'));
|
|
|
+ for (const entry of entries) {
|
|
|
+ /* const name = */ entry[0];
|
|
|
+ const stream = entry[1];
|
|
|
+ const loaded_reduces = new Map();
|
|
|
+ // const loaded_storages = new Map();
|
|
|
+ const persistent_load = (saved_id) => {
|
|
|
+ const typename = saved_id.shift();
|
|
|
+ switch (typename) {
|
|
|
+ case 'storage': {
|
|
|
+ /*
|
|
|
+ const storage_type = saved_id[0];
|
|
|
+ const key = saved_id[1];
|
|
|
+ const location = saved_id[2];
|
|
|
+ const size = saved_id[3];
|
|
|
+ dtype = storage_type.dtype
|
|
|
+ if key not in loaded_storages:
|
|
|
+ load_tensor(
|
|
|
+ dtype,
|
|
|
+ size,
|
|
|
+ key,
|
|
|
+ _maybe_decode_ascii(location),
|
|
|
+ restore_location,
|
|
|
+ )
|
|
|
+ storage = loaded_storages[key]
|
|
|
+ # TODO: Once we decide to break serialization FC, we can
|
|
|
+ # stop wrapping with _TypedStorage
|
|
|
+ return torch.storage._TypedStorage(
|
|
|
+ wrap_storage=storage._untyped(), dtype=dtype
|
|
|
+ )
|
|
|
+ */
|
|
|
+ throw new pytorch.Error('');
|
|
|
+ }
|
|
|
+ case 'reduce_package': {
|
|
|
+ if (saved_id.left === 2) {
|
|
|
+ const func = saved_id[0];
|
|
|
+ const args = saved_id[1];
|
|
|
+ return execution.invoke(func, args);
|
|
|
+ }
|
|
|
+ const reduce_id = saved_id[0];
|
|
|
+ const func = saved_id[1];
|
|
|
+ const args = saved_id[2];
|
|
|
+ if (!loaded_reduces.has(reduce_id)) {
|
|
|
+ const value = execution.invoke(func, args);
|
|
|
+ loaded_reduces.set(reduce_id, value);
|
|
|
+ }
|
|
|
+ return loaded_reduces.get(reduce_id);
|
|
|
+ }
|
|
|
+ default: {
|
|
|
+ throw new python.Error("Unknown package typename '" + typename + "'.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+ const execution = new pytorch.Container.Zip.Execution(null, this._exceptionCallback, this._metadata);
|
|
|
+ execution.registerFunction('torch.jit._script.unpackage_script_module', function(script_module_id) {
|
|
|
+ /* const data = */ '.data/ts_code' + script_module_id + 'data.pkl';
|
|
|
+ // const constants = '.data/ts_code' + script_module_id + 'constants.pkl.pkl';
|
|
|
+ return { __TODO__: 'unpackage_script_module' };
|
|
|
+ });
|
|
|
+ const unpickler = python.Unpickler.open(stream);
|
|
|
+ /* const obj = */ unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
|
|
|
+ }
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+pytorch.Container.Zip.Script = class {
|
|
|
+};
|
|
|
+
|
|
|
pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
|
|
|
constructor(sources, exceptionCallback, metadata) {
|