|
@@ -1221,34 +1221,52 @@ tf.Tensor = class {
|
|
|
else {
|
|
else {
|
|
|
const DataType = tf.proto.tensorflow.DataType;
|
|
const DataType = tf.proto.tensorflow.DataType;
|
|
|
switch (tensor.dtype) {
|
|
switch (tensor.dtype) {
|
|
|
- case DataType.DT_FLOAT:
|
|
|
|
|
|
|
+ case DataType.DT_HALF: {
|
|
|
|
|
+ const values = tensor.half_val || [];
|
|
|
|
|
+ this._buffer = new Uint8Array(values.length << 1);
|
|
|
|
|
+ const view = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
|
|
|
|
|
+ for (let i = 0; i < values.length; i++) {
|
|
|
|
|
+ view.setUint16(i << 1, values[i], true);
|
|
|
|
|
+ }
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ case DataType.DT_FLOAT: {
|
|
|
this._data = tensor.float_val || null;
|
|
this._data = tensor.float_val || null;
|
|
|
break;
|
|
break;
|
|
|
- case DataType.DT_DOUBLE:
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ case DataType.DT_DOUBLE: {
|
|
|
this._data = tensor.double_val || null;
|
|
this._data = tensor.double_val || null;
|
|
|
break;
|
|
break;
|
|
|
|
|
+ }
|
|
|
case DataType.DT_INT8:
|
|
case DataType.DT_INT8:
|
|
|
case DataType.DT_UINT8:
|
|
case DataType.DT_UINT8:
|
|
|
- case DataType.DT_INT32:
|
|
|
|
|
|
|
+ case DataType.DT_INT32: {
|
|
|
this._data = tensor.int_val || null;
|
|
this._data = tensor.int_val || null;
|
|
|
break;
|
|
break;
|
|
|
- case DataType.DT_UINT32:
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ case DataType.DT_UINT32: {
|
|
|
this._data = tensor.uint32_val || null;
|
|
this._data = tensor.uint32_val || null;
|
|
|
break;
|
|
break;
|
|
|
- case DataType.DT_INT64:
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ case DataType.DT_INT64: {
|
|
|
this._data = tensor.int64_val || null;
|
|
this._data = tensor.int64_val || null;
|
|
|
break;
|
|
break;
|
|
|
- case DataType.DT_UINT64:
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ case DataType.DT_UINT64: {
|
|
|
this._data = tensor.uint64_val || null;
|
|
this._data = tensor.uint64_val || null;
|
|
|
break;
|
|
break;
|
|
|
- case DataType.DT_BOOL:
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ case DataType.DT_BOOL: {
|
|
|
this._data = tensor.bool_val || null;
|
|
this._data = tensor.bool_val || null;
|
|
|
break;
|
|
break;
|
|
|
- case DataType.DT_STRING:
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ case DataType.DT_STRING: {
|
|
|
this._data = tensor.string_val || null;
|
|
this._data = tensor.string_val || null;
|
|
|
break;
|
|
break;
|
|
|
- default:
|
|
|
|
|
|
|
+ }
|
|
|
|
|
+ default: {
|
|
|
throw new tf.Error("Unsupported tensor data type '" + tensor.dtype + "'.");
|
|
throw new tf.Error("Unsupported tensor data type '" + tensor.dtype + "'.");
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -1326,6 +1344,7 @@ tf.Tensor = class {
|
|
|
if (this._buffer) {
|
|
if (this._buffer) {
|
|
|
const DataType = tf.proto.tensorflow.DataType;
|
|
const DataType = tf.proto.tensorflow.DataType;
|
|
|
switch (this._tensor.dtype) {
|
|
switch (this._tensor.dtype) {
|
|
|
|
|
+ case DataType.DT_HALF:
|
|
|
case DataType.DT_FLOAT:
|
|
case DataType.DT_FLOAT:
|
|
|
case DataType.DT_DOUBLE:
|
|
case DataType.DT_DOUBLE:
|
|
|
case DataType.DT_QINT8:
|
|
case DataType.DT_QINT8:
|
|
@@ -1395,6 +1414,11 @@ tf.Tensor = class {
|
|
|
else {
|
|
else {
|
|
|
if (context.rawData) {
|
|
if (context.rawData) {
|
|
|
switch (this._tensor.dtype) {
|
|
switch (this._tensor.dtype) {
|
|
|
|
|
+ case tf.proto.tensorflow.DataType.DT_HALF:
|
|
|
|
|
+ results.push(context.rawData.getFloat16(context.index, true));
|
|
|
|
|
+ context.index += 2;
|
|
|
|
|
+ context.count++;
|
|
|
|
|
+ break;
|
|
|
case tf.proto.tensorflow.DataType.DT_FLOAT:
|
|
case tf.proto.tensorflow.DataType.DT_FLOAT:
|
|
|
results.push(context.rawData.getFloat32(context.index, true));
|
|
results.push(context.rawData.getFloat32(context.index, true));
|
|
|
context.index += 4;
|
|
context.index += 4;
|