Przeglądaj źródła

Update pytorch.js

Lutz Roeder 5 lat temu
rodzic
commit
6239a3e2d8
2 zmienionych plików z 34 dodań i 46 usunięć
  1. 13 12
      source/python.js
  2. 21 34
      source/pytorch.js

+ 13 - 12
source/python.js

@@ -3192,6 +3192,14 @@ python.Unpickler = class {
         return this._reader.stream(size);
         return this._reader.stream(size);
     }
     }
 
 
+    int32() {
+        return this._reader.int32();
+    }
+
+    int64() {
+        return this._reader.int64();
+    }
+
     unescape(token, size) {
     unescape(token, size) {
         const length = token.length;
         const length = token.length;
         const a = new Uint8Array(length);
         const a = new Uint8Array(length);
@@ -3424,12 +3432,9 @@ python.Unpickler.BinaryReader = class {
     }
     }
 
 
     int64() {
     int64() {
-        const low = this.uint32();
-        const high = this.uint32();
-        if (high !== 0) {
-            throw new python.Error('Unsupported 64-bit integer value.');
-        }
-        return low;
+        const position = this._position;
+        this.skip(8);
+        return this._dataView.getInt64(position, true).toNumber();
     }
     }
 
 
     float32() {
     float32() {
@@ -3521,12 +3526,8 @@ python.Unpickler.StreamReader = class {
     }
     }
 
 
     int64() {
     int64() {
-        const low = this.uint32();
-        const high = this.uint32();
-        if (high !== 0) {
-            throw new python.Error('Unsupported 64-bit integer value.');
-        }
-        return low;
+        const position = this._fill(8);
+        return this._dataView.getInt64(position, true).toNumber();
     }
     }
 
 
     float32() {
     float32() {

+ 21 - 34
source/pytorch.js

@@ -1579,20 +1579,20 @@ pytorch.Execution = class extends python.Execution {
                 }
                 }
                 this._cdata = data;
                 this._cdata = data;
             }
             }
-            _set_from_file(file) {
-                const size = pytorch.Utility.readInt64(file.read(8));
+            _set_from_file(unpickler) {
+                const size = unpickler.int64();
                 if (size !== this.size()) {
                 if (size !== this.size()) {
                     throw new pytorch.Error('Storage size mismatch.');
                     throw new pytorch.Error('Storage size mismatch.');
                 }
                 }
                 const itemsize = this.dtype.itemsize();
                 const itemsize = this.dtype.itemsize();
-                const data = file.stream(itemsize * size);
+                const data = unpickler.stream(itemsize * size);
                 this._set_cdata(data);
                 this._set_cdata(data);
             }
             }
-            static _new_with_file(file) {
-                const size = pytorch.Utility.readInt64(file.read(8));
+            static _new_with_file(unpickler) {
+                const size = unpickler.int64();
                 const storage = new this(size);
                 const storage = new this(size);
                 const itemsize = storage.dtype.itemsize();
                 const itemsize = storage.dtype.itemsize();
-                const data = file.stream(itemsize * size);
+                const data = unpickler.stream(itemsize * size);
                 storage._set_cdata(data);
                 storage._set_cdata(data);
                 return storage;
                 return storage;
             }
             }
@@ -1850,26 +1850,24 @@ pytorch.Container.Tar = class {
         if (entries.tensors) {
         if (entries.tensors) {
             const unpickler = new python.Unpickler(entries.tensors);
             const unpickler = new python.Unpickler(entries.tensors);
             const num_tensors = unpickler.load((name, args) => execution.invoke(name, args));
             const num_tensors = unpickler.load((name, args) => execution.invoke(name, args));
-            for (let j = 0; j < num_tensors; j++) {
-                const tensor_args = unpickler.load();
-                const tensor_key = tensor_args[0];
-                const storage_id = tensor_args[1];
+            for (let i = 0; i < num_tensors; i++) {
+                const args = unpickler.load();
+                const key = args[0];
+                const storage_id = args[1];
                 const storage = deserialized_objects[storage_id];
                 const storage = deserialized_objects[storage_id];
-                const ndim = pytorch.Utility.readInt32(unpickler.read(4));
+                const ndim = unpickler.int32();
                 unpickler.read(4);
                 unpickler.read(4);
-                const shape = [];
-                for (let k = 0; k < ndim; k++) {
-                    shape.push(pytorch.Utility.readInt64(unpickler.read(8)));
+                const shape = new Array(ndim);
+                for (let j = 0; j < ndim; j++) {
+                    shape[j] = unpickler.int64();
                 }
                 }
-                const stride = [];
-                for (let l = 0; l < ndim; l++) {
-                    stride.push(pytorch.Utility.readInt64(unpickler.read(8)));
+                const stride = new Array(ndim);
+                for (let j = 0; j < ndim; j++) {
+                    stride[j] = unpickler.int64();
                 }
                 }
-                const storage_offset = pytorch.Utility.readInt64(unpickler.read(8));
-                const tensor_type_name = storage.__class__.__name__.replace('Storage', 'Tensor');
-                const tensor = execution.invoke(storage.__class__.__module__ + '.' + tensor_type_name, []);
-                tensor.__setstate__([ storage, storage_offset, shape, stride ]);
-                deserialized_objects[tensor_key] = tensor;
+                const storage_offset = unpickler.int64();
+                const tensor = execution.invoke('torch._utils._rebuild_tensor', [ storage, storage_offset, shape, stride ]);
+                deserialized_objects[key] = tensor;
             }
             }
         }
         }
 
 
@@ -2162,10 +2160,8 @@ pytorch.Container.Zip = class {
                         const length = size * itemsize;
                         const length = size * itemsize;
                         const data = buffer.slice(offset, offset + length);
                         const data = buffer.slice(offset, offset + length);
                         storage._set_cdata(data);
                         storage._set_cdata(data);
-                        const tensor_type = this.execution.type('torch.' + type + 'Tensor');
-                        const tensor = new tensor_type();
+                        const tensor = this.execution.invoke('torch._utils._rebuild_tensor', [ storage, 0, shape, 0 ]);
                         tensor.name = constant.data.key;
                         tensor.name = constant.data.key;
-                        tensor.__setstate__([ storage, 0, shape, 0 ]);
                         return tensor;
                         return tensor;
                     });
                     });
                     this._attributes = [];
                     this._attributes = [];
@@ -3183,15 +3179,6 @@ pytorch.Utility = class {
         return state_dict;
         return state_dict;
     }
     }
 
 
-    static readInt32(buffer) {
-        const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-        return view.getInt32(0, true);
-    }
-
-    static readInt64(buffer) {
-        const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-        return view.getInt64(0, true).toNumber();
-    }
 };
 };
 
 
 pytorch.nnapi = {};
 pytorch.nnapi = {};