|
|
@@ -2042,14 +2042,13 @@ pytorch.Container = class {
|
|
|
if (zip) {
|
|
|
return zip;
|
|
|
}
|
|
|
- const stream = context.stream;
|
|
|
- const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
|
|
|
- if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
|
|
|
- return new pytorch.Container.Pickle(stream);
|
|
|
+ const pickle = pytorch.Container.Pickle.open(context.stream);
|
|
|
+ if (pickle) {
|
|
|
+ return pickle;
|
|
|
}
|
|
|
- const entries = context.entries('tar');
|
|
|
- if (entries.has('pickle')) {
|
|
|
- return new pytorch.Container.Tar(entries);
|
|
|
+ const tar = pytorch.Container.Tar.open(context.entries('tar'));
|
|
|
+ if (tar) {
|
|
|
+ return tar;
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
@@ -2057,6 +2056,13 @@ pytorch.Container = class {
|
|
|
|
|
|
pytorch.Container.Tar = class {
|
|
|
|
|
|
+ static open(entries) {
|
|
|
+ if (entries.has('pickle')) {
|
|
|
+ return new pytorch.Container.Tar(entries);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
constructor(entries) {
|
|
|
this._entries = entries;
|
|
|
}
|
|
|
@@ -2183,6 +2189,14 @@ pytorch.Container.Tar = class {
|
|
|
|
|
|
pytorch.Container.Pickle = class {
|
|
|
|
|
|
+ static open(stream) {
|
|
|
+ const signature = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
|
|
|
+ if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
|
|
|
+ return new pytorch.Container.Pickle(stream);
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+
|
|
|
constructor(stream) {
|
|
|
this._stream = stream;
|
|
|
}
|