Explorar o código

Add Caffe test file (#276)

Lutz Roeder %!s(int64=5) %!d(string=hai) anos
pai
achega
be0973de89
Modificáronse 3 ficheiros con 45 adicións e 44 borrados
  1. 37 43
      source/caffe.js
  2. 1 1
      source/view.js
  3. 7 0
      test/models.json

+ 37 - 43
source/caffe.js

@@ -11,16 +11,10 @@ caffe.ModelFactory = class {
         if (extension == 'caffemodel') {
             return true;
         }
-        if (extension == 'pbtxt' || extension == 'prototxt') {
-            if (identifier == 'saved_model.pbtxt' || identifier == 'saved_model.prototxt' ||
-                identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
-                identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
-                return false;
-            }
-            const tags = context.tags('pbtxt');
-            if (tags.has('layer') || tags.has('layers') || tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
-                return true;
-            }
+        if (identifier == 'saved_model.pbtxt' || identifier == 'saved_model.prototxt' ||
+            identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt') ||
+            identifier.endsWith('init_net.pbtxt') || identifier.endsWith('init_net.prototxt')) {
+            return false;
         }
         if (extension == 'pt') {
             // Reject PyTorch models
@@ -33,10 +27,10 @@ caffe.ModelFactory = class {
             if (buffer && buffer.length > 2 && buffer[0] == 0x50 && buffer[1] == 0x4B) {
                 return false;
             }
-            const tags = context.tags('pbtxt');
-            if (tags.has('layer') || tags.has('layers') || tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
-                return true;
-            }
+        }
+        const tags = context.tags('pbtxt');
+        if (tags.has('layer') || tags.has('layers') || tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
+            return true;
         }
         return false;
     }
@@ -46,39 +40,39 @@ caffe.ModelFactory = class {
             caffe.proto = protobuf.get('caffe').caffe;
             return caffe.Metadata.open(host).then((metadata) => {
                 const extension = context.identifier.split('.').pop();
-                if (extension == 'pbtxt' || extension == 'prototxt' || extension == 'pt') {
-                    const tags = context.tags('pbtxt');
-                    if (tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
-                        try {
-                            const reader = protobuf.TextReader.create(context.buffer);
-                            reader.field = function(tag, message) {
-                                if (message instanceof caffe.proto.SolverParameter) {
-                                    message[tag] = this.read();
-                                    return;
-                                }
-                                throw new Error("Unknown field '" + tag + "'" + this.location());
-                            };
-                            const solver = caffe.proto.SolverParameter.decodeText(reader);
-                            if (solver.net_param) {
-                                return new caffe.Model(metadata, solver.net_param);
-                            }
-                            else if (solver.net || solver.train_net) {
-                                let file = solver.net || solver.train_net;
-                                file = file.split('/').pop();
-                                return context.request(file, null).then((buffer) => {
-                                    return this._openNetParameterText(metadata, context.identifier, buffer, host);
-                                }).catch((error) => {
-                                    if (error) {
-                                        const message = error && error.message ? error.message : error.toString();
-                                        throw new caffe.Error("Failed to load '" + file + "' (" + message.replace(/\.$/, '') + ').');
-                                    }
-                                });
+                const tags = context.tags('pbtxt');
+                if (tags.has('net') || tags.has('train_net') || tags.has('net_param')) {
+                    try {
+                        const reader = protobuf.TextReader.create(context.buffer);
+                        reader.field = function(tag, message) {
+                            if (message instanceof caffe.proto.SolverParameter) {
+                                message[tag] = this.read();
+                                return;
                             }
+                            throw new Error("Unknown field '" + tag + "'" + this.location());
+                        };
+                        const solver = caffe.proto.SolverParameter.decodeText(reader);
+                        if (solver.net_param) {
+                            return new caffe.Model(metadata, solver.net_param);
                         }
-                        catch (error) {
-                            // continue regardless of error
+                        else if (solver.net || solver.train_net) {
+                            let file = solver.net || solver.train_net;
+                            file = file.split('/').pop();
+                            return context.request(file, null).then((buffer) => {
+                                return this._openNetParameterText(metadata, context.identifier, buffer, host);
+                            }).catch((error) => {
+                                if (error) {
+                                    const message = error && error.message ? error.message : error.toString();
+                                    throw new caffe.Error("Failed to load '" + file + "' (" + message.replace(/\.$/, '') + ').');
+                                }
+                            });
                         }
                     }
+                    catch (error) {
+                        // continue regardless of error
+                    }
+                }
+                else if (tags.has('layer') || tags.has('layers')) {
                     return this._openNetParameterText(metadata, context.identifier, context.buffer, host);
                 }
                 else {

+ 1 - 1
source/view.js

@@ -1185,7 +1185,7 @@ view.ModelFactoryService = class {
         this.register('./mxnet', [ '.mar', '.model', '.json', '.params' ]);
         this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.bin', '.pb', '.zip' ]);
         this.register('./coreml', [ '.mlmodel' ]);
-        this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt' ]);
+        this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt' ]);
         this.register('./caffe2', [ '.pb', '.pbtxt', '.prototxt' ]);
         this.register('./torch', [ '.t7' ]);
         this.register('./tflite', [ '.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json' ]);

+ 7 - 0
test/models.json

@@ -740,6 +740,13 @@
     "format": "Caffe v2",
     "link":   "https://github.com/cwlacewe/netscope"
   },
+  {
+    "type":   "caffe",
+    "target": "yolov3_gen.prototext.txt",
+    "source": "https://github.com/lutzroeder/netron/files/5615031/yolov3_gen.prototext.txt.zip[yolov3_gen.prototext.txt]",
+    "format": "Caffe v2",
+    "link":   "https://github.com/lutzroeder/netron/issues/276"
+  },
   {
     "type":   "caffe",
     "target": "yolov3-tiny.prototxt",