Procházet zdrojové kódy

Add Caffe2 test files (#437)

Lutz Roeder před 6 roky
rodič
revize
9992f8b771
2 změnil soubory, kde provedl 98 přidání a 64 odebrání
  1. 84 64
      src/caffe2.js
  2. 14 0
      test/models.json

+ 84 - 64
src/caffe2.js

@@ -11,7 +11,8 @@ caffe2.ModelFactory = class {
         const identifier = context.identifier.toLowerCase();
         const extension = identifier.split('.').pop().toLowerCase();
         if (extension == 'pb') {
-            if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb')) {
+            if (identifier.endsWith('predict_net.pb') || identifier.endsWith('init_net.pb') ||
+                identifier.startsWith('predict_net') || identifier.startsWith('init_net')) {
                 return true;
             }
             const tags = context.tags('pb');
@@ -44,7 +45,7 @@ caffe2.ModelFactory = class {
             }
         }
         if (extension == 'pbtxt' || extension == 'prototxt') {
-            if (identifier.endsWith('predict_net.pbtxt') || identifier.endsWith('predict_net.prototxt')) {
+            if (identifier.endsWith('predict_net')) {
                 return true;
             }
             const tags = context.tags('pbtxt');
@@ -61,8 +62,10 @@ caffe2.ModelFactory = class {
     open(context, host) {
         return host.require('./caffe2-proto').then(() => {
             return caffe2.Metadata.open(host).then((metadata) => {
-                const identifier = context.identifier; 
-                const extension = identifier.split('.').pop().toLowerCase();
+                const identifier = context.identifier;
+                const parts = identifier.split('.');
+                const extension = parts.pop().toLowerCase();
+                const base = parts.join('.');
                 if (extension == 'pbtxt' || extension == 'prototxt') {
                     const open_text = (predict, init) => {
                         let predict_net = null;
@@ -83,8 +86,10 @@ caffe2.ModelFactory = class {
                             throw new caffe2.Error("File text format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'.");
                         }
                         try {
-                            caffe2.proto = protobuf.roots.caffe2.caffe2;
-                            init_net = caffe2.proto.NetDef.decodeText(prototxt.TextReader.create(init));
+                            if (init) {
+                                caffe2.proto = protobuf.roots.caffe2.caffe2;
+                                init_net = caffe2.proto.NetDef.decodeText(prototxt.TextReader.create(init));
+                            }
                         }
                         catch (error) {
                             // continue regardless of error
@@ -99,20 +104,23 @@ caffe2.ModelFactory = class {
                             throw new caffe2.Error(message + " in '" + identifier + "'.");
                         }
                     };
-                    if (identifier.toLowerCase().startsWith('init_net.')) {
-                        return context.request('predict_net.' + extension, 'utf-8').then((text) => {
+                    if (base.toLowerCase().endsWith('init_net') || base.toLowerCase().startsWith('init_net')) {
+                        return context.request(identifier.replace('init_net', 'predict_net'), 'utf-8').then((text) => {
                             return open_text(text, context.text);
                         }).catch(() => {
                             return open_text(context.text, null);
                         });
                     }
-                    else {
-                        return context.request('init_net.' + extension, 'utf-8').then((text) => {
+                    else if (base.toLowerCase().endsWith('predict_net') || base.toLowerCase().startsWith('predict_net')) {
+                        return context.request(identifier.replace('predict_net', 'init_net'), 'utf-8').then((text) => {
                             return open_text(context.text, text);
                         }).catch(() => {
                             return open_text(context.text, null);
                         });
                     }
+                    else {
+                        return open_text(context.text, null);
+                    }
                 }
                 else {
                     const open_binary = (predict, init) => {
@@ -126,8 +134,10 @@ caffe2.ModelFactory = class {
                             throw new caffe2.Error("File format is not caffe2.NetDef (" + error.message + ") in '" + identifier + "'.");
                         }
                         try {
-                            caffe2.proto = protobuf.roots.caffe2.caffe2;
-                            init_net = caffe2.proto.NetDef.decode(init);
+                            if (init) {
+                                caffe2.proto = protobuf.roots.caffe2.caffe2;
+                                init_net = caffe2.proto.NetDef.decode(init);
+                            }
                         }
                         catch (error) {
                             // continue regardless of error
@@ -142,15 +152,15 @@ caffe2.ModelFactory = class {
                             throw new caffe2.Error(message + " in '" + identifier + "'.");
                         }
                     };
-                    if (identifier.toLowerCase().startsWith('init_net.')) {
-                        return context.request('predict_net.' + extension, null).then((buffer) => {
+                    if (base.toLowerCase().endsWith('init_net')) {
+                        return context.request(base.substring(0, base.length - 8) + 'predict_net.' + extension, null).then((buffer) => {
                             return open_binary(buffer, context.buffer);
                         }).catch(() => {
                             return open_binary(context.buffer, null);
                         });
                     }
                     else {
-                        return context.request('init_net.' + extension, null).then((buffer) => {
+                        return context.request(base.substring(0, base.length - 11) + 'init_net.' + extension, null).then((buffer) => {
                             return open_binary(context.buffer, buffer);
                         }).catch(() => {
                             return open_binary(context.buffer, null);
@@ -190,47 +200,59 @@ caffe2.Graph = class {
         this._type = netDef.type || '';
         this._nodes = [];
 
-        let initializers = {};
-        for (let external_input of netDef.external_input) {
-            initializers[external_input] = {};
+        let inputs = new Map();
+        for (const input of netDef.external_input) {
+            inputs.set(input, {});
         }
         if (init) {
-            for (let op of init.op) {
+            for (const op of init.op) {
                 if (op.output && op.output.length == 1) {
                     const name = op.output[0];
-                    let dataType = null;
+                    if (!inputs.has(name)) {
+                        inputs.set(name, {});
+                    }
+                    let initializer = inputs.get(name);
+                    for (const arg of op.arg) {
+                        initializer[arg.name] = arg;
+                    }
                     switch (op.type) {
                         case 'GivenTensorFill':
-                            dataType = 'float32';
+                            initializer.dataType = 'float32';
+                            break;
+                        case 'GivenTensorDoubleFill':
+                            initializer.dataType = 'float64';
                             break;
                         case 'GivenTensorBoolFill':
-                            dataType = 'boolean';
+                            initializer.dataType = 'boolean';
                             break;
                         case 'GivenTensorByteStringToUInt8Fill':
-                            dataType = 'uint8';
+                            initializer.dataType = 'uint8';
                             break;
                         case 'GivenTensorIntFill':
-                            dataType = 'int32';
+                            initializer.dataType = 'int32';
                             break;
                         case 'GivenTensorInt64Fill':
-                            dataType = 'int64';
+                            initializer.dataType = 'int64';
                             break;
                         case 'GivenTensorStringFill':
-                            dataType = 'string';
+                            initializer.dataType = 'string';
                             break;
                         case 'Int8GivenIntTensorFill':
-                            dataType = 'int32';
+                            initializer.dataType = 'int32';
                             break;
                         case 'Int8GivenTensorFill':
-                            dataType = 'int8';
+                            initializer.dataType = 'int8';
                             break;
-                        default:
+                        case 'XavierFill':
                             break;
+                        case 'ConstantFill':
+                            break;
+                        default:
+                            throw new caffe2.Error("Unknown init op '" + op.type + "'.");
+                    }
+                    if (initializer.values && (initializer.values.floats.length !== 1 || initializer.values.floats[0] !== 0)) {
+                        initializer.input = false;
                     }
-                    if (dataType) {
-                        op.dataType = dataType;
-                        initializers[name] = op;
-                    }    
                 }
             }
         }
@@ -241,7 +263,7 @@ caffe2.Graph = class {
             op.input = op.input.map((input) => scope[input] ? scope[input] : input);
             op.output = op.output.map((output) => {
                 if (scope[output]) {
-                    let next = output + '\n' + index.toString(); // custom argument id
+                    const next = output + '\n' + index.toString(); // custom argument id
                     scope[output] = next;
                     return next;
                 }
@@ -254,7 +276,7 @@ caffe2.Graph = class {
         let lastNode = null;
         let lastOutput = null;
         for (let op of netDef.op) {
-            let node = new caffe2.Node(metadata, op, initializers);
+            let node = new caffe2.Node(metadata, op, inputs);
             if (op.input.length == 1 &&
                 op.output.length >= 1 && 
                 op.input[0].split('\n').shift() == op.output[0].split('\n').shift() && 
@@ -274,11 +296,14 @@ caffe2.Graph = class {
         }
 
         this._inputs = [];
-        let inputs = Object.keys(initializers);
-        for (let input of inputs) {
-            if (inputs.length == 1 || !input.startsWith('caffe.')) {
-                this._inputs.push(new caffe2.Parameter(input, [ new caffe2.Argument(input, null, null) ]));
+        for (let input of netDef.external_input) {
+            if (netDef.external_input.length > 1) {
+                const initializer = inputs.get(input);
+                if (initializer && initializer.input === false) {
+                    continue;
+                }
             }
+            this._inputs.push(new caffe2.Parameter(input, [ new caffe2.Argument(input, null, null) ]));
         }
 
         this._outputs = [];
@@ -379,16 +404,25 @@ caffe2.Node = class {
 
         const schema = metadata.type(this._operator);
 
-        let inputs = op.input;
+        const inputs = op.input;
+        const outputs = op.output;
+
         let tensors = {};
         let index = 0;
-        for (let input of inputs) {
-            if (index > 0 && initializers[input]) {
-                tensors[input] = new caffe2.Tensor(input, initializers[input], 'Initializer');
-                delete initializers[input];
+        for (const input of inputs) {
+            if (index > 0 && initializers.has(input)) {
+                const initializer = initializers.get(input);
+                tensors[input] = new caffe2.Tensor(input, initializer);
+                initializer.input = false;
             }
             index++;
         }
+        for (const output of outputs) {
+            if (initializers.has(output)) {
+                const initializer = initializers.get(output);
+                initializer.input = false;
+            }
+        }
         this._inputs = [];
         let inputIndex = 0;
         if (schema && schema.inputs) {
@@ -412,7 +446,6 @@ caffe2.Node = class {
             }));
         }
 
-        let outputs = op.output;
         this._outputs = [];
         let outputIndex = 0;
         if (schema && schema.outputs) {
@@ -540,26 +573,13 @@ caffe2.Attribute = class {
 
 caffe2.Tensor = class {
 
-    constructor(name, tensor, kind) {
+    constructor(name, tensor) {
         this._name = name;
-        this._kind = kind;
-
-        let args = {};
-        if (tensor && tensor.arg) {
-            for (let arg of tensor.arg) {
-                args[arg.name] = arg;
-            }
-        }
-        let shape = null;
-        if (args.shape && args.shape.ints) {
-            shape = args.shape.ints;
-        }
-        if (args.values) {
-            this._values = args.values;
-        }
-        this._scale = Object.prototype.hasOwnProperty.call(args, 'Y_scale') ? args.Y_scale.f : 0;
-        this._zeroPoint = Object.prototype.hasOwnProperty.call(args, 'Y_zero_point') ? args.Y_zero_point.i : 0;
+        const shape = tensor.shape && tensor.shape.ints ? tensor.shape.ints : null;
         this._type = new caffe2.TensorType(tensor.dataType, new caffe2.TensorShape(shape));
+        this._values = tensor.values || null;
+        this._scale = tensor.Y_scale ? tensor.Y_scale.f : 0;
+        this._zeroPoint = tensor.Y_zero_point ? tensor.Y_zero_point.i : 0;
     }
 
     get name() {
@@ -571,7 +591,7 @@ caffe2.Tensor = class {
     }
 
     get kind() {
-        return this._kind;
+        return 'Initializer';
     }
 
     get quantization() {

+ 14 - 0
test/models.json

@@ -818,6 +818,13 @@
     "format": "Caffe2",
     "link":   "https://github.com/lutzroeder/netron/issues/168"
   },
+  {
+    "type":   "caffe2",
+    "target": "mobilenet/predict_net_int8.pbtxt,mobilenet/init_net_int8.pbtxt",
+    "source": "https://raw.githubusercontent.com/cuiyanx/dnnl-models/master/Image%20Classification/mobilenet/predict_net_int8.pbtxt,https://raw.githubusercontent.com/cuiyanx/dnnl-models/master/Image%20Classification/mobilenet/init_net_int8.pbtxt",
+    "format": "Caffe2",
+    "link":   "https://github.com/lutzroeder/netron/issues/437"
+  },
   {
     "type":   "caffe2",
     "target": "mobilenet_v2/predict_net.pb,mobilenet_v2/init_net.pb",
@@ -839,6 +846,13 @@
     "format": "Caffe2",
     "link":   "https://github.com/caffe2/models"
   },
+  {
+    "type":   "caffe2",
+    "target": "resnet50_quantized/resnet50_quantized_predict_net.pb,resnet50_quantized/resnet50_quantized_init_net.pb",
+    "source": "https://s3.amazonaws.com/download.caffe2.ai/models/resnet50_quantized/resnet50_quantized_predict_net.pb,https://s3.amazonaws.com/download.caffe2.ai/models/resnet50_quantized/resnet50_quantized_init_net.pb",
+    "format": "Caffe2",
+    "link":   "https://github.com/caffe2/models/tree/master/resnet50_quantized"
+  },
   {
     "type":   "caffe2",
     "target": "onnx_while/predict_net.pb,onnx_while/inits_net.pb",