|
|
@@ -1304,30 +1304,29 @@ tf.TensorBundle = class {
|
|
|
const indexOffset = reader.varint64();
|
|
|
const indexSize = reader.varint64();
|
|
|
reader.seek(indexOffset);
|
|
|
- let indexData = reader.bytes(indexSize);
|
|
|
+ const indexReader = reader.clone(indexSize);
|
|
|
let indexCompression = reader.byte();
|
|
|
if (indexCompression !== 0) { // kNoCompression
|
|
|
throw new tf.Error("Unsupported block compression '" + indexCompression + "'.");
|
|
|
}
|
|
|
- let indexReader = new tf.TensorBundle.BinaryReader(indexData);
|
|
|
indexReader.seek(-4);
|
|
|
const numRestarts = indexReader.int32();
|
|
|
indexReader.seek(-4 - (4 * numRestarts));
|
|
|
- let restartOffsets = [];
|
|
|
+ const restartOffsets = [];
|
|
|
for (let i = 0; i < numRestarts; i++) {
|
|
|
restartOffsets.push(indexReader.int32());
|
|
|
}
|
|
|
const textDecoder = new TextDecoder();
|
|
|
- let entries = new Map();
|
|
|
+ const entries = new Map();
|
|
|
for (let i = 0; i < numRestarts; i++) {
|
|
|
indexReader.seek(restartOffsets[i]);
|
|
|
indexReader.varint32(); // index shared size
|
|
|
const indexNonSharedSize = indexReader.varint32();
|
|
|
const indexValueSize = indexReader.varint32();
|
|
|
indexReader.skip(indexNonSharedSize);
|
|
|
- let indexValueReader = new tf.TensorBundle.BinaryReader(indexReader.bytes(indexValueSize));
|
|
|
+ const indexValueReader = indexReader.clone(indexValueSize);
|
|
|
reader.seek(indexValueReader.varint64());
|
|
|
- let blockReader = new tf.TensorBundle.BinaryReader(reader.bytes(indexValueReader.varint64()));
|
|
|
+ const blockReader = reader.clone(indexValueReader.varint64());
|
|
|
let key = '';
|
|
|
while (!blockReader.end()) {
|
|
|
const sharedSize = blockReader.varint32();
|
|
|
@@ -1350,7 +1349,7 @@ tf.TensorBundle = class {
|
|
|
}
|
|
|
const header = tf.proto.BundleHeaderProto.decode(entries.get(''));
|
|
|
const numShards = header.num_shards;
|
|
|
- let promises = [];
|
|
|
+ const promises = [];
|
|
|
for (let i = 0; i < numShards; i++) {
|
|
|
const shardIndex = ('0000' + i).slice(-5);
|
|
|
const shardCount = ('0000' + numShards).slice(-5);
|
|
|
@@ -1374,7 +1373,7 @@ tf.TensorBundle = class {
|
|
|
switch (format) {
|
|
|
case 1: {
|
|
|
const header = tf.proto.SavedTensorSlices.decode(entries.get(''));
|
|
|
- let data = new Map();
|
|
|
+ const data = new Map();
|
|
|
for (const pair of entries) {
|
|
|
if (pair[0] !== '' && pair[0] !== 'global_step') {
|
|
|
const slices = tf.proto.SavedTensorSlices.decode(pair[1]);
|
|
|
@@ -1390,7 +1389,7 @@ tf.TensorBundle = class {
|
|
|
}
|
|
|
}
|
|
|
else {
|
|
|
- let item = data.get(name);
|
|
|
+ const item = data.get(name);
|
|
|
if (item !== null) {
|
|
|
if (tensor[item.key] && tensor[item.key].length > 0) {
|
|
|
item.value = item.value.concat(tensor[item.key]);
|
|
|
@@ -1404,7 +1403,7 @@ tf.TensorBundle = class {
|
|
|
}
|
|
|
for (const meta of header.meta.tensor) {
|
|
|
if (meta.name !== 'global_step') {
|
|
|
- let tensor = new tf.proto.TensorProto();
|
|
|
+ const tensor = new tf.proto.TensorProto();
|
|
|
tensor.dtype = meta.type;
|
|
|
tensor.tensor_shape = meta.shape;
|
|
|
const item = data.get(meta.name);
|
|
|
@@ -1420,7 +1419,7 @@ tf.TensorBundle = class {
|
|
|
entries.forEach((value, name) => {
|
|
|
if (name !== '') {
|
|
|
const entry = tf.proto.BundleEntryProto.decode(value);
|
|
|
- let tensor = new tf.proto.TensorProto();
|
|
|
+ const tensor = new tf.proto.TensorProto();
|
|
|
tensor.dtype = entry.dtype;
|
|
|
tensor.tensor_shape = entry.shape;
|
|
|
const offset = (entry.offset instanceof long.Long) ? entry.offset.toNumber() : entry.offset;
|
|
|
@@ -1448,32 +1447,47 @@ tf.TensorBundle = class {
|
|
|
tf.TensorBundle.BinaryReader = class {
|
|
|
|
|
|
constructor(buffer) {
|
|
|
- this._buffer = buffer;
|
|
|
- this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
- this._position = 0;
|
|
|
+ if (buffer) {
|
|
|
+ this._buffer = buffer;
|
|
|
+ this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
+ this._position = 0;
|
|
|
+ this._start = 0;
|
|
|
+ this._end = this._buffer.length;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
seek(position) {
|
|
|
- this._position = position >= 0 ? position : this._buffer.length + position;
|
|
|
- if (this._position > this._buffer.length) {
|
|
|
- throw new tf.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
|
|
|
+ this._position = position >= 0 ? this._start + position : this._end + position;
|
|
|
+ if (this._position > this._end) {
|
|
|
+ throw new tf.Error('Expected ' + (this._position - this._end) + ' more bytes. The file might be corrupted. Unexpected end of file.');
|
|
|
}
|
|
|
}
|
|
|
|
|
|
skip(offset) {
|
|
|
this._position += offset;
|
|
|
- if (this._position > this._buffer.length) {
|
|
|
- throw new tf.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
|
|
|
+ if (this._position > this._end) {
|
|
|
+ throw new tf.Error('Expected ' + (this._position - this._end) + ' more bytes. The file might be corrupted. Unexpected end of file.');
|
|
|
}
|
|
|
}
|
|
|
|
|
|
end() {
|
|
|
- return this._position >= this._buffer.length;
|
|
|
+ return this._position >= this._end;
|
|
|
+ }
|
|
|
+
|
|
|
+ clone(size) {
|
|
|
+ const reader = new tf.TensorBundle.BinaryReader();
|
|
|
+ reader._buffer = this._buffer;
|
|
|
+ reader._dataView = this._dataView;
|
|
|
+ reader._start = this._position;
|
|
|
+ reader._position = this._position;
|
|
|
+ this.skip(size);
|
|
|
+ reader._end = this._position;
|
|
|
+ return reader;
|
|
|
}
|
|
|
|
|
|
- bytes(length) {
|
|
|
+ bytes(size) {
|
|
|
const position = this._position;
|
|
|
- this.skip(length);
|
|
|
+ this.skip(size);
|
|
|
return this._buffer.subarray(position, this._position);
|
|
|
}
|
|
|
|
|
|
@@ -1496,7 +1510,7 @@ tf.TensorBundle.BinaryReader = class {
|
|
|
varint64() {
|
|
|
let result = 0;
|
|
|
for (let shift = 0; shift <= 63; shift += 7) {
|
|
|
- let byte = this.byte();
|
|
|
+ const byte = this.byte();
|
|
|
if (byte & 128) {
|
|
|
result |= (byte & 127) << shift;
|
|
|
}
|