Lutz Roeder преди 5 години
родител
ревизия
d009e1529b
променени са 3 файла, в които са добавени 12 реда и са изтрити 23 реда
  1. 0 12
      source/caffe.js
  2. 0 5
      source/cntk.js
  3. 12 6
      source/view.js

+ 0 - 12
source/caffe.js

@@ -16,18 +16,6 @@ caffe.ModelFactory = class {
             identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
             return false;
         }
-        if (extension == 'pt') {
-            const stream = context.stream;
-            const signatures = [
-                // Reject PyTorch models
-                [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ],
-                // Reject TorchScript models
-                [ 0x50, 0x4b ]
-            ];
-            if (signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value))) {
-                return false;
-            }
-        }
         const tags = context.tags('pbtxt');
         if (tags.has('layer') || tags.has('layers') || tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
             return true;

+ 0 - 5
source/cntk.js

@@ -10,11 +10,6 @@ cntk.ModelFactory = class {
 
     match(context) {
         const stream = context.stream;
-        // Reject PyTorch models with .model file extension.
-        const torch = [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
-        if (torch.length <= stream.length && stream.peek(torch.length).every((value, index) => torch[index] === undefined || torch[index] === value)) {
-            return false;
-        }
         // CNTK v1
         const signature = [ 0x42, 0x00, 0x43, 0x00, 0x4e, 0x00, 0x00, 0x00 ];
         if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {

+ 12 - 6
source/view.js

@@ -1235,13 +1235,19 @@ view.ModelContext = class {
         if (!tags) {
             tags = new Map();
             let reset = false;
-            const signature = [ 0x50, 0x4B, 0x03, 0x04 ];
-            if (this.stream.length < 4 || !this.stream.peek(4).every((value, index) => value === signature[index])) {
+            const signatures = [
+                // Reject PyTorch models
+                [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ],
+                // Reject TorchScript models
+                [ 0x50, 0x4b ]
+            ];
+            const stream = this.stream;
+            if (!signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value))) {
                 try {
                     switch (type) {
                         case 'pbtxt': {
                             reset = true;
-                            const decoder = base.TextDecoder.create(this.stream.peek());
+                            const decoder = base.TextDecoder.create(stream.peek());
                             let count = 0;
                             for (let i = 0; i < 0x100; i++) {
                                 const c = decoder.decode();
@@ -1252,7 +1258,7 @@ view.ModelContext = class {
                                 }
                             }
                             if (count < 4) {
-                                const reader = protobuf.TextReader.create(this.stream.peek());
+                                const reader = protobuf.TextReader.create(stream.peek());
                                 reader.start(false);
                                 while (!reader.end(false)) {
                                     const tag = reader.tag();
@@ -1274,7 +1280,7 @@ view.ModelContext = class {
                         }
                         case 'pb': {
                             reset = true;
-                            const reader = protobuf.Reader.create(this.stream.peek());
+                            const reader = protobuf.Reader.create(stream.peek());
                             const length = reader.length;
                             while (reader.position < length) {
                                 const tag = reader.uint32();
@@ -1371,7 +1377,7 @@ view.ModelFactoryService = class {
         this._host = host;
         this._extensions = [];
         this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn' ]);
-        this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt', '.model' ]);
+        this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt', '.model', '.pt', '.pth', '.pkl' ]);
         this.register('./mxnet', [ '.mar', '.model', '.json', '.params' ]);
         this.register('./coreml', [ '.mlmodel' ]);
         this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt' ]);