瀏覽代碼

Update TensorFlow content detection (#782)

Lutz Roeder 4 年之前
父節點
當前提交
4b3942cbf7
共有 5 個文件被更改,包括 260 次插入109 次删除
  1. 4 4
      source/onnx.js
  2. 133 33
      source/protobuf.js
  3. 111 67
      source/tf.js
  4. 5 5
      source/view.js
  5. 7 0
      test/models.json

+ 4 - 4
source/onnx.js

@@ -24,19 +24,19 @@ onnx.ModelFactory = class {
                     for (const pair of schema) {
                         const key = pair[0];
                         const inner = pair[1];
-                        if (!tags.has(key)) {
+                        if (tags[key] === undefined) {
                             continue;
                         }
                         else if (inner === false) {
                             return false;
                         }
                         if (Array.isArray(inner)) {
-                            const value = tags.get(key);
-                            if (!(value instanceof Map) || !match(value, inner)) {
+                            const value = tags[key];
+                            if (typeof value !== 'object' || !match(value, inner)) {
                                 return false;
                             }
                         }
-                        else if (inner !== tags.get(key)) {
+                        else if (inner !== tags[key]) {
                             return false;
                         }
                     }

+ 133 - 33
source/protobuf.js

@@ -57,16 +57,26 @@ protobuf.BinaryReader = class {
     }
 
     decode() {
-        const tags = new Map();
+        let tags = {};
         this._position = 0;
         try {
-            const decodeMessage = () => {
-                const length = this.uint32();
+            const decodeMessage = (max) => {
+                const length = this._uint32();
+                if (length === undefined) {
+                    return undefined;
+                }
                 const end = this.position + length;
+                if (end > max) {
+                    return undefined;
+                }
                 try {
-                    const tags = new Map();
+                    const tags = {};
                     while (this.position < end) {
-                        const tag = this.uint32();
+                        const tag = this._uint32();
+                        if (tag === undefined) {
+                            this.seek(end);
+                            return 2;
+                        }
                         const field = tag >>> 3;
                         const type = tag & 7;
                         if (type > 5 || field === 0) {
@@ -74,24 +84,35 @@ protobuf.BinaryReader = class {
                             return 2;
                         }
                         if (type === 2) {
-                            const type = tags.get(field);
+                            const type = tags[field];
                             if (type !== 2) {
-                                const inner = decodeMessage(this);
+                                const inner = decodeMessage(end);
+                                if (this.position > end) {
+                                    this.seek(end);
+                                    return 2;
+                                }
+                                if (inner === undefined) {
+                                    this.seek(end);
+                                    return 2;
+                                }
                                 if (inner === 2) {
-                                    tags.set(field, inner);
+                                    tags[field] = inner;
                                 }
                                 else if (!type) {
-                                    tags.set(field, inner);
+                                    tags[field] = inner;
                                 }
                                 else {
-                                    for (const pair of inner) {
-                                        type.set(pair[0], pair[1]);
+                                    for (const pair of Object.entries(inner)) {
+                                        if (type[pair[0]] === 2 && pair[1] !== 2) {
+                                            continue;
+                                        }
+                                        type[pair[0]] = pair[1];
                                     }
                                 }
                                 continue;
                             }
                         }
-                        tags.set(field, type);
+                        tags[field] = type;
                         if (!this._skipType(type)) {
                             this.seek(end);
                             return 2;
@@ -116,30 +137,37 @@ protobuf.BinaryReader = class {
                         const field = tag >>> 3;
                         const type = tag & 7;
                         if (type > 5 || field === 0) {
-                            tags.clear();
+                            tags = {};
                             break;
                         }
                         if (type === 2) {
-                            const type = tags.get(field);
+                            const type = tags[field];
                             if (type !== 2) {
-                                const inner = decodeMessage(this);
+                                const inner = decodeMessage(length);
+                                if (inner === undefined) {
+                                    tags = {};
+                                    break;
+                                }
                                 if (inner === 2) {
-                                    tags.set(field, inner);
+                                    tags[field] = inner;
                                 }
                                 else if (!type) {
-                                    tags.set(field, inner);
+                                    tags[field] = inner;
                                 }
                                 else {
-                                    for (const pair of inner) {
-                                        type.set(pair[0], pair[1]);
+                                    for (const pair of Object.entries(inner)) {
+                                        if (type[pair[0]] === 2 && pair[1] !== 2) {
+                                            continue;
+                                        }
+                                        type[pair[0]] = pair[1];
                                     }
                                 }
                                 continue;
                             }
                         }
-                        tags.set(field, type);
+                        tags[field] = type;
                         if (!this._skipType(type)) {
-                            tags.clear();
+                            tags = {};
                             break;
                         }
                     }
@@ -147,7 +175,7 @@ protobuf.BinaryReader = class {
             }
         }
         catch (err) {
-            tags.clear();
+            tags = {};
         }
         this._position = 0;
         return tags;
@@ -368,34 +396,106 @@ protobuf.BinaryReader = class {
         while (this._buffer[this._position++] & 128);
     }
 
+    _uint32() {
+        let c;
+        if (this._position < this._length) {
+            c = this._buffer[this._position++];
+            let value = (c & 127) >>> 0;
+            if (c < 128) {
+                return value;
+            }
+            if (this._position < this._length) {
+                c = this._buffer[this._position++];
+                value = (value | (c & 127) <<  7) >>> 0;
+                if (c < 128) {
+                    return value;
+                }
+                if (this._position < this._length) {
+                    c = this._buffer[this._position++];
+                    value = (value | (c & 127) << 14) >>> 0;
+                    if (c < 128) {
+                        return value;
+                    }
+                    if (this._position < this._length) {
+                        c = this._buffer[this._position++];
+                        value = (value | (c & 127) << 21) >>> 0;
+                        if (c < 128) {
+                            return value;
+                        }
+                        if (this._position < this._length) {
+                            c = this._buffer[this._position++];
+                            value = (value | (c & 15) << 28) >>> 0;
+                            if (c < 128) {
+                                return value;
+                            }
+                            if (this.byte() !== 255 || this.byte() !== 255 || this.byte() !== 255 || this.byte() !== 255 || this.byte() !== 1) {
+                                return undefined;
+                            }
+                            return value;
+                        }
+                    }
+                }
+            }
+        }
+        return undefined;
+    }
+
     _skipType(wireType) {
         switch (wireType) {
-            case 0:
+            case 0: {
+                // const max = this._position + 9;
                 do {
-                    if (this._position >= this._length) {
+                    if (this._position >= this._length /* || this._position > max */) {
                         return false;
                     }
                 }
                 while (this._buffer[this._position++] & 128);
                 break;
-            case 1:
-                this.skip(8);
+            }
+            case 1: {
+                if (this._position + 8 >= this._length) {
+                    return false;
+                }
+                this._position += 8;
                 break;
-            case 2:
-                this.skip(this.uint32());
+            }
+            case 2: {
+                const length = this._uint32();
+                if (length === undefined) {
+                    return false;
+                }
+                if (this._position + length > this._end) {
+                    return false;
+                }
+                this._position += length;
                 break;
-            case 3:
-                while ((wireType = this.uint32() & 7) !== 4) {
+            }
+            case 3: {
+                for (;;) {
+                    const tag = this._uint32();
+                    if (tag === undefined) {
+                        return false;
+                    }
+                    const wireType = tag & 7;
+                    if (wireType === 4) {
+                        break;
+                    }
                     if (!this._skipType(wireType)) {
                         return false;
                     }
                 }
                 break;
-            case 5:
-                this.skip(4);
+            }
+            case 5: {
+                if (this._position + 4 >= this._length) {
+                    return false;
+                }
+                this._position += 4;
                 break;
-            default:
+            }
+            default: {
                 return false;
+            }
         }
         return true;
     }

+ 111 - 67
source/tf.js

@@ -12,12 +12,6 @@ tf.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        if (extension === 'meta') {
-            const tags = context.tags('pb');
-            if (tags.size !== 0) {
-                return 'tf.pb';
-            }
-        }
         if (extension === 'pbtxt' || extension === 'prototxt' || extension === 'pt') {
             if (identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
                 identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
@@ -37,7 +31,7 @@ tf.ModelFactory = class {
                 return 'tf.pbtxt.GraphDef';
             }
         }
-        if (extension === 'pb' || extension === 'pbtxt' || extension === 'prototxt' || extension === 'graphdef') {
+        if (extension === 'pb' || extension === 'pbtxt' || extension === 'prototxt' || extension === 'graphdef' || extension === 'meta') {
             if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb')) {
                 return undefined;
             }
@@ -51,30 +45,56 @@ tf.ModelFactory = class {
             const tags = context.tags('pb');
             if (tags.size > 0) {
                 if (!Array.from(tags).some((pair) => pair[0] >= 5 || pair[1] === 5)) {
-                    if (tags.size === 1 && tags.get(1) === 2) {
-                        const tags = context.tags('pb+');
-                        const match = (tags, schema) => {
-                            for (const pair of schema) {
-                                const key = pair[0];
-                                const inner = pair[1];
-                                if (!tags.has(key)) {
-                                    continue;
-                                }
-                                else if (inner === false) {
-                                    return false;
-                                }
-                                if (Array.isArray(inner)) {
-                                    const value = tags.get(key);
-                                    if (!(value instanceof Map) || !match(value, inner)) {
-                                        return false;
-                                    }
-                                }
-                                else if (inner !== tags.get(key)) {
+                    const match = (tags, schema) => {
+                        for (const pair of schema) {
+                            const key = pair[0];
+                            const inner = pair[1];
+                            if (tags[key] === undefined) {
+                                continue;
+                            }
+                            else if (inner === false) {
+                                return false;
+                            }
+                            if (Array.isArray(inner)) {
+                                const value = tags[key];
+                                if (typeof value !== 'object' || !match(value, inner)) {
                                     return false;
                                 }
                             }
-                            return true;
-                        };
+                            else if (inner !== tags[key]) {
+                                return false;
+                            }
+                        }
+                        return true;
+                    };
+                    const signatureGraphDef = [
+                        [1 /* node */, [
+                            [1 /* name */, 2],
+                            [2 /* op */, 2],
+                            [3 /* input */, 2],
+                            [4 /* device */,2],
+                            [5 /* attr */, [
+                                [1,2],
+                                [2,[]]
+                            ]],
+                            [6 /* experimental_debug_info */, []]
+                        ]],
+                        [2 /* library */, []],
+                        [3 /* version */, 0],
+                        [4 /* versions */, [[1,0],[2,0]]]
+                    ];
+                    const signatureMetaGraphDef = [
+                        [1 /* meta_info_def */, [[1,2],[2,[]],[3,[]],[4,2],[6,2],[7,0],[8,[]]]],
+                        [2 /* graph_def */, signatureGraphDef],
+                        [3 /* saver_def */, [[1,2],[2,2],[3,2],[4,0],[5,0],[6,5],[7,0]]],
+                        [4 /* collection_def */,[]],
+                        [5 /* signature_def */, []],
+                        [6 /* asset_file_def */, []],
+                        [7 /* object_graph_def */, []]
+                    ];
+                    const signatureSavedModel = [[1,0],[2,signatureMetaGraphDef]];
+                    if (tags.size === 1 && tags.get(1) === 2) {
+                        const tags = context.tags('pb+');
                         // mediapipe.BoxDetectorIndex
                         if (match(tags, [[1,[[1,[[1,[[1,5],[2,5],[3,5],[4,5],[6,0],[7,5],[8,5],[10,5],[11,0],[12,0]]],[2,5],[3,[]]]],[2,false],[3,false],[4,false],[5,false]]],[2,false],[3,false]] )) {
                             return undefined;
@@ -84,8 +104,26 @@ tf.ModelFactory = class {
                             return 'tf.pb.keras.SavedMetadata';
                         }
                     }
+                    if ((!tags.has(1) || tags.get(1) === 0) && tags.get(2) === 2) {
+                        const tags = context.tags('pb+');
+                        if (match(tags, signatureSavedModel)) {
+                            return 'tf.pb.SavedModel';
+                        }
+                    }
+                    if ((!tags.has(1) || tags.get(1) === 2) &&
+                        (!tags.has(2) || tags.get(2) === 2) &&
+                        (!tags.has(3) || tags.get(3) === 2) &&
+                        (!tags.has(4) || tags.get(4) === 2)) {
+                        const tags = context.tags('pb+');
+                        if (match(tags, signatureMetaGraphDef)) {
+                            return 'tf.pb.MetaGraphDef';
+                        }
+                    }
                     if (tags.get(1) !== 2) {
-                        return 'tf.pb';
+                        const tags = context.tags('pb+');
+                        if (match(tags, signatureGraphDef)) {
+                            return 'tf.pb.GraphDef';
+                        }
                     }
                     const decode = (buffer, value) => {
                         const reader = protobuf.BinaryReader.open(buffer);
@@ -112,7 +150,7 @@ tf.ModelFactory = class {
                             const decoder = new TextDecoder('utf-8');
                             const name = decoder.decode(nameBuffer);
                             if (Array.from(name).filter((c) => c <= ' ').length < 256) {
-                                return 'tf.pb';
+                                return 'tf.pb.GraphDef';
                             }
                         }
                     }
@@ -193,7 +231,9 @@ tf.ModelFactory = class {
                 }
                 return openModel(saved_model, format, producer, null);
             };
-            const openBundle = (context, stream, identifier) => {
+            const openBundle = (context) => {
+                const stream = context.stream;
+                const identifier = context.identifier;
                 return tf.TensorBundle.open(stream, identifier, context).then((bundle) => {
                     return openModel(null, 'TensorFlow Tensor Bundle v' + bundle.format.toString(), null, bundle);
                 }).catch((error) => {
@@ -401,54 +441,54 @@ tf.ModelFactory = class {
                     throw new tf.Error('File text format is not tensorflow.SavedModel (' + error.message + ').');
                 }
             };
-            const openBinaryProto = (stream, identifier) => {
+            const openBinaryGraphDef = (context) => {
                 let saved_model = null;
-                let format = null;
-                const extension = identifier.split('.').pop().toLowerCase();
+                const format = 'TensorFlow Graph';
                 try {
-                    if (identifier.endsWith('saved_model.pb')) {
-                        const reader = protobuf.BinaryReader.open(stream);
-                        saved_model = tf.proto.tensorflow.SavedModel.decode(reader);
-                        format = 'TensorFlow Saved Model';
-                        if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
-                            format = format + ' v' + saved_model.saved_model_schema_version.toString();
-                        }
-                    }
+                    const stream = context.stream;
+                    const reader = protobuf.BinaryReader.open(stream);
+                    const graph_def = tf.proto.tensorflow.GraphDef.decode(reader);
+                    const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
+                    meta_graph.graph_def = graph_def;
+                    saved_model = new tf.proto.tensorflow.SavedModel();
+                    saved_model.meta_graphs.push(meta_graph);
                 }
                 catch (error) {
-                    const signature = [ 0x08, 0x01, 0x12 ];
-                    if (signature.length < stream.length && stream.peek(3).every((value, index) => value === signature[index])) {
-                        const message = error && error.message ? error.message : error.toString();
-                        throw new tf.Error('File format is not tensorflow.SavedModel (' + message.replace(/\.$/, '') + ').');
-                    }
+                    const message = error && error.message ? error.message : error.toString();
+                    throw new tf.Error('File format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
                 }
+                return openSavedModel(saved_model, format, null);
+            };
+            const openBinaryMetaGraphDef = (context) => {
+                let saved_model = null;
+                const format = 'TensorFlow MetaGraph';
                 try {
-                    if (!saved_model && extension == 'meta') {
-                        const reader = protobuf.BinaryReader.open(stream);
-                        const meta_graph = tf.proto.tensorflow.MetaGraphDef.decode(reader);
-                        saved_model = new tf.proto.tensorflow.SavedModel();
-                        saved_model.meta_graphs.push(meta_graph);
-                        format = 'TensorFlow MetaGraph';
-                    }
+                    const stream = context.stream;
+                    const reader = protobuf.BinaryReader.open(stream);
+                    const meta_graph = tf.proto.tensorflow.MetaGraphDef.decode(reader);
+                    saved_model = new tf.proto.tensorflow.SavedModel();
+                    saved_model.meta_graphs.push(meta_graph);
                 }
                 catch (error) {
                     const message = error && error.message ? error.message : error.toString();
                     throw new tf.Error('File format is not tensorflow.MetaGraphDef (' + message.replace(/\.$/, '') + ').');
                 }
+                return openSavedModel(saved_model, format, null);
+            };
+            const openBinarySavedModel = (context) => {
+                let saved_model = null;
+                let format = 'TensorFlow Saved Model';
                 try {
-                    if (!saved_model) {
-                        const reader = protobuf.BinaryReader.open(stream);
-                        const graph_def = tf.proto.tensorflow.GraphDef.decode(reader);
-                        const meta_graph = new tf.proto.tensorflow.MetaGraphDef();
-                        meta_graph.graph_def = graph_def;
-                        saved_model = new tf.proto.tensorflow.SavedModel();
-                        saved_model.meta_graphs.push(meta_graph);
-                        format = 'TensorFlow Graph';
+                    const stream = context.stream;
+                    const reader = protobuf.BinaryReader.open(stream);
+                    saved_model = tf.proto.tensorflow.SavedModel.decode(reader);
+                    if (saved_model && Object.prototype.hasOwnProperty.call(saved_model, 'saved_model_schema_version')) {
+                        format = format + ' v' + saved_model.saved_model_schema_version.toString();
                     }
                 }
                 catch (error) {
                     const message = error && error.message ? error.message : error.toString();
-                    throw new tf.Error('File format is not tensorflow.GraphDef (' + message.replace(/\.$/, '') + ').');
+                    throw new tf.Error('File format is not tensorflow.SavedModel (' + message.replace(/\.$/, '') + ').');
                 }
                 return openSavedModel(saved_model, format, null);
             };
@@ -461,12 +501,12 @@ tf.ModelFactory = class {
                 */
                 const identifier = 'saved_model.pb';
                 return context.request(identifier, null).then((stream) => {
-                    return openBinaryProto(stream, identifier);
+                    return openBinarySavedModel({ stream: stream });
                 });
             };
             switch (match) {
                 case 'tf.bundle':
-                    return openBundle(context, context.stream, context.identifier);
+                    return openBundle(context);
                 case 'tf.data':
                     return openData(context);
                 case 'tf.events':
@@ -479,8 +519,12 @@ tf.ModelFactory = class {
                     return openTextMetaGraphDef(context);
                 case 'tf.pbtxt.SavedModel':
                     return openTextSavedModel(context);
-                case 'tf.pb':
-                    return openBinaryProto(context.stream, context.identifier);
+                case 'tf.pb.GraphDef':
+                    return openBinaryGraphDef(context);
+                case 'tf.pb.MetaGraphDef':
+                    return openBinaryMetaGraphDef(context);
+                case 'tf.pb.SavedModel':
+                    return openBinarySavedModel(context);
                 case 'tf.pb.keras.SavedMetadata':
                     return openSavedMetadata(context);
                 default:

+ 5 - 5
source/view.js

@@ -1666,7 +1666,7 @@ view.ModelFactoryService = class {
         };
         const pb = () => {
             const tags = context.tags('pb+');
-            if (tags.size > 0) {
+            if (Object.keys(tags).length > 0) {
                 const formats = [
                     { name: 'mediapipe.BoxDetectorIndex data', tags: [[1,[[1,[[1,[[1,5],[2,5],[3,5],[4,5],[6,0],[7,5],[8,5],[10,5],[11,0],[12,0]]],[2,5],[3,[]]]],[2,false],[3,false],[4,false],[5,false]]],[2,false],[3,false]] },
                     { name: 'sentencepiece.ModelProto data', tags: [[1,[[1,2],[2,5],[3,0]]],[2,[[1,2],[2,2],[3,0],[4,0],[5,2],[6,0],[7,2],[10,5],[16,0],[40,0],[41,0],[42,0],[43,0]]],[3,[]],[4,[]],[5,[]]] },
@@ -1676,19 +1676,19 @@ view.ModelFactoryService = class {
                     for (const pair of schema) {
                         const key = pair[0];
                         const inner = pair[1];
-                        if (!tags.has(key)) {
+                        if (tags[key] === undefined) {
                             continue;
                         }
                         else if (inner === false) {
                             return false;
                         }
                         if (Array.isArray(inner)) {
-                            const value = tags.get(key);
-                            if (!(value instanceof Map) || !match(value, inner)) {
+                            const value = tags[key];
+                            if (typeof value !== 'object' || !match(value, inner)) {
                                 return false;
                             }
                         }
-                        else if (inner !== tags.get(key)) {
+                        else if (inner !== tags[key]) {
                             return false;
                         }
                     }

+ 7 - 0
test/models.json

@@ -5447,6 +5447,13 @@
     "action": "skip-render",
     "link":   "https://github.com/lutzroeder/netron/issues/235"
   },
+  {
+    "type":   "tf",
+    "target": "saved_model_xxx.pb",
+    "source": "https://github.com/lutzroeder/netron/files/6965264/saved_model_xxx.pb.zip[saved_model_xxx.pb]",
+    "format": "TensorFlow Saved Model v1",
+    "link":   "https://github.com/lutzroeder/netron/issues/782"
+  },
   {
     "type":   "tf",
     "target": "speech_commands_v0.pb",