Forráskód Böngészése

Workaround for large float arrays in Chromium (#131)

Lutz Roeder 5 éve
szülő
commit
0e8858abca
4 módosított fájl, 75 hozzáadás és 27 törlés
  1. 30 4
      src/tf-proto.js
  2. 37 23
      src/tf.js
  3. 6 0
      test/models.json
  4. 2 0
      tools/tf

+ 30 - 4
src/tf-proto.js

@@ -2994,8 +2994,21 @@
                             message.float_val = [];
                         if ((tag & 7) === 2) {
                             var end2 = reader.uint32() + reader.pos;
-                            while (reader.pos < end2)
-                                message.float_val.push(reader.float());
+                            if (message.float_val.length == 0 && (end2 - reader.pos) > 1048576) {
+                                var float_valLength = end2 - reader.pos;
+                                var float_valView = new DataView(reader.buf.buffer, reader.buf.byteOffset + reader.pos, float_valLength);
+                                float_valLength = float_valLength >>> 2;
+                                var float_val = new Float32Array(float_valLength);
+                                for (var i = 0; i < float_valLength; i++) {
+                                    float_val[i] = float_valView.getFloat32(i << 2, true);
+                                }
+                                message.float_val = float_val;
+                                reader.pos = end2;
+                            }
+                            else {
+                                while (reader.pos < end2)
+                                    message.float_val.push(reader.float());
+                            }
                         } else
                             message.float_val.push(reader.float());
                         break;
@@ -3004,8 +3017,21 @@
                             message.double_val = [];
                         if ((tag & 7) === 2) {
                             var end2 = reader.uint32() + reader.pos;
-                            while (reader.pos < end2)
-                                message.double_val.push(reader.double());
+                            if (message.double_val.length == 0 && (end2 - reader.pos) > 1048576) {
+                                var double_valLength = end2 - reader.pos;
+                                var double_valView = new DataView(reader.buf.buffer, reader.buf.byteOffset + reader.pos, double_valLength);
+                                double_valLength = double_valLength >>> 3;
+                                var double_val = new Float64Array(double_valLength);
+                                for (var i = 0; i < double_valLength; i++) {
+                                    double_val[i] = double_valView.getFloat64(i << 3, true);
+                                }
+                                message.double_val = double_val;
+                                reader.pos = end2;
+                            }
+                            else {
+                                while (reader.pos < end2)
+                                    message.double_val.push(reader.double());
+                            }
                         } else
                             message.double_val.push(reader.double());
                         break;

+ 37 - 23
src/tf.js

@@ -1304,30 +1304,29 @@ tf.TensorBundle = class {
         const indexOffset = reader.varint64();
         const indexSize = reader.varint64();
         reader.seek(indexOffset);
-        let indexData = reader.bytes(indexSize);
+        const indexReader = reader.clone(indexSize);
         let indexCompression = reader.byte();
         if (indexCompression !== 0) { // kNoCompression
             throw new tf.Error("Unsupported block compression '" + indexCompression + "'.");
         }
-        let indexReader = new tf.TensorBundle.BinaryReader(indexData);
         indexReader.seek(-4);
         const numRestarts = indexReader.int32();
         indexReader.seek(-4 - (4 * numRestarts));
-        let restartOffsets = [];
+        const restartOffsets = [];
         for (let i = 0; i < numRestarts; i++) {
             restartOffsets.push(indexReader.int32());
         }
         const textDecoder = new TextDecoder();
-        let entries = new Map();
+        const entries = new Map();
         for (let i = 0; i < numRestarts; i++) {
             indexReader.seek(restartOffsets[i]);
             indexReader.varint32(); // index shared size
             const indexNonSharedSize = indexReader.varint32();
             const indexValueSize = indexReader.varint32();
             indexReader.skip(indexNonSharedSize);
-            let indexValueReader = new tf.TensorBundle.BinaryReader(indexReader.bytes(indexValueSize));
+            const indexValueReader = indexReader.clone(indexValueSize);
             reader.seek(indexValueReader.varint64());
-            let blockReader = new tf.TensorBundle.BinaryReader(reader.bytes(indexValueReader.varint64()));
+            const blockReader = reader.clone(indexValueReader.varint64());
             let key = '';
             while (!blockReader.end()) {
                 const sharedSize = blockReader.varint32();
@@ -1350,7 +1349,7 @@ tf.TensorBundle = class {
         }
         const header = tf.proto.BundleHeaderProto.decode(entries.get(''));
         const numShards = header.num_shards;
-        let promises = [];
+        const promises = [];
         for (let i = 0; i < numShards; i++) {
             const shardIndex = ('0000' + i).slice(-5);
             const shardCount = ('0000' + numShards).slice(-5);
@@ -1374,7 +1373,7 @@ tf.TensorBundle = class {
         switch (format) {
             case 1: {
                 const header = tf.proto.SavedTensorSlices.decode(entries.get(''));
-                let data = new Map();
+                const data = new Map();
                 for (const pair of entries) {
                     if (pair[0] !== '' && pair[0] !== 'global_step') {
                         const slices = tf.proto.SavedTensorSlices.decode(pair[1]);
@@ -1390,7 +1389,7 @@ tf.TensorBundle = class {
                             }
                         }
                         else {
-                            let item = data.get(name);
+                            const item = data.get(name);
                             if (item !== null) {
                                 if (tensor[item.key] && tensor[item.key].length > 0) {
                                     item.value = item.value.concat(tensor[item.key]);
@@ -1404,7 +1403,7 @@ tf.TensorBundle = class {
                 }
                 for (const meta of header.meta.tensor) {
                     if (meta.name !== 'global_step') {
-                        let tensor = new tf.proto.TensorProto();
+                        const tensor = new tf.proto.TensorProto();
                         tensor.dtype = meta.type;
                         tensor.tensor_shape = meta.shape;
                         const item = data.get(meta.name);
@@ -1420,7 +1419,7 @@ tf.TensorBundle = class {
                 entries.forEach((value, name) => {
                     if (name !== '') {
                         const entry = tf.proto.BundleEntryProto.decode(value);
-                        let tensor = new tf.proto.TensorProto();
+                        const tensor = new tf.proto.TensorProto();
                         tensor.dtype = entry.dtype;
                         tensor.tensor_shape = entry.shape;
                         const offset = (entry.offset instanceof long.Long) ? entry.offset.toNumber() : entry.offset;
@@ -1448,32 +1447,47 @@ tf.TensorBundle = class {
 tf.TensorBundle.BinaryReader = class {
 
     constructor(buffer) {
-        this._buffer = buffer;
-        this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-        this._position = 0;
+        if (buffer) {
+            this._buffer = buffer;
+            this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
+            this._position = 0;
+            this._start = 0;
+            this._end = this._buffer.length;
+        }
     }
 
     seek(position) {
-        this._position = position >= 0 ? position : this._buffer.length + position;
-        if (this._position > this._buffer.length) {
-            throw new tf.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        this._position = position >= 0 ? this._start + position : this._end + position;
+        if (this._position > this._end) {
+            throw new tf.Error('Expected ' + (this._position - this._end) + ' more bytes. The file might be corrupted. Unexpected end of file.');
         }
     }
 
     skip(offset) {
         this._position += offset;
-        if (this._position > this._buffer.length) {
-            throw new tf.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        if (this._position > this._end) {
+            throw new tf.Error('Expected ' + (this._position - this._end) + ' more bytes. The file might be corrupted. Unexpected end of file.');
         }
     }
 
     end() {
-        return this._position >= this._buffer.length;
+        return this._position >= this._end;
+    }
+
+    clone(size) {
+        const reader = new tf.TensorBundle.BinaryReader();
+        reader._buffer = this._buffer;
+        reader._dataView = this._dataView;
+        reader._start = this._position;
+        reader._position = this._position;
+        this.skip(size);
+        reader._end = this._position;
+        return reader;
     }
 
-    bytes(length) {
+    bytes(size) {
         const position = this._position;
-        this.skip(length);
+        this.skip(size);
         return this._buffer.subarray(position, this._position);
     }
 
@@ -1496,7 +1510,7 @@ tf.TensorBundle.BinaryReader = class {
     varint64() {
         let result = 0;
         for (let shift = 0; shift <= 63; shift += 7) {
-            let byte = this.byte();
+            const byte = this.byte();
             if (byte & 128) {
                 result |= (byte & 127) << shift;
             }

+ 6 - 0
test/models.json

@@ -4966,6 +4966,12 @@
     "format": "TensorFlow Saved Model v1",
     "link":   "https://github.com/tensorflow/tensorflow/issues/9169"
   },
+  {
+    "type":   "tf",
+    "target": "vgg_16.ckpt",
+    "source": "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz[vgg_16.ckpt]",
+    "format": "TensorFlow Tensor Bundle v1"
+  },
   {
     "type":   "tf",
     "target": "vgg_19.pb",

+ 2 - 0
tools/tf

@@ -69,6 +69,8 @@ schema() {
         ./third_party/src/tensorflow/tensorflow/core/protobuf/tensor_bundle.proto \
         ./third_party/src/tensorflow/tensorflow/core/framework/tensor_slice.proto \
         ./third_party/src/tensorflow/tensorflow/core/util/saved_tensor_slice.proto
+    node ./tools/update_pbjs.js array ./src/tf-proto.js float_val float 1
+    node ./tools/update_pbjs.js array ./src/tf-proto.js double_val double 1
     if [[ -n ${crlf} ]]; then
         unix2dos --quiet --newfile ./src/tf-proto.js ./src/tf-proto.js
     fi