|
|
@@ -12,12 +12,6 @@ tf.ModelFactory = class {
|
|
|
match(context) {
|
|
|
const identifier = context.identifier;
|
|
|
const extension = identifier.split('.').pop().toLowerCase();
|
|
|
- if (extension === 'meta') {
|
|
|
- const tags = context.tags('pb');
|
|
|
- if (tags.size !== 0) {
|
|
|
- return 'tf.pb';
|
|
|
- }
|
|
|
- }
|
|
|
if (extension === 'pbtxt' || extension === 'prototxt' || extension === 'pt') {
|
|
|
if (identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
|
|
|
identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
|
|
|
@@ -37,7 +31,7 @@ tf.ModelFactory = class {
|
|
|
return 'tf.pbtxt.GraphDef';
|
|
|
}
|
|
|
}
|
|
|
- if (extension === 'pb' || extension === 'pbtxt' || extension === 'prototxt' || extension === 'graphdef') {
|
|
|
+ if (extension === 'pb' || extension === 'pbtxt' || extension === 'prototxt' || extension === 'graphdef' || extension === 'meta') {
|
|
|
if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb')) {
|
|
|
return undefined;
|
|
|
}
|
|
|
@@ -51,30 +45,56 @@ tf.ModelFactory = class {
|
|
|
const tags = context.tags('pb');
|
|
|
if (tags.size > 0) {
|
|
|
if (!Array.from(tags).some((pair) => pair[0] >= 5 || pair[1] === 5)) {
|
|
|
- if (tags.size === 1 && tags.get(1) === 2) {
|
|
|
- const tags = context.tags('pb+');
|
|
|
- const match = (tags, schema) => {
|
|
|
- for (const pair of schema) {
|
|
|
- const key = pair[0];
|
|
|
- const inner = pair[1];
|
|
|
- if (!tags.has(key)) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- else if (inner === false) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- if (Array.isArray(inner)) {
|
|
|
- const value = tags.get(key);
|
|
|
- if (!(value instanceof Map) || !match(value, inner)) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- }
|
|
|
- else if (inner !== tags.get(key)) {
|
|
|
+ const match = (tags, schema) => {
|
|
|
+ for (const pair of schema) {
|
|
|
+ const key = pair[0];
|
|
|
+ const inner = pair[1];
|
|
|
+ if (tags[key] === undefined) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ else if (inner === false) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (Array.isArray(inner)) {
|
|
|
+ const value = tags[key];
|
|
|
+ if (typeof value !== 'object' || !match(value, inner)) {
|
|
|
return false;
|
|
|
}
|
|
|
}
|
|
|
- return true;
|
|
|
- };
|
|
|
+ else if (inner !== tags[key]) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ };
|
|
|
+ const signatureGraphDef = [
|
|
|
+ [1 /* node */, [
|
|
|
+ [1 /* name */, 2],
|
|
|
+ [2 /* op */, 2],
|
|
|
+ [3 /* input */, 2],
|
|
|
+ [4 /* device */,2],
|
|
|
+ [5 /* attr */, [
|
|
|
+ [1,2],
|
|
|
+ [2,[]]
|
|
|
+ ]],
|
|
|
+ [6 /* experimental_debug_info */, []]
|
|
|
+ ]],
|
|
|
+ [2 /* library */, []],
|
|
|
+ [3 /* version */, 0],
|
|
|
+ [4 /* versions */, [[1,0],[2,0]]]
|
|
|
+ ];
|
|
|
+ const signatureMetaGraphDef = [
|
|
|
+ [1 /* meta_info_def */, [[1,2],[2,[]],[3,[]],[4,2],[6,2],[7,0],[8,[]]]],
|
|
|
+ [2 /* graph_def */, signatureGraphDef],
|
|
|
+ [3 /* saver_def */, [[1,2],[2,2],[3,2],[4,0],[5,0],[6,5],[7,0]]],
|
|
|
+ [4 /* collection_def */,[]],
|
|
|
+ [5 /* signature_def */, []],
|
|
|
+ [6 /* asset_file_def */, []],
|
|
|
+ [7 /* object_graph_def */, []]
|
|
|
+ ];
|
|
|
+ const signatureSavedModel = [[1,0],[2,signatureMetaGraphDef]];
|
|
|
+ if (tags.size === 1 && tags.get(1) === 2) {
|
|
|
+ const tags = context.tags('pb+');
|
|
|
// mediapipe.BoxDetectorIndex
|
|
|
if (match(tags, [[1,[[1,[[1,[[1,5],[2,5],[3,5],[4,5],[6,0],[7,5],[8,5],[10,5],[11,0],[12,0]]],[2,5],[3,[]]]],[2,false],[3,false],[4,false],[5,false]]],[2,false],[3,false]] )) {
|
|
|
return undefined;
|
|
|
@@ -84,8 +104,26 @@ tf.ModelFactory = class {
|
|
|
return 'tf.pb.keras.SavedMetadata';
|
|
|
}
|
|
|
}
|
|
|
+ if ((!tags.has(1) || tags.get(1) === 0) && tags.get(2) === 2) {
|
|
|
+ const tags = context.tags('pb+');
|
|
|
+ if (match(tags, signatureSavedModel)) {
|
|
|
+ return 'tf.pb.SavedModel';
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if ((!tags.has(1) || tags.get(1) === 2) &&
|
|
|
+ (!tags.has(2) || tags.get(2) === 2) &&
|
|
|
+ (!tags.has(3) || tags.get(3) === 2) &&
|
|
|
+ (!tags.has(4) || tags.get(4) === 2)) {
|
|
|
+ const tags = context.tags('pb+');
|
|
|
+ if (match(tags, signatureMetaGraphDef)) {
|
|
|
+ return 'tf.pb.MetaGraphDef';
|
|
|
+ }
|
|
|
+ }
|
|
|
if (tags.get(1) !== 2) {
|
|
|
- return 'tf.pb';
|
|
|
+ const tags = context.tags('pb+');
|
|
|
+ if (match(tags, signatureGraphDef)) {
|
|
|
+ return 'tf.pb.GraphDef';
|
|
|
+ }
|
|
|
}
|
|
|
const decode = (buffer, value) => {
|
|
|
const reader = protobuf.BinaryReader.open(buffer);
|
|
|
@@ -112,7 +150,7 @@ tf.ModelFactory = class {
|
|
|
const decoder = new TextDecoder('utf-8');
|
|
|
const name = decoder.decode(nameBuffer);
|
|
|
if (Array.from(name).filter((c) => c <= ' ').length < 256) {
|
|
|
- return 'tf.pb';
|
|
|
+ return 'tf.pb.GraphDef';
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -193,7 +231,9 @@ tf.ModelFactory = class {
|
|
|
}
|
|
|
return openModel(saved_model, format, producer, null);
|
|
|
};
|
|
|
- const openBundle = (context, stream, identifier) => {
|
|
|
+ const openBundle = (context) => {
|
|
|
+ const stream = context.stream;
|
|
|
+ const identifier = context.identifier;
|
|
|
return tf.TensorBundle.open(stream, identifier, context).then((bundle) => {
|
|
|
return openModel(null, 'TensorFlow Tensor Bundle v' + bundle.format.toString(), null, bundle);
|
|
|
}).catch((error) => {
|
|
|
@@ -401,54 +441,54 @@ tf.ModelFactory = class {
|
|
|
throw new tf.Error('File text format is not tensorflow.SavedModel (' + error.message + ').');
|
|
|
}
|
|
|
};
|
|
|
- const openBinaryProto = (stream, identifier) => {
|
|
|
+ const openBinaryGraphDef = (context) => {
|
|
|
let saved_model = null;
|
|
|
- let format = null;
|
|
|
- const extension = identifier.split('.').pop().toLowerCase();
|
|
|
+ const format = 'TensorFlow Graph';
|
|
|
try {
|
|
|
- if (identifier.endsWith('saved_model.pb')) {
|
|
|
- const reader = protobuf.BinaryReader.open(stream);
|
|
|
- saved_model = tf.proto.tensorflow.SavedModel.decode(reader);
|
|
|
- format = 'TensorFlow Saved Model';
|
|
|
- if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
|
|
|
- format = format + ' v' + saved_model.saved_model_schema_version.toString();
|
|
|
- }
|
|
|
- }
|
|
|
+ const stream = context.stream;
|
|
|
+ const reader = protobuf.BinaryReader.open(stream);
|
|
|
+ const graph_def = tf.proto.tensorflow.GraphDef.decode(reader);
|
|
|
+ const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
|
|
|
+ meta_graph.graph_def = graph_def;
|
|
|
+ saved_model = new tf.proto.tensorflow.SavedModel();
|
|
|
+ saved_model.meta_graphs.push(meta_graph);
|
|
|
}
|
|
|
catch (error) {
|
|
|
- const signature = [ 0x08, 0x01, 0x12 ];
|
|
|
- if (signature.length < stream.length && stream.peek(3).every((value, index) => value === signature[index])) {
|
|
|
- const message = error && error.message ? error.message : error.toString();
|
|
|
- throw new tf.Error('File format is not tensorflow.SavedModel (' + message.replace(/\.$/, '') + ').');
|
|
|
- }
|
|
|
+ const message = error && error.message ? error.message : error.toString();
|
|
|
+ throw new tf.Error('File format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
+ return openSavedModel(saved_model, format, null);
|
|
|
+ };
|
|
|
+ const openBinaryMetaGraphDef = (context) => {
|
|
|
+ let saved_model = null;
|
|
|
+ const format = 'TensorFlow MetaGraph';
|
|
|
try {
|
|
|
- if (!saved_model && extension == 'meta') {
|
|
|
- const reader = protobuf.BinaryReader.open(stream);
|
|
|
- const meta_graph = tf.proto.tensorflow.MetaGraphDef.decode(reader);
|
|
|
- saved_model = new tf.proto.tensorflow.SavedModel();
|
|
|
- saved_model.meta_graphs.push(meta_graph);
|
|
|
- format = 'TensorFlow MetaGraph';
|
|
|
- }
|
|
|
+ const stream = context.stream;
|
|
|
+ const reader = protobuf.BinaryReader.open(stream);
|
|
|
+ const meta_graph = tf.proto.tensorflow.MetaGraphDef.decode(reader);
|
|
|
+ saved_model = new tf.proto.tensorflow.SavedModel();
|
|
|
+ saved_model.meta_graphs.push(meta_graph);
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
throw new tf.Error('File format is not tensorflow.MetaGraphDef (' + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
+ return openSavedModel(saved_model, format, null);
|
|
|
+ };
|
|
|
+ const openBinarySavedModel = (context) => {
|
|
|
+ let saved_model = null;
|
|
|
+ let format = 'TensorFlow Saved Model';
|
|
|
try {
|
|
|
- if (!saved_model) {
|
|
|
- const reader = protobuf.BinaryReader.open(stream);
|
|
|
- const graph_def = tf.proto.tensorflow.GraphDef.decode(reader);
|
|
|
- const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
|
|
|
- meta_graph.graph_def = graph_def;
|
|
|
- saved_model = new tf.proto.tensorflow.SavedModel();
|
|
|
- saved_model.meta_graphs.push(meta_graph);
|
|
|
- format = 'TensorFlow Graph';
|
|
|
+ const stream = context.stream;
|
|
|
+ const reader = protobuf.BinaryReader.open(stream);
|
|
|
+ saved_model = tf.proto.tensorflow.SavedModel.decode(reader);
|
|
|
+ if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
|
|
|
+ format = format + ' v' + saved_model.saved_model_schema_version.toString();
|
|
|
}
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
- throw new tf.Error('File format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
|
|
|
+ throw new tf.Error('File format is not tensorflow.SavedModel (' + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
return openSavedModel(saved_model, format, null);
|
|
|
};
|
|
|
@@ -461,12 +501,12 @@ tf.ModelFactory = class {
|
|
|
*/
|
|
|
const identifier = 'saved_model.pb';
|
|
|
return context.request(identifier, null).then((stream) => {
|
|
|
- return openBinaryProto(stream, identifier);
|
|
|
+ return openBinarySavedModel({ stream: stream });
|
|
|
});
|
|
|
};
|
|
|
switch (match) {
|
|
|
case 'tf.bundle':
|
|
|
- return openBundle(context, context.stream, context.identifier);
|
|
|
+ return openBundle(context);
|
|
|
case 'tf.data':
|
|
|
return openData(context);
|
|
|
case 'tf.events':
|
|
|
@@ -479,8 +519,12 @@ tf.ModelFactory = class {
|
|
|
return openTextMetaGraphDef(context);
|
|
|
case 'tf.pbtxt.SavedModel':
|
|
|
return openTextSavedModel(context);
|
|
|
- case 'tf.pb':
|
|
|
- return openBinaryProto(context.stream, context.identifier);
|
|
|
+ case 'tf.pb.GraphDef':
|
|
|
+ return openBinaryGraphDef(context);
|
|
|
+ case 'tf.pb.MetaGraphDef':
|
|
|
+ return openBinaryMetaGraphDef(context);
|
|
|
+ case 'tf.pb.SavedModel':
|
|
|
+ return openBinarySavedModel(context);
|
|
|
case 'tf.pb.keras.SavedMetadata':
|
|
|
return openSavedMetadata(context);
|
|
|
default:
|