|
|
@@ -678,9 +678,11 @@ pytorch.Tensor = class {
|
|
|
case 'float32':
|
|
|
case 'float64':
|
|
|
case 'bfloat16':
|
|
|
+ case 'complex64':
|
|
|
+ case 'complex128':
|
|
|
break;
|
|
|
default:
|
|
|
- context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
|
|
|
+ context.state = "Tensor data type '" + this._type.dataType + "' is not implemented.";
|
|
|
return context;
|
|
|
}
|
|
|
if (!this._type.shape) {
|
|
|
@@ -702,7 +704,7 @@ pytorch.Tensor = class {
|
|
|
|
|
|
context.dataType = this._type.dataType;
|
|
|
context.dimensions = this._type.shape.dimensions;
|
|
|
- context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
|
|
|
+ context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
|
|
|
return context;
|
|
|
}
|
|
|
|
|
|
@@ -718,56 +720,66 @@ pytorch.Tensor = class {
|
|
|
}
|
|
|
switch (context.dataType) {
|
|
|
case 'boolean':
|
|
|
- results.push(context.dataView.getUint8(context.index) === 0 ? false : true);
|
|
|
+ results.push(context.view.getUint8(context.index) === 0 ? false : true);
|
|
|
context.index++;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'uint8':
|
|
|
- results.push(context.dataView.getUint8(context.index));
|
|
|
+ results.push(context.view.getUint8(context.index));
|
|
|
context.index++;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'qint8':
|
|
|
case 'int8':
|
|
|
- results.push(context.dataView.getInt8(context.index));
|
|
|
+ results.push(context.view.getInt8(context.index));
|
|
|
context.index++;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'int16':
|
|
|
- results.push(context.dataView.getInt16(context.index, this._littleEndian));
|
|
|
+ results.push(context.view.getInt16(context.index, this._littleEndian));
|
|
|
context.index += 2;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'int32':
|
|
|
- results.push(context.dataView.getInt32(context.index, this._littleEndian));
|
|
|
+ results.push(context.view.getInt32(context.index, this._littleEndian));
|
|
|
context.index += 4;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'int64':
|
|
|
- results.push(context.dataView.getInt64(context.index, this._littleEndian));
|
|
|
+ results.push(context.view.getInt64(context.index, this._littleEndian));
|
|
|
context.index += 8;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'float16':
|
|
|
- results.push(context.dataView.getFloat16(context.index, this._littleEndian));
|
|
|
+ results.push(context.view.getFloat16(context.index, this._littleEndian));
|
|
|
context.index += 2;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'float32':
|
|
|
- results.push(context.dataView.getFloat32(context.index, this._littleEndian));
|
|
|
+ results.push(context.view.getFloat32(context.index, this._littleEndian));
|
|
|
context.index += 4;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'float64':
|
|
|
- results.push(context.dataView.getFloat64(context.index, this._littleEndian));
|
|
|
+ results.push(context.view.getFloat64(context.index, this._littleEndian));
|
|
|
context.index += 8;
|
|
|
context.count++;
|
|
|
break;
|
|
|
case 'bfloat16':
|
|
|
- results.push(context.dataView.getBfloat16(context.index, this._littleEndian));
|
|
|
+ results.push(context.view.getBfloat16(context.index, this._littleEndian));
|
|
|
context.index += 2;
|
|
|
context.count++;
|
|
|
break;
|
|
|
+ case 'complex64':
|
|
|
+ results.push(context.view.getComplex64(i << 3, this._littleEndian));
|
|
|
+ context.index += 8;
|
|
|
+ context.count++;
|
|
|
+ break;
|
|
|
+ case 'complex128':
|
|
|
+ results.push(context.view.getComplex128(i << 4, this._littleEndian));
|
|
|
+ context.index += 16;
|
|
|
+ context.count++;
|
|
|
+ break;
|
|
|
default:
|
|
|
throw new pytorch.Error("Unsupported tensor data type '" + context.dataType + "'.");
|
|
|
}
|
|
|
@@ -799,22 +811,26 @@ pytorch.Tensor = class {
|
|
|
result.push(indentation + ']');
|
|
|
return result.join('\n');
|
|
|
}
|
|
|
- if (value && (value instanceof base.Int64 || value instanceof base.Uint64)) {
|
|
|
- return indentation + value.toString();
|
|
|
- }
|
|
|
- if (typeof value == 'string') {
|
|
|
- return indentation + value;
|
|
|
- }
|
|
|
- if (value == Infinity) {
|
|
|
- return indentation + 'Infinity';
|
|
|
- }
|
|
|
- if (value == -Infinity) {
|
|
|
- return indentation + '-Infinity';
|
|
|
- }
|
|
|
- if (isNaN(value)) {
|
|
|
- return indentation + 'NaN';
|
|
|
+ switch (typeof value) {
|
|
|
+ case 'string':
|
|
|
+ return indentation + value;
|
|
|
+ case 'number':
|
|
|
+ if (value == Infinity) {
|
|
|
+ return indentation + 'Infinity';
|
|
|
+ }
|
|
|
+ if (value == -Infinity) {
|
|
|
+ return indentation + '-Infinity';
|
|
|
+ }
|
|
|
+ if (isNaN(value)) {
|
|
|
+ return indentation + 'NaN';
|
|
|
+ }
|
|
|
+ return indentation + value.toString();
|
|
|
+ default:
|
|
|
+ if (value && value.toString) {
|
|
|
+ return indentation + value.toString();
|
|
|
+ }
|
|
|
+ return indentation + '(undefined)';
|
|
|
}
|
|
|
- return indentation + value.toString();
|
|
|
}
|
|
|
};
|
|
|
|
|
|
@@ -1833,20 +1849,20 @@ pytorch.Execution = class extends python.Execution {
|
|
|
this._device = null;
|
|
|
}
|
|
|
get device() {
|
|
|
- return null;
|
|
|
+ return this._device;
|
|
|
}
|
|
|
get dtype() {
|
|
|
return this._dtype;
|
|
|
}
|
|
|
- get data() {
|
|
|
- return this._cdata;
|
|
|
- }
|
|
|
element_size() {
|
|
|
return this._dtype.element_size;
|
|
|
}
|
|
|
size() {
|
|
|
return this._size;
|
|
|
}
|
|
|
+ get data() {
|
|
|
+ return this._cdata;
|
|
|
+ }
|
|
|
_set_cdata(data) {
|
|
|
const length = this.size() * this.dtype.itemsize();
|
|
|
if (length !== data.length) {
|
|
|
@@ -1876,6 +1892,33 @@ pytorch.Execution = class extends python.Execution {
|
|
|
return storage;
|
|
|
}
|
|
|
});
|
|
|
+ this.registerType('torch.storage._UntypedStorage', class extends torch_storage._StorageBase {
|
|
|
+ constructor() {
|
|
|
+ super();
|
|
|
+ throw new python.Error('_UntypedStorage not implemented.');
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('torch.storage._TypedStorage', class {
|
|
|
+ constructor() {
|
|
|
+ throw new python.Error('_TypedStorage not implemented.');
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('torch.storage._LegacyStorage', class extends torch_storage._TypedStorage {
|
|
|
+ constructor() {
|
|
|
+ super();
|
|
|
+ throw new python.Error('_LegacyStorage not implemented.');
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('torch.ComplexFloatStorage', class extends torch_storage._StorageBase {
|
|
|
+ constructor(size) {
|
|
|
+ super(size, torch.complex64);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ this.registerType('torch.ComplexDoubleStorage', class extends torch_storage._StorageBase {
|
|
|
+ constructor(size) {
|
|
|
+ super(size, torch.complex128);
|
|
|
+ }
|
|
|
+ });
|
|
|
this.registerType('torch.BoolStorage', class extends torch_storage._StorageBase {
|
|
|
constructor(size) {
|
|
|
super(size, torch.bool);
|
|
|
@@ -2058,6 +2101,8 @@ pytorch.Execution = class extends python.Execution {
|
|
|
this.registerType('torch.HalfTensor', class extends torch.Tensor {});
|
|
|
this.registerType('torch.FloatTensor', class extends torch.Tensor {});
|
|
|
this.registerType('torch.DoubleTensor', class extends torch.Tensor {});
|
|
|
+ this.registerType('torch.ComplexFloatTensor', class extends torch.Tensor {});
|
|
|
+ this.registerType('torch.ComplexDoubleTensor', class extends torch.Tensor {});
|
|
|
this.registerType('torch.QInt8Tensor', class extends torch.Tensor {});
|
|
|
this.registerType('torch.QUInt8Tensor', class extends torch.Tensor {});
|
|
|
this.registerType('torch.QInt32Tensor', class extends torch.Tensor {});
|