|
|
@@ -244,12 +244,7 @@ torch.Node = class {
|
|
|
for (const key of Object.keys(module)) {
|
|
|
const obj = module[key];
|
|
|
if (obj && obj.__type__ && obj.__type__.startsWith('torch.') && obj.__type__.endsWith('Storage')) {
|
|
|
- const array = [];
|
|
|
- obj.reset();
|
|
|
- for (let i = 0; i < obj.size; i++) {
|
|
|
- array.push(obj.read());
|
|
|
- }
|
|
|
- module[key] = array;
|
|
|
+ module[key] = obj.data();
|
|
|
}
|
|
|
}
|
|
|
delete module.iSize;
|
|
|
@@ -476,6 +471,7 @@ torch.Tensor = class {
|
|
|
constructor(tensor) {
|
|
|
this._type = new torch.TensorType(tensor);
|
|
|
this._storage = tensor.storage;
|
|
|
+ this._offset = tensor.storage_offset;
|
|
|
}
|
|
|
|
|
|
get type() {
|
|
|
@@ -510,7 +506,13 @@ torch.Tensor = class {
|
|
|
context.state = null;
|
|
|
context.index = 0;
|
|
|
context.count = 0;
|
|
|
- if (!this._storage || !this._storage.reader) {
|
|
|
+ if (!this._storage) {
|
|
|
+ context.state = 'Tensor data is empty.';
|
|
|
+ return context;
|
|
|
+ }
|
|
|
+ context.data = this._storage.data();
|
|
|
+ context.index = this._offset;
|
|
|
+ if (!context.data) {
|
|
|
context.state = 'Tensor data is empty.';
|
|
|
return context;
|
|
|
}
|
|
|
@@ -532,8 +534,6 @@ torch.Tensor = class {
|
|
|
context.state = 'Tensor has no dimensions.';
|
|
|
return context;
|
|
|
}
|
|
|
- context.storage = this._storage;
|
|
|
- context.storage.reset();
|
|
|
return context;
|
|
|
}
|
|
|
|
|
|
@@ -546,7 +546,7 @@ torch.Tensor = class {
|
|
|
results.push('...');
|
|
|
return results;
|
|
|
}
|
|
|
- results.push(context.storage.read());
|
|
|
+ results.push(context.data[context.index]);
|
|
|
context.index++;
|
|
|
context.count++;
|
|
|
}
|
|
|
@@ -994,27 +994,41 @@ torch.T7Reader = class {
|
|
|
obj.itemSize = itemSize;
|
|
|
obj.size = this.int64();
|
|
|
obj.reader = this._reader.storage(obj.size, obj.itemSize, dataType);
|
|
|
- obj.reset = function() {
|
|
|
- this.reader.reset();
|
|
|
- };
|
|
|
- obj.read = function() {
|
|
|
- switch (dataType) {
|
|
|
- case 'uint8':
|
|
|
- return this.reader.byte();
|
|
|
- case 'int8':
|
|
|
- return this.reader.int8();
|
|
|
- case 'int16':
|
|
|
- return this.reader.int16();
|
|
|
- case 'int32':
|
|
|
- return this.reader.int32();
|
|
|
- case 'int64':
|
|
|
- return this.reader.int64();
|
|
|
- case 'float32':
|
|
|
- return this.reader.float32();
|
|
|
- case 'float64':
|
|
|
- return this.reader.float64();
|
|
|
+ obj.data = function() {
|
|
|
+ if (this.reader) {
|
|
|
+ const reader = this.reader;
|
|
|
+ reader.reset();
|
|
|
+ const size = obj.size;
|
|
|
+ const array = new Array(size);
|
|
|
+ for (let i = 0; i < size; i++) {
|
|
|
+ switch (dataType) {
|
|
|
+ case 'uint8':
|
|
|
+ array[i] = this.reader.byte();
|
|
|
+ break;
|
|
|
+ case 'int8':
|
|
|
+ array[i] = this.reader.int8();
|
|
|
+ break;
|
|
|
+ case 'int16':
|
|
|
+ array[i] = this.reader.int16();
|
|
|
+ break;
|
|
|
+ case 'int32':
|
|
|
+ array[i] = this.reader.int32();
|
|
|
+ break;
|
|
|
+ case 'int64':
|
|
|
+ array[i] = this.reader.int64();
|
|
|
+ break;
|
|
|
+ case 'float32':
|
|
|
+ array[i] = this.reader.float32();
|
|
|
+ break;
|
|
|
+ case 'float64':
|
|
|
+ array[i] = this.reader.float64();
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ obj._data = array;
|
|
|
+ delete obj.reader;
|
|
|
}
|
|
|
- return null;
|
|
|
+ return obj._data;
|
|
|
};
|
|
|
}
|
|
|
};
|