Lutz Roeder 3 rokov pred
rodič
commit
c1b431073e
1 zmenil súbory, kde vykonal 11 pridanie a 0 odobranie
  1. 11 0
      source/tf.js

+ 11 - 0
source/tf.js

@@ -1221,6 +1221,15 @@ tf.Tensor = class {
             else {
                 const DataType = tf.proto.tensorflow.DataType;
                 switch (tensor.dtype) {
+                    case DataType.DT_BFLOAT16: {
+                        const values = tensor.half_val || [];
+                        this._buffer = new Uint8Array(values.length << 2);
+                        const view = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
+                        for (let i = 0; i < values.length; i++) {
+                            view.setUint32(i << 2, values[i] << 16, true);
+                        }
+                        break;
+                    }
                     case DataType.DT_HALF: {
                         const values = tensor.half_val || [];
                         this._buffer = new Uint8Array(values.length << 1);
@@ -1344,6 +1353,7 @@ tf.Tensor = class {
         if (this._buffer) {
             const DataType = tf.proto.tensorflow.DataType;
             switch (this._tensor.dtype) {
+                case DataType.DT_BFLOAT16:
                 case DataType.DT_HALF:
                 case DataType.DT_FLOAT:
                 case DataType.DT_DOUBLE:
@@ -1419,6 +1429,7 @@ tf.Tensor = class {
                                 context.index += 2;
                                 context.count++;
                                 break;
+                            case tf.proto.tensorflow.DataType.DT_BFLOAT16:
                             case tf.proto.tensorflow.DataType.DT_FLOAT:
                                 results.push(context.rawData.getFloat32(context.index, true));
                                 context.index += 4;