2
0
Эх сурвалжийг харах

Core ML Model Package .bin detection (#751)

Lutz Roeder 4 жил өмнө
parent
commit
88938fadd4
3 өөрчлөгдсөн 37 нэмэгдсэн , 8 устгасан
  1. 26 6
      source/coreml.js
  2. 10 1
      source/openvino.js
  3. 1 1
      source/view.js

+ 26 - 6
source/coreml.js

@@ -11,7 +11,9 @@ coreml.ModelFactory = class {
         if (tags.get(1) === 0 && tags.get(2) === 2) {
             return true;
         }
+        const stream = context.stream;
         const identifier = context.identifier.toLowerCase();
+        const extension = identifier.split('.').pop().toLowerCase();
         switch (identifier) {
             case 'manifest.json': {
                 const obj = context.open('json');
@@ -38,6 +40,15 @@ coreml.ModelFactory = class {
                 break;
             }
         }
+        if (extension === 'bin' && stream.length > 16) {
+            const buffer = stream.peek(Math.min(256, stream.length));
+            for (let i = 0; i < buffer.length - 4; i++) {
+                const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
+                if (signature === 0xdeadbeef) {
+                    return true;
+                }
+            }
+        }
         return false;
     }
 
@@ -126,6 +137,14 @@ coreml.ModelFactory = class {
                         return openModel(stream, context, file, 'Core ML Package');
                     });
                 };
+                const openManifestStream = (context, path) => {
+                    return context.request(path + 'Manifest.json', null).then((stream) => {
+                        const buffer = stream.peek();
+                        const reader = json.TextReader.create(buffer);
+                        const obj = reader.read();
+                        return openManifest(obj, context, path);
+                    });
+                };
                 const tags = context.tags('pb');
                 if (tags.get(1) === 0 && tags.get(2) === 2) {
                     return openModel(context.stream, context, context.identifier);
@@ -138,12 +157,13 @@ coreml.ModelFactory = class {
                     }
                     case 'featuredescriptions.json':
                     case 'metadata.json': {
-                        return context.request('../../Manifest.json', null).then((stream) => {
-                            const buffer = stream.peek();
-                            const reader = json.TextReader.create(buffer);
-                            const obj = reader.read();
-                            return openManifest(obj, context, '../../');
-                        });
+                        return openManifestStream(context, '../../');
+                    }
+                    default: {
+                        const extension = identifier.split('.').pop().toLowerCase();
+                        if (extension === 'bin') {
+                            return openManifestStream(context, '../../../');
+                        }
                     }
                 }
             });

+ 10 - 1
source/openvino.js

@@ -39,12 +39,21 @@ openvino.ModelFactory = class {
             }
             if (stream.length > 4) {
                 const buffer = stream.peek(4);
-                const signature = buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24;
+                const signature = (buffer[0] | buffer[1] << 8 | buffer[2] << 16 | buffer [3] << 24) >>> 0;
                 if (signature === 0x00000000 || signature === 0x00000001 ||
                     signature === 0x01306B47 || signature === 0x000D4B38 || signature === 0x0002C056) {
                     return false;
                 }
             }
+            if (stream.length > 4) {
+                const buffer = stream.peek(Math.min(256, stream.length));
+                for (let i = 0; i < buffer.length - 4; i++) {
+                    const signature = (buffer[i] | buffer[i + 1] << 8 | buffer[i + 2] << 16 | buffer [i + 3] << 24) >>> 0;
+                    if (signature === 0xdeadbeef) {
+                        return false;
+                    }
+                }
+            }
             return true;
         }
         return false;

+ 1 - 1
source/view.js

@@ -1436,7 +1436,7 @@ view.ModelFactoryService = class {
         this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel' ]);
         this.register('./onnx', [ '.onnx', '.onn', '.pb', '.pbtxt', '.prototxt', '.model', '.pt', '.pth', '.pkl' ]);
         this.register('./mxnet', [ '.json', '.params' ]);
-        this.register('./coreml', [ '.mlmodel', 'manifest.json', 'metadata.json', 'featuredescriptions.json' ]);
+        this.register('./coreml', [ '.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json' ]);
         this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt' ]);
         this.register('./caffe2', [ '.pb', '.pbtxt', '.prototxt' ]);
         this.register('./torch', [ '.t7' ]);