Procházet zdrojové kódy

Add TensorFlow test file (#187)

Lutz Roeder před 3 roky
rodič
revize
863262309c
2 změnil soubory, kde provedl 40 přidání a 9 odebrání
  1. 33 9
      source/tf.js
  2. 7 0
      test/models.json

+ 33 - 9
source/tf.js

@@ -1221,34 +1221,52 @@ tf.Tensor = class {
             else {
                 const DataType = tf.proto.tensorflow.DataType;
                 switch (tensor.dtype) {
-                    case DataType.DT_FLOAT:
+                    case DataType.DT_HALF: {
+                        const values = tensor.half_val || [];
+                        this._buffer = new Uint8Array(values.length << 1);
+                        const view = new DataView(this._buffer.buffer, this._buffer.byteOffset, this._buffer.byteLength);
+                        for (let i = 0; i < values.length; i++) {
+                            view.setUint16(i << 1, values[i], true);
+                        }
+                        break;
+                    }
+                    case DataType.DT_FLOAT: {
                         this._data = tensor.float_val || null;
                         break;
-                    case DataType.DT_DOUBLE:
+                    }
+                    case DataType.DT_DOUBLE: {
                         this._data = tensor.double_val || null;
                         break;
+                    }
                     case DataType.DT_INT8:
                     case DataType.DT_UINT8:
-                    case DataType.DT_INT32:
+                    case DataType.DT_INT32: {
                         this._data = tensor.int_val || null;
                         break;
-                    case DataType.DT_UINT32:
+                    }
+                    case DataType.DT_UINT32: {
                         this._data = tensor.uint32_val || null;
                         break;
-                    case DataType.DT_INT64:
+                    }
+                    case DataType.DT_INT64: {
                         this._data = tensor.int64_val || null;
                         break;
-                    case DataType.DT_UINT64:
+                    }
+                    case DataType.DT_UINT64: {
                         this._data = tensor.uint64_val || null;
                         break;
-                    case DataType.DT_BOOL:
+                    }
+                    case DataType.DT_BOOL: {
                         this._data = tensor.bool_val || null;
                         break;
-                    case DataType.DT_STRING:
+                    }
+                    case DataType.DT_STRING: {
                         this._data = tensor.string_val || null;
                         break;
-                    default:
+                    }
+                    default: {
                         throw new tf.Error("Unsupported tensor data type '" + tensor.dtype + "'.");
+                    }
                 }
             }
         }
@@ -1326,6 +1344,7 @@ tf.Tensor = class {
         if (this._buffer) {
             const DataType = tf.proto.tensorflow.DataType;
             switch (this._tensor.dtype) {
+                case DataType.DT_HALF:
                 case DataType.DT_FLOAT:
                 case DataType.DT_DOUBLE:
                 case DataType.DT_QINT8:
@@ -1395,6 +1414,11 @@ tf.Tensor = class {
                 else {
                     if (context.rawData) {
                         switch (this._tensor.dtype) {
+                            case tf.proto.tensorflow.DataType.DT_HALF:
+                                results.push(context.rawData.getFloat16(context.index, true));
+                                context.index += 2;
+                                context.count++;
+                                break;
                             case tf.proto.tensorflow.DataType.DT_FLOAT:
                                 results.push(context.rawData.getFloat32(context.index, true));
                                 context.index += 4;

+ 7 - 0
test/models.json

@@ -5626,6 +5626,13 @@
     "action":   "skip-render",
     "format":   "TensorFlow MetaGraph"
   },
+  {
+    "type":     "tf",
+    "target":   "float16.txt",
+    "source":   "https://github.com/lutzroeder/netron/files/8547644/float16.txt.zip[float16.txt]",
+    "format":   "TensorFlow Graph",
+    "link":     "https://github.com/lutzroeder/netron/issues/187"
+  },
   {
     "type":     "tf",
     "target":   "graph_missing_function.pbtxt",