Selaa lähdekoodia

Add TensorFlow Memmapped test file (#836)

Lutz Roeder 3 vuotta sitten
vanhempi
sitoutus
eb32d6c256
2 muutettua tiedostoa jossa 39 lisäystä ja 17 poistoa
  1. 32 17
      source/tf.js
  2. 7 0
      test/models.json

+ 32 - 17
source/tf.js

@@ -208,8 +208,7 @@ tf.ModelFactory = class {
                 stream.seek(-8);
                 const buffer = stream.read(8);
                 stream.seek(0);
-                const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-                const offset = view.getUint64(0, true).toNumber();
+                const offset = new base.BinaryReader(buffer).uint64();
                 if (offset < stream.length) {
                     return 'tf.pb.mmap';
                 }
@@ -590,37 +589,53 @@ tf.ModelFactory = class {
             };
             const openMemmappedFileSystemDirectory = (context) => {
                 const stream = context.stream;
-                const readDirectoryBuffer = (stream) => {
+                const readDirectoryOffset = (stream) => {
                     stream.seek(-8);
-                    const end = stream.position;
                     const buffer = stream.read(8);
-                    const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-                    const offset = view.getUint64(0, true).toNumber();
-                    stream.seek(offset);
-                    return stream.read(end - offset);
+                    const reader = new base.BinaryReader(buffer);
+                    return reader.uint64();
                 };
-                const readDirectory = (stream) => {
-                    const buffer = readDirectoryBuffer(stream);
+                const readDirectory = (stream, offset) => {
+                    const end = stream.position - 8;
+                    stream.seek(offset);
+                    const buffer = stream.read(end - offset);
                     const reader = protobuf.BinaryReader.open(buffer);
                     return tf.proto.tensorflow.MemmappedFileSystemDirectory.decode(reader);
                 };
-                const directory = readDirectory(stream);
+                const offset = readDirectoryOffset(stream);
+                const directory = readDirectory(stream, offset);
                 const elements = new Map();
                 for (const element of directory.element) {
-                    const offset = element.offset ? element.offset.toNumber() : 0;
-                    const length = element.length.toNumber();
-                    stream.seek(offset);
-                    const buffer = stream.read(length);
                     const name = element.name;
                     if (elements.has(name)) {
                         throw new tf.Error("Memory mapped file directory contains duplicate '" + name + "'.");
                     }
-                    elements.set(name, buffer);
+                    elements.set(name, {
+                        offset: element.offset ? element.offset.toNumber() : 0,
+                        length: element.length ? element.length.toNumber() : 0
+                    });
+                }
+                const offsets = Array.from(elements).map((entry) => entry[1].offset);
+                offsets.push(offset);
+                for (const value of elements.values()) {
+                    if (value.length === 0) {
+                        const min = Math.min.apply(null, offsets.filter((offset) => offset > value.offset));
+                        if (Number.isInteger(min)) {
+                            value.length = min - value.offset;
+                        }
+                    }
+                }
+                for (const entry of elements) {
+                    const offset = entry[1].offset;
+                    const length = entry[1].length;
+                    stream.seek(offset);
+                    entry[1].buffer = stream.read(length);
                 }
                 if (!elements.has('memmapped_package://.')) {
                     throw new tf.Error('Memory mapped file directory does not contain tensorflow.GraphDef root.');
                 }
-                const buffer = elements.get('memmapped_package://.');
+                const element = elements.get('memmapped_package://.');
+                const buffer = element.buffer;
                 const reader = protobuf.BinaryReader.open(buffer);
                 const graph_def = tf.proto.tensorflow.GraphDef.decode(reader);
                 const format = 'TensorFlow GraphDef Memmapped';

+ 7 - 0
test/models.json

@@ -5889,6 +5889,13 @@
     "format":   "TensorFlow Graph",
     "link":     "https://github.com/lutzroeder/netron/issues/895"
   },
+  {
+    "type":     "tf",
+    "target":   "output_graph.pbmm",
+    "source":   "https://github.com/lutzroeder/netron/files/8645971/output_graph.pbmm.zip[output_graph.pbmm]",
+    "format":   "TensorFlow GraphDef Memmapped",
+    "link":     "https://github.com/lutzroeder/netron/issues/836"
+  },
   {
     "type":     "tf",
     "target":   "pose_estimation_for_mobile.pb",