Browse Source

Detect invalid content (#458)

Lutz Roeder 6 năm trước cách đây
mục cha
commit
1a3b467030
3 tập tin đã thay đổi với 107 bổ sung71 xóa
  1. 4 8
      src/tflite.js
  2. 78 59
      src/view.js
  3. 25 4
      test/models.json

+ 4 - 8
src/tflite.js

@@ -10,13 +10,10 @@ tflite.ModelFactory = class {
 
     match(context) {
         const extension = context.identifier.split('.').pop().toLowerCase();
-        if (extension == 'tflite' || extension == 'lite') {
-            return true;
-        }
-        if (extension == 'tfl' || extension == 'bin') {
+        if (extension === 'tflite' || extension === 'lite' || extension === 'tfl' || extension === 'bin') {
             const buffer = context.buffer;
-            const signature = [ 0x54, 0x46, 0x4c, 0x33 ]; // TFL3
-            if (buffer && buffer.length > 8 && signature.every((x, i) => x == buffer[i + 4])) {
+            const signature = 'TFL3'
+            if (buffer && buffer.length > 8 && buffer.subarray(4, 8).every((x, i) => x === signature.charCodeAt(i))) {
                 return true;
             }
         }
@@ -32,8 +29,7 @@ tflite.ModelFactory = class {
                 const byteBuffer = new flatbuffers.ByteBuffer(buffer);
                 tflite.schema = tflite_schema;
                 if (!tflite.schema.Model.bufferHasIdentifier(byteBuffer)) {
-                    const signature = Array.from(buffer.subarray(0, Math.min(8, buffer.length))).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
-                    throw new tflite.Error("File format is not tflite.Model (" + signature + ").");
+                    throw new tflite.Error("File format is not tflite.Model.");
                 }
                 model = tflite.schema.Model.getRootAsModel(byteBuffer);
             }

+ 78 - 59
src/view.js

@@ -1160,70 +1160,72 @@ view.ModelFactoryService = class {
     }
  
     open(context) {
-        return this._openArchive(context).then((context) => {
-            context = new ModelContext(context);
-            const identifier = context.identifier;
-            const extension = identifier.split('.').pop().toLowerCase();
-            let modules = this._filter(context);
-            if (modules.length == 0) {
-                throw new ModelError("Unsupported file extension '." + extension + "'.");
-            }
-            let errors = [];
-            let match = false;
-            let nextModule = () => {
-                if (modules.length > 0) {
-                    let id = modules.shift();
-                    return this._host.require(id).then((module) => {
-                        if (!module.ModelFactory) {
-                            throw new ModelError("Failed to load module '" + id + "'.");
-                        }
-                        const modelFactory = new module.ModelFactory(); 
-                        if (!modelFactory.match(context)) {
-                            return nextModule();
-                        }
-                        match++;
-                        return modelFactory.open(context, this._host).then((model) => {
-                            return model;
-                        }).catch((error) => {
-                            errors.push(error);
-                            return nextModule();
-                        });
-                    });
+        return this._openSignature(context).then((context) => {
+            return this._openArchive(context).then((context) => {
+                context = new ModelContext(context);
+                const identifier = context.identifier;
+                const extension = identifier.split('.').pop().toLowerCase();
+                let modules = this._filter(context);
+                if (modules.length == 0) {
+                    throw new ModelError("Unsupported file extension '." + extension + "'.");
                 }
-                else {
-                    if (match) {
-                        if (errors.length == 1) {
-                            throw errors[0];
-                        }
-                        throw new ModelError(errors.map((err) => err.message).join('\n'));
+                let errors = [];
+                let match = false;
+                let nextModule = () => {
+                    if (modules.length > 0) {
+                        let id = modules.shift();
+                        return this._host.require(id).then((module) => {
+                            if (!module.ModelFactory) {
+                                throw new ModelError("Failed to load module '" + id + "'.");
+                            }
+                            const modelFactory = new module.ModelFactory(); 
+                            if (!modelFactory.match(context)) {
+                                return nextModule();
+                            }
+                            match++;
+                            return modelFactory.open(context, this._host).then((model) => {
+                                return model;
+                            }).catch((error) => {
+                                errors.push(error);
+                                return nextModule();
+                            });
+                        });
                     }
-                    const knownUnsupportedIdentifiers = new Set([
-                        'natives_blob.bin', 
-                        'v8_context_snapshot.bin',
-                        'snapshot_blob.bin',
-                        'image_net_labels.json',
-                        'package.json',
-                        'models.json',
-                        'LICENSE.meta',
-                        'input_0.pb', 
-                        'output_0.pb',
-                        'object-detection.pbtxt',
-                    ]);
-                    let skip = knownUnsupportedIdentifiers.has(identifier);
-                    if (!skip && (extension === 'pbtxt' || extension === 'prototxt')) {
-                        if (identifier.includes('label_map') || identifier.includes('labels_map') || identifier.includes('labelmap')) {
-                            const tags = context.tags('pbtxt');
-                            if (tags.size === 1 && (tags.has('item') || tags.has('entry'))) {
-                                skip = true;
+                    else {
+                        if (match) {
+                            if (errors.length == 1) {
+                                throw errors[0];
+                            }
+                            throw new ModelError(errors.map((err) => err.message).join('\n'));
+                        }
+                        const knownUnsupportedIdentifiers = new Set([
+                            'natives_blob.bin', 
+                            'v8_context_snapshot.bin',
+                            'snapshot_blob.bin',
+                            'image_net_labels.json',
+                            'package.json',
+                            'models.json',
+                            'LICENSE.meta',
+                            'input_0.pb', 
+                            'output_0.pb',
+                            'object-detection.pbtxt',
+                        ]);
+                        let skip = knownUnsupportedIdentifiers.has(identifier);
+                        if (!skip && (extension === 'pbtxt' || extension === 'prototxt')) {
+                            if (identifier.includes('label_map') || identifier.includes('labels_map') || identifier.includes('labelmap')) {
+                                const tags = context.tags('pbtxt');
+                                if (tags.size === 1 && (tags.has('item') || tags.has('entry'))) {
+                                    skip = true;
+                                }
                             }
                         }
+                        const buffer = context.buffer;
+                        const content = Array.from(buffer.subarray(0, Math.min(16, buffer.length))).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
+                        throw new ModelError("Unsupported file content (" + content + ") for extension '." + extension + "' in '" + identifier + "'.", !skip);
                     }
-                    const buffer = context.buffer;
-                    const content = Array.from(buffer.subarray(0, Math.min(8, buffer.length))).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
-                    throw new ModelError("Unsupported file content (" + content + ") for extension '." + extension + "' in '" + identifier + "'.", !skip);
-                }
-            };
-            return nextModule();
+                };
+                return nextModule();
+            });
         });
     }
 
@@ -1390,6 +1392,23 @@ view.ModelFactoryService = class {
         }
         return moduleList;
     }
+
+    _openSignature(context) {
+        const buffer = context.buffer;
+        const identifier = context.identifier;
+        const list = [
+            { name: 'Git LFS', value: 'version https://git-lfs.github.com/spec/v1\n' },
+            { name: 'HTML', value: '<html>' },
+            { name: 'HTML', value: '\n\n\n\n\n\n<!DOCTYPE html>' }
+        ];
+        for (const item of list) {
+            if (buffer.length >= item.value.length &&
+                buffer.subarray(0, item.value.length).every((v, i) => v === item.value.charCodeAt(i))) {
+                return Promise.reject(new ModelError("Invalid " + item.name + " content in '" + identifier + "'.", true));
+            }
+        }
+        return Promise.resolve(context);
+    }
 };
 
 if (typeof module !== 'undefined' && typeof module.exports === 'object') {

+ 25 - 4
test/models.json

@@ -864,7 +864,7 @@
     "type":   "caffe2",
     "target": "ops/ops.pbtxt",
     "source": "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/core/ops/ops.pbtxt",
-    "error":  "Unsupported file content (6f70207b0a20206e) for extension '.pbtxt' in 'ops.pbtxt'.",
+    "error":  "Unsupported file content (6f70207b0a20206e616d653a20224162) for extension '.pbtxt' in 'ops.pbtxt'.",
     "link":   "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/ops.pbtxt"
   },
   {
@@ -1076,7 +1076,7 @@
     "type":   "cntk",
     "target": "v2/mobilenetv2-1.0.dnn",
     "source": "https://s3.amazonaws.com/mxnet-model-server/onnx-mobilenet/mobilenetv2-1.0.model",
-    "error":  "Unsupported file content (504b030414000000) for extension '.dnn' in 'mobilenetv2-1.0.dnn'.",
+    "error":  "Unsupported file content (504b030414000000000092a2c64c1c0a) for extension '.dnn' in 'mobilenetv2-1.0.dnn'.",
     "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
   },
   {
@@ -1508,6 +1508,20 @@
     "target": "Inceptionv3.mlmodel",
     "source": "https://docs-assets.developer.apple.com/coreml/models/Inceptionv3.mlmodel"
   },
+  {
+    "type":   "coreml",
+    "target": "invalid_git_lfs.mlmodel",
+    "source": "https://github.com/lutzroeder/netron/files/4432767/invalid_git_lfs.mlmodel.zip[invalid_git_lfs.mlmodel]",
+    "error":  "Invalid Git LFS content in 'invalid_git_lfs.mlmodel'.",
+    "link":   "https://github.com/lutzroeder/netron/issues/458"
+  },
+  {
+    "type":   "coreml",
+    "target": "invalid_html.mlmodel.zip",
+    "source": "https://github.com/lutzroeder/netron/files/4432768/invalid_html.mlmodel.zip",
+    "error":  "Invalid HTML content in 'invalid_html.mlmodel.zip'.",
+    "link":   "https://github.com/lutzroeder/netron/issues/458"
+  },
   {
     "type":   "coreml",
     "target": "iris.mlmodel",
@@ -2118,7 +2132,7 @@
     "type":   "keras",
     "target": "keras_invalid_file.h5",
     "source": "https://github.com/lutzroeder/netron/files/3364286/keras_invalid_file.zip[keras_invalid_file.h5]",
-    "error":  "Unsupported file content (0000000000000000) for extension '.h5' in 'keras_invalid_file.h5'.",
+    "error":  "Unsupported file content (00000000000000000000000000000000) for extension '.h5' in 'keras_invalid_file.h5'.",
     "link":   "https://github.com/lutzroeder/netron/issues/57"
   },
   {
@@ -3095,7 +3109,7 @@
     "type":   "onnx",
     "target": "input_0.pb",
     "source": "https://s3.amazonaws.com/download.onnx/models/opset_9/shufflenet.tar.gz[shufflenet/test_data_set_0/input_0.pb]",
-    "error":  "Unsupported file content (0801080308e00108) for extension '.pb' in 'input_0.pb'.",
+    "error":  "Unsupported file content (0801080308e00108e00110014a80e024) for extension '.pb' in 'input_0.pb'.",
     "link":   "https://github.com/onnx/models/tree/master/bvlc_alexnet"
   },
   {
@@ -5337,6 +5351,13 @@
     "source": "https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip[inceptionv3_slim_2016.tflite]",
     "format": "TensorFlow Lite v3"
   },
+  {
+    "type":   "tflite",
+    "target": "invalid_html.tflite",
+    "source": "https://github.com/lutzroeder/netron/files/4432789/invalid_html.tflite.zip[invalid_html.tflite]",
+    "error":  "Invalid HTML content in 'invalid_html.tflite'.",
+    "link":   "https://github.com/lutzroeder/netron/issues/458"
+  },
   {
     "type":   "tflite",
     "target": "mobilenet_v1_1.0_224.tflite",