Jelajahi Sumber

Torch storage offset support (#200) (#691)

Lutz Roeder 5 tahun lalu
induk
melakukan
d92135d7eb
1 mengubah file dengan 44 tambahan dan 30 penghapusan
  1. 44 30
      source/torch.js

+ 44 - 30
source/torch.js

@@ -244,12 +244,7 @@ torch.Node = class {
         for (const key of Object.keys(module)) {
             const obj = module[key];
             if (obj && obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Storage')) {
-                const array = [];
-                obj.reset();
-                for (let i = 0; i < obj.size; i++) {
-                    array.push(obj.read());
-                }
-                module[key] = array;
+                module[key] = obj.data();
             }
         }
         delete module.iSize;
@@ -476,6 +471,7 @@ torch.Tensor = class {
     constructor(tensor) {
         this._type = new torch.TensorType(tensor);
         this._storage = tensor.storage;
+        this._offset = tensor.storage_offset;
     }
 
     get type() {
@@ -510,7 +506,13 @@ torch.Tensor = class {
         context.state = null;
         context.index = 0;
         context.count = 0;
-        if (!this._storage || !this._storage.reader) {
+        if (!this._storage) {
+            context.state = 'Tensor data is empty.';
+            return context;
+        }
+        context.data = this._storage.data();
+        context.index = this._offset;
+        if (!context.data) {
             context.state = 'Tensor data is empty.';
             return context;
         }
@@ -532,8 +534,6 @@ torch.Tensor = class {
             context.state =  'Tensor has no dimensions.';
             return context;
         }
-        context.storage = this._storage;
-        context.storage.reset();
         return context;
     }
 
@@ -546,7 +546,7 @@ torch.Tensor = class {
                     results.push('...');
                     return results;
                 }
-                results.push(context.storage.read());
+                results.push(context.data[context.index]);
                 context.index++;
                 context.count++;
             }
@@ -994,27 +994,41 @@ torch.T7Reader = class {
         obj.itemSize = itemSize;
         obj.size = this.int64();
         obj.reader = this._reader.storage(obj.size, obj.itemSize, dataType);
-        obj.reset = function() {
-            this.reader.reset();
-        };
-        obj.read = function() {
-            switch (dataType) {
-                case 'uint8':
-                    return this.reader.byte();
-                case 'int8':
-                    return this.reader.int8();
-                case 'int16':
-                    return this.reader.int16();
-                case 'int32':
-                    return this.reader.int32();
-                case 'int64':
-                    return this.reader.int64();
-                case 'float32':
-                    return this.reader.float32();
-                case 'float64':
-                    return this.reader.float64();
+        obj.data = function() {
+            if (this.reader) {
+                const reader = this.reader;
+                reader.reset();
+                const size = obj.size;
+                const array = new Array(size);
+                for (let i = 0; i < size; i++) {
+                    switch (dataType) {
+                        case 'uint8':
+                            array[i] = this.reader.byte();
+                            break;
+                        case 'int8':
+                            array[i] = this.reader.int8();
+                            break;
+                        case 'int16':
+                            array[i] = this.reader.int16();
+                            break;
+                        case 'int32':
+                            array[i] = this.reader.int32();
+                            break;
+                        case 'int64':
+                            array[i] = this.reader.int64();
+                            break;
+                        case 'float32':
+                            array[i] = this.reader.float32();
+                            break;
+                        case 'float64':
+                            array[i] = this.reader.float64();
+                            break;
+                    }
+                }
+                obj._data = array;
+                delete obj.reader;
             }
-            return null;
+            return obj._data;
         };
     }
 };