Przeglądaj źródła

Add ONNX test file (#6)

Lutz Roeder 3 lat temu
rodzic
commit
f9beec5d03
3 zmienionych plików z 32 dodań i 0 usunięć
  1. 16 0
      source/base.js
  2. 8 0
      source/onnx.js
  3. 8 0
      test/models.json

+ 16 - 0
source/base.js

@@ -501,6 +501,22 @@ if (!DataView.prototype.setFloat16) {
     }
 }
 
+if (!DataView.prototype.getBfloat16) {
+    DataView.prototype.getBfloat16 = function(byteOffset, littleEndian) {
+        if (littleEndian) {
+            DataView.__bfloat16_uint16[0] = 0;
+            DataView.__bfloat16_uint16[1] = this.getUint16(byteOffset, littleEndian);
+        }
+        else {
+            DataView.__bfloat16_uint16[0] = this.getUint16(byteOffset, littleEndian);
+            DataView.__bfloat16_uint16[1] = 0;
+        }
+        return DataView.__bfloat16_float32[0];
+    };
+    DataView.__bfloat16_float32 = new Float32Array(1);
+    DataView.__bfloat16_uint16 = new Uint16Array(DataView.__bfloat16_float32.buffer, DataView.__bfloat16_float32.byteOffset, 2);
+}
+
 DataView.prototype.getInt64 = DataView.prototype.getInt64 || function(byteOffset, littleEndian) {
     return littleEndian ?
         new base.Int64(this.getUint32(byteOffset, true), this.getUint32(byteOffset + 4, true)) :

+ 8 - 0
source/onnx.js

@@ -847,6 +847,8 @@ onnx.Tensor = class {
                     case onnx.DataType.STRING:
                         data = tensor.string_data;
                         break;
+                    case onnx.DataType.BFLOAT16:
+                        break;
                     default:
                         throw new onnx.Error("Unsupported tensor data type '" + tensor.data_type + "'.");
                 }
@@ -1005,6 +1007,12 @@ onnx.Tensor = class {
                         data[i] = view.getUint64(i << 3, true);
                     }
                     break;
+                case onnx.DataType.BFLOAT16:
+                    data = new Array(buffer.length >> 1);
+                    for (let i = 0; i < data.length; i++) {
+                        data[i] = view.getBfloat16(i << 1, true);
+                    }
+                    break;
                 default:
                     throw new onnx.Error("Unsupported tensor data type '" + type + "'.");
             }

+ 8 - 0
test/models.json

@@ -3181,6 +3181,14 @@
     "producer": "tf2onnx 1.5.2",
     "link":     "https://github.com/onnx/models/blob/master/text/machine_comprehension/bert-squad/README.md"
   },
+  {
+    "type":     "onnx",
+    "target":   "mnist_bfloat16.onnx",
+    "source":   "https://github.com/lutzroeder/netron/files/8556259/mnist_bfloat16.onnx.zip[mnist_bfloat16.onnx]",
+    "format":   "ONNX v4",
+    "producer": "pytorch 1.12.0",
+    "link":     "https://github.com/lutzroeder/netron/issues/6"
+  },
   {
     "type":     "onnx",
     "target":   "bvlc_alexnet_opset_3.onnx.zip",