|
|
@@ -1221,6 +1221,15 @@ tf.Tensor = class {
|
|
|
else {
|
|
|
const DataType = tf.proto.tensorflow.DataType;
|
|
|
switch (tensor.dtype) {
|
|
|
+ case DataType.DT_BFLOAT16: {
|
|
|
+ const values = tensor.half_val || [];
|
|
|
+ this._buffer = new Uint8Array(values.length << 2);
|
|
|
+ const view = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
|
|
|
+ for (let i = 0; i < values.length; i++) {
|
|
|
+ view.setUint32(i << 2, values[i] << 16, true);
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
case DataType.DT_HALF: {
|
|
|
const values = tensor.half_val || [];
|
|
|
this._buffer = new Uint8Array(values.length << 1);
|
|
|
@@ -1344,6 +1353,7 @@ tf.Tensor = class {
|
|
|
if (this._buffer) {
|
|
|
const DataType = tf.proto.tensorflow.DataType;
|
|
|
switch (this._tensor.dtype) {
|
|
|
+ case DataType.DT_BFLOAT16:
|
|
|
case DataType.DT_HALF:
|
|
|
case DataType.DT_FLOAT:
|
|
|
case DataType.DT_DOUBLE:
|
|
|
@@ -1419,6 +1429,7 @@ tf.Tensor = class {
|
|
|
context.index += 2;
|
|
|
context.count++;
|
|
|
break;
|
|
|
+ case tf.proto.tensorflow.DataType.DT_BFLOAT16:
|
|
|
case tf.proto.tensorflow.DataType.DT_FLOAT:
|
|
|
results.push(context.rawData.getFloat32(context.index, true));
|
|
|
context.index += 4;
|