|
|
@@ -60,7 +60,7 @@ tf.ModelFactory = class {
|
|
|
return true;
|
|
|
}
|
|
|
const decode = (buffer, value) => {
|
|
|
- const reader = protobuf.Reader.create(buffer);
|
|
|
+ const reader = protobuf.BinaryReader.open(buffer);
|
|
|
const length = reader.length;
|
|
|
while (reader.position < length) {
|
|
|
const tag = reader.uint32();
|
|
|
@@ -112,7 +112,8 @@ tf.ModelFactory = class {
|
|
|
return true;
|
|
|
}
|
|
|
if (/^events.out.tfevents./.exec(identifier)) {
|
|
|
- if (tf.EventFileReader.open(context.stream)) {
|
|
|
+ const stream = context.stream;
|
|
|
+ if (tf.EventFileReader.open(stream)) {
|
|
|
return true;
|
|
|
}
|
|
|
}
|
|
|
@@ -177,7 +178,8 @@ tf.ModelFactory = class {
|
|
|
const openEventFile = () => {
|
|
|
let format = 'TensorFlow Event File';
|
|
|
let producer = null;
|
|
|
- const eventFileReader = tf.EventFileReader.open(context.stream);
|
|
|
+ const stream = context.stream;
|
|
|
+ const eventFileReader = tf.EventFileReader.open(stream);
|
|
|
const saved_model = new tf.proto.SavedModel();
|
|
|
for (;;) {
|
|
|
const event = eventFileReader.read();
|
|
|
@@ -203,7 +205,7 @@ tf.ModelFactory = class {
|
|
|
}
|
|
|
case 'graph_def': {
|
|
|
const buffer = event.graph_def;
|
|
|
- const reader = protobuf.Reader.create(buffer);
|
|
|
+ const reader = protobuf.BinaryReader.open(buffer);
|
|
|
const graph_def = tf.proto.GraphDef.decode(reader);
|
|
|
const meta_graph = new tf.proto.MetaGraphDef();
|
|
|
meta_graph.meta_info_def = new tf.proto.MetaGraphDef.MetaInfoDef();
|
|
|
@@ -317,8 +319,8 @@ tf.ModelFactory = class {
|
|
|
let saved_model = null;
|
|
|
if (tags.has('saved_model_schema_version') || tags.has('meta_graphs')) {
|
|
|
try {
|
|
|
- const buffer = context.stream.peek();
|
|
|
- const reader = protobuf.TextReader.create(buffer);
|
|
|
+ const stream = context.stream;
|
|
|
+ const reader = protobuf.TextReader.open(stream);
|
|
|
saved_model = tf.proto.SavedModel.decodeText(reader);
|
|
|
format = 'TensorFlow Saved Model';
|
|
|
if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
|
|
|
@@ -331,8 +333,8 @@ tf.ModelFactory = class {
|
|
|
}
|
|
|
else if (tags.has('graph_def')) {
|
|
|
try {
|
|
|
- const buffer = context.stream.peek();
|
|
|
- const reader = protobuf.TextReader.create(buffer);
|
|
|
+ const stream = context.stream;
|
|
|
+ const reader = protobuf.TextReader.open(stream);
|
|
|
const meta_graph = tf.proto.MetaGraphDef.decodeText(reader);
|
|
|
saved_model = new tf.proto.SavedModel();
|
|
|
saved_model.meta_graphs.push(meta_graph);
|
|
|
@@ -344,8 +346,8 @@ tf.ModelFactory = class {
|
|
|
}
|
|
|
else if (tags.has('node')) {
|
|
|
try {
|
|
|
- const buffer = context.stream.peek();
|
|
|
- const reader = protobuf.TextReader.create(buffer);
|
|
|
+ const stream = context.stream;
|
|
|
+ const reader = protobuf.TextReader.open(stream);
|
|
|
const graph_def = tf.proto.GraphDef.decodeText(reader);
|
|
|
const meta_graph = new tf.proto.MetaGraphDef();
|
|
|
meta_graph.graph_def = graph_def;
|
|
|
@@ -365,8 +367,7 @@ tf.ModelFactory = class {
|
|
|
let format = null;
|
|
|
try {
|
|
|
if (identifier.endsWith('saved_model.pb')) {
|
|
|
- const buffer = stream.peek();
|
|
|
- const reader = protobuf.Reader.create(buffer);
|
|
|
+ const reader = protobuf.BinaryReader.open(stream);
|
|
|
saved_model = tf.proto.SavedModel.decode(reader);
|
|
|
format = 'TensorFlow Saved Model';
|
|
|
if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
|
|
|
@@ -383,8 +384,7 @@ tf.ModelFactory = class {
|
|
|
}
|
|
|
try {
|
|
|
if (!saved_model && extension == 'meta') {
|
|
|
- const buffer = stream.peek();
|
|
|
- const reader = protobuf.Reader.create(buffer);
|
|
|
+ const reader = protobuf.BinaryReader.open(stream);
|
|
|
const meta_graph = tf.proto.MetaGraphDef.decode(reader);
|
|
|
saved_model = new tf.proto.SavedModel();
|
|
|
saved_model.meta_graphs.push(meta_graph);
|
|
|
@@ -397,8 +397,7 @@ tf.ModelFactory = class {
|
|
|
}
|
|
|
try {
|
|
|
if (!saved_model) {
|
|
|
- const buffer = stream.peek();
|
|
|
- const reader = protobuf.Reader.create(buffer);
|
|
|
+ const reader = protobuf.BinaryReader.open(stream);
|
|
|
const graph_def = tf.proto.GraphDef.decode(reader);
|
|
|
const meta_graph = new tf.proto.MetaGraphDef();
|
|
|
meta_graph.graph_def = graph_def;
|
|
|
@@ -1829,8 +1828,8 @@ tf.TensorBundle = class {
|
|
|
if (format === 1) {
|
|
|
return Promise.resolve(new tf.TensorBundle(format, table.entries, []));
|
|
|
}
|
|
|
- const entry = table.entries.get('');
|
|
|
- const reader = protobuf.Reader.create(entry);
|
|
|
+ const buffer = table.entries.get('');
|
|
|
+ const reader = protobuf.BinaryReader.open(buffer);
|
|
|
const header = tf.proto.BundleHeaderProto.decode(reader);
|
|
|
const numShards = header.num_shards;
|
|
|
const promises = [];
|
|
|
@@ -1857,12 +1856,13 @@ tf.TensorBundle = class {
|
|
|
switch (format) {
|
|
|
case 1: {
|
|
|
const buffer = entries.get('');
|
|
|
- const reader = protobuf.Reader.create(buffer);
|
|
|
+ const reader = protobuf.BinaryReader.open(buffer);
|
|
|
const header = tf.proto.SavedTensorSlices.decode(reader);
|
|
|
const data = new Map();
|
|
|
for (const pair of entries) {
|
|
|
if (pair[0] !== '' && pair[0] !== 'global_step') {
|
|
|
- const reader = protobuf.Reader.create(pair[1]);
|
|
|
+ const buffer = pair[1];
|
|
|
+ const reader = protobuf.BinaryReader.open(buffer);
|
|
|
const slices = tf.proto.SavedTensorSlices.decode(reader);
|
|
|
const name = slices.data.name;
|
|
|
const tensor = slices.data.data;
|
|
|
@@ -1903,9 +1903,9 @@ tf.TensorBundle = class {
|
|
|
break;
|
|
|
}
|
|
|
case 2: {
|
|
|
- entries.forEach((value, name) => {
|
|
|
+ entries.forEach((buffer, name) => {
|
|
|
if (name !== '') {
|
|
|
- const reader = protobuf.Reader.create(value);
|
|
|
+ const reader = protobuf.BinaryReader.open(buffer);
|
|
|
const entry = tf.proto.BundleEntryProto.decode(reader);
|
|
|
const tensor = new tf.proto.TensorProto();
|
|
|
tensor.dtype = entry.dtype;
|
|
|
@@ -2178,14 +2178,6 @@ tf.EventFileReader = class {
|
|
|
if (masked_crc32c(length_bytes) !== length_crc) {
|
|
|
return null;
|
|
|
}
|
|
|
- // reader.skip(-12);
|
|
|
- // const length = reader.uint64().toNumber();
|
|
|
- // reader.uint32(); // masked crc of length
|
|
|
- // const data = reader.read(length);
|
|
|
- // const data_crc = reader.uint32();
|
|
|
- // if (masked_crc32c(data) !== data_crc) {
|
|
|
- // return null;
|
|
|
- // }
|
|
|
return new tf.EventFileReader(stream);
|
|
|
}
|
|
|
|
|
|
@@ -2195,12 +2187,16 @@ tf.EventFileReader = class {
|
|
|
|
|
|
read() {
|
|
|
if (this._stream.position < this._stream.length) {
|
|
|
- const buffer = this._stream.read(12);
|
|
|
- const reader = new tf.BinaryReader(buffer);
|
|
|
- const length = reader.uint64().toNumber();
|
|
|
- reader.uint32(); // masked crc of length
|
|
|
- const data = this._stream.read(length);
|
|
|
- const event = tf.proto.Event.decode(protobuf.Reader.create(data));
|
|
|
+ const uint64 = (stream) => {
|
|
|
+ const buffer = stream.read(8);
|
|
|
+ const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
+ return view.getUint64(0, true).toNumber();
|
|
|
+ };
|
|
|
+ const length = uint64(this._stream);
|
|
|
+ this._stream.skip(4); // masked crc of length
|
|
|
+ const buffer = this._stream.read(length);
|
|
|
+ const reader = protobuf.BinaryReader.open(buffer);
|
|
|
+ const event = tf.proto.Event.decode(reader);
|
|
|
this._stream.skip(4); // masked crc of data
|
|
|
return event;
|
|
|
}
|