瀏覽代碼

Add HDF5 Weights test file (#467)

Lutz Roeder 3 年之前
父節點
當前提交
564ed10ce8
共有 2 個文件被更改,包括 101 次插入63 次删除
  1. 94 63
      source/hdf5.js
  2. 7 0
      test/models.json

+ 94 - 63
source/hdf5.js

@@ -212,28 +212,68 @@ hdf5.Variable = class {
                 break;
             case 2: { // Chunked
                 const dimensionality = this._dataLayout.dimensionality;
-                if (dimensionality === 2) {
-                    const tree = new hdf5.Tree(this._reader.at(this._dataLayout.address), dimensionality);
-                    const itemsize = this._dataLayout.datasetElementSize;
-                    const shape = this._dataspace.shape;
-                    const size = shape.reduce((a, b) => a * b, 1) * itemsize;
-                    const data = new Uint8Array(size);
-                    for (const node of tree.nodes) {
-                        if (node.filterMask !== 0) {
-                            return null;
+                const tree = new hdf5.Tree(this._reader.at(this._dataLayout.address), dimensionality);
+                const item_size = this._dataLayout.datasetElementSize;
+                const chunk_shape = this._dataLayout.dimensionSizes;
+                const data_shape = this._dataspace.shape;
+                const chunk_size = chunk_shape.reduce((a, b) => a * b, 1);
+                const data_size = data_shape.reduce((a, b) => a * b, 1);
+                const max_dim = data_shape.length - 1;
+                let data_stride = 1;
+                const data_strides = data_shape.slice().reverse().map((d2) => {
+                    const s = data_stride;
+                    data_stride *= d2;
+                    return s;
+                }).reverse();
+                const data = new Uint8Array(data_size * item_size);
+                for (const node of tree.nodes) {
+                    if (node.filterMask !== 0) {
+                        return null;
+                    }
+                    let chunk = node.data;
+                    if (this._filterPipeline) {
+                        for (const filter of this._filterPipeline.filters) {
+                            chunk = filter.decode(chunk);
+                        }
+                    }
+                    const chunk_offset = node.fields;
+                    var data_pos = chunk_offset.slice();
+                    var chunk_pos = data_pos.map(() => 0);
+                    for (let chunk_index = 0; chunk_index < chunk_size; chunk_index++) {
+                        for (let i = max_dim; i >= 0; i--) {
+                            if (chunk_pos[i] >= chunk_shape[i]) {
+                                chunk_pos[i] = 0;
+                                data_pos[i] = chunk_offset[i];
+                                if (i > 0) {
+                                    chunk_pos[i - 1]++;
+                                    data_pos[i - 1]++;
+                                }
+                            }
+                            else {
+                                break;
+                            }
+                        }
+                        let index = 0;
+                        let inbounds = true;
+                        const length = data_pos.length - 1;
+                        for (let i = 0; i < length; i++) {
+                            const pos = data_pos[i];
+                            inbounds = inbounds && pos < data_shape[i];
+                            index += pos * data_strides[i];
                         }
-                        const start = node.fields.slice(0, 1).reduce((a, b) => a * b, 1) * itemsize;
-                        let chunk = node.data;
-                        if (this._filterPipeline) {
-                            for (const filter of this._filterPipeline.filters) {
-                                chunk = filter.decode(chunk);
+                        if (inbounds) {
+                            let chunk_offset = chunk_index * item_size;
+                            let target_offset = index * item_size;
+                            const target_end = target_offset + item_size;
+                            while (target_offset < target_end) {
+                                data[target_offset++] = chunk[chunk_offset++];
                             }
                         }
-                        data.set(chunk, start);
+                        chunk_pos[max_dim]++;
+                        data_pos[max_dim]++;
                     }
-                    return data;
                 }
-                break;
+                return data;
             }
             default: {
                 throw new hdf5.Error("Unsupported data layout class '" + this.layoutClass + "'.");
@@ -245,13 +285,13 @@ hdf5.Variable = class {
 
 hdf5.BinaryReader = class {
 
-    constructor(buffer) {
-        if (buffer) {
-            this._buffer = buffer;
-            this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-            this._position = 0;
-            this._offset = 0;
-        }
+    constructor(buffer, view, position, offset, offsetSize, lengthSize) {
+        this._buffer = buffer;
+        this._view = view || new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
+        this._position = position || 0;
+        this._offset = offset || 0;
+        this._offsetSize = offsetSize;
+        this._lengthSize = lengthSize;
     }
 
     initialize() {
@@ -275,49 +315,49 @@ hdf5.BinaryReader = class {
     int8() {
         const offset = this._offset;
         this.skip(1);
-        return this._dataView.getInt8(this._position + offset);
+        return this._view.getInt8(this._position + offset);
     }
 
     byte() {
         const offset = this._offset;
         this.skip(1);
-        return this._dataView.getUint8(this._position + offset);
+        return this._view.getUint8(this._position + offset);
     }
 
     int16() {
         const offset = this._position + this._offset;
         this.skip(2);
-        return this._dataView.getInt16(offset, true);
+        return this._view.getInt16(offset, true);
     }
 
     uint16() {
         const offset = this._position + this._offset;
         this.skip(2);
-        return this._dataView.getUint16(offset, true);
+        return this._view.getUint16(offset, true);
     }
 
     int32() {
         const offset = this._position + this._offset;
         this.skip(4);
-        return this._dataView.getInt32(offset, true);
+        return this._view.getInt32(offset, true);
     }
 
     uint32() {
         const offset = this._position + this._offset;
         this.skip(4);
-        return this._dataView.getUint32(offset, true);
+        return this._view.getUint32(offset, true);
     }
 
     int64() {
         const offset = this._position + this._offset;
         this.skip(8);
-        return this._dataView.getInt64(offset, true).toNumber();
+        return this._view.getInt64(offset, true).toNumber();
     }
 
     uint64() {
         const offset = this._position + this._offset;
         this.skip(8);
-        return this._dataView.getUint64(offset, true).toNumber();
+        return this._view.getUint64(offset, true).toNumber();
     }
 
     uint(size) {
@@ -333,7 +373,7 @@ hdf5.BinaryReader = class {
     float16() {
         const offset = this._offset;
         this.skip(2);
-        const value = this._dataView.getUint16(this._position + offset, true);
+        const value = this._view.getUint16(this._position + offset, true);
         // decode float16 value
         const s = (value & 0x8000) >> 15;
         const e = (value & 0x7C00) >> 10;
@@ -350,13 +390,13 @@ hdf5.BinaryReader = class {
     float32() {
         const offset = this._position + this._offset;
         this.skip(4);
-        return this._dataView.getFloat32(offset, true);
+        return this._view.getFloat32(offset, true);
     }
 
     float64() {
         const offset = this._position + this._offset;
         this.skip(8);
-        return this._dataView.getFloat64(offset, true);
+        return this._view.getFloat64(offset, true);
     }
 
     string(size, encoding) {
@@ -393,7 +433,7 @@ hdf5.BinaryReader = class {
             case 8: {
                 const position = this._position + this._offset;
                 this.skip(8);
-                const value = this._dataView.getUint64(position, true);
+                const value = this._view.getUint64(position, true);
                 if (value.low === -1 && value.high === -1) {
                     return undefined;
                 }
@@ -417,7 +457,7 @@ hdf5.BinaryReader = class {
             case 8: {
                 const position = this._position + this._offset;
                 this.skip(8);
-                const value = this._dataView.getUint64(position, true);
+                const value = this._view.getUint64(position, true);
                 if (value.low === -1 && value.high === -1) {
                     return undefined;
                 }
@@ -437,25 +477,11 @@ hdf5.BinaryReader = class {
     }
 
     at(position) {
-        const reader = new hdf5.BinaryReader(null);
-        reader._buffer = this._buffer;
-        reader._dataView = this._dataView;
-        reader._position = position;
-        reader._offset = 0;
-        reader._offsetSize = this._offsetSize;
-        reader._lengthSize = this._lengthSize;
-        return reader;
+        return new hdf5.BinaryReader(this._buffer, this._view, position, 0, this._offsetSize, this._lengthSize);
     }
 
     clone() {
-        const reader =  new hdf5.BinaryReader(this._buffer, this._position);
-        reader._buffer = this._buffer;
-        reader._dataView = this._dataView;
-        reader._position = this._position;
-        reader._offset = this._offset;
-        reader._offsetSize = this._offsetSize;
-        reader._lengthSize = this._lengthSize;
-        return reader;
+        return new hdf5.BinaryReader(this._buffer, this._view, this._position, this._offset, this._offsetSize, this._lengthSize);
     }
 
     align(mod) {
@@ -540,16 +566,17 @@ hdf5.DataObjectHeader = class {
         this.attributes = [];
         this.links = [];
         this.continuations = [];
-        const version = reader.match('OHDR') ? reader.byte() : reader.byte();
+        reader.match('OHDR');
+        const version = reader.byte();
         switch (version) {
             case 1: {
                 reader.skip(1);
-                const messageCount = reader.uint16();
+                const count = reader.uint16();
                 reader.uint32();
                 const objectHeaderSize = reader.uint32();
                 reader.align(8);
                 let end = reader.position + objectHeaderSize;
-                for (let i = 0; i < messageCount; i++) {
+                for (let i = 0; i < count; i++) {
                     const type = reader.uint16();
                     const size = reader.uint16();
                     const flags = reader.byte();
@@ -1124,8 +1151,9 @@ hdf5.DataLayout = class {
                 }
                 break;
             }
-            default:
+            default: {
                 throw new hdf5.Error('Unsupported data layout version \'' + version + '\'.');
+            }
         }
     }
 };
@@ -1330,13 +1358,13 @@ hdf5.Tree = class {
         }
         this.type = reader.byte();
         this.level = reader.byte();
-        const entriesUsed = reader.uint16();
+        const entries = reader.uint16();
         reader.offset(); // address of left sibling
         reader.offset(); // address of right sibling
         this.nodes = [];
         switch (this.type) {
-            case 0: // Group nodes
-                for (let i = 0; i < entriesUsed; i++) {
+            case 0: { // Group nodes
+                for (let i = 0; i < entries; i++) {
                     reader.length();
                     const childPointer = reader.offset();
                     if (this.level == 0) {
@@ -1349,8 +1377,9 @@ hdf5.Tree = class {
                     }
                 }
                 break;
-            case 1: // Raw data chunk nodes
-                for (let i = 0; i < entriesUsed; i++) {
+            }
+            case 1: { // Raw data chunk nodes
+                for (let i = 0; i < entries; i++) {
                     const size = reader.int32();
                     const filterMask = reader.int32();
                     const fields = [];
@@ -1368,8 +1397,10 @@ hdf5.Tree = class {
                     }
                 }
                 break;
-            default:
+            }
+            default: {
                 throw new hdf5.Error('Unsupported B-Tree node type \'' + this.type + '\'.');
+            }
         }
     }
 };

+ 7 - 0
test/models.json

@@ -2027,6 +2027,13 @@
     "format":   "Keras v2.2.3",
     "link":     "https://github.com/lutzroeder/netron/issues/157"
   },
+  {
+    "type":     "keras",
+    "target":   "data_prediction.hdf5",
+    "source":   "https://github.com/lutzroeder/netron/files/8694214/data_prediction.hdf5.zip[data_prediction.hdf5]",
+    "format":   "HDF5 Weights",
+    "link":     "https://github.com/lutzroeder/netron/issues/467"
+  },
   {
     "type":     "keras",
     "target":   "DenseNet121.h5.zip",