Explorar o código

Caffe inplace nodes

Lutz Roeder %!s(int64=8) %!d(string=hai) anos
pai
achega
85a30f8579
Modificáronse 2 ficheiros con 177 adicións e 33 borrados
  1. 107 33
      src/caffe-model.js
  2. 70 0
      src/caffe-operator.json

+ 107 - 33
src/caffe-model.js

@@ -24,9 +24,9 @@ class CaffeModel {
         try {
             var netParameter = caffe.NetParameter.decode(buffer);
             var model = new CaffeModel(netParameter);
-            // CaffeOperatorMetadata.open(host, (err, metadata) => {
+            CaffeOperatorMetadata.open(host, (err, metadata) => {
                 callback(null, model);
-            // });
+            });
         }
         catch (err) {
             callback(err, null);
@@ -62,20 +62,53 @@ class CaffeGraph {
         this._name = netParameter.name;
         this._nodes = [];
 
+        var layers = [];
         switch (version) {
             case 1:
-                netParameter.layers.forEach((layer) => {
-                    this._nodes.push(new CaffeNode(layer, version));
-                });
+                layers = netParameter.layers;
                 break;
             case 2:
-                netParameter.layer.forEach((layer) => {
-                    this._nodes.push(new CaffeNode(layer, version));
-                });
+                layers = netParameter.layer;
                 break;
         }
 
+        var nonInplaceLayers = [];
+        var inplaceMap = {};
+        layers.forEach((layer) => {
+            if (layer.top.length == 1 && layer.bottom.length == 1 && layer.top[0] == layer.bottom[0]) {
+                var key = layer.top[0];
+                if (!inplaceMap[key]) {
+                    inplaceMap[key] = [];
+                }
+                inplaceMap[key].push(layer);
+            }
+            else {
+                nonInplaceLayers.push(layer);
+            }
+        });
 
+        Object.keys(inplaceMap).forEach((key) => {
+            var nodes = inplaceMap[key];
+            nodes.forEach((node, index) => {
+                if (index > 0) {
+                    node.bottom[0] = node.bottom[0] + ':' + index.toString();
+                }
+                node.top[0] = node.top[0] + ':' + (index + 1).toString();
+            });
+        });
+
+        nonInplaceLayers.forEach((layer) => {
+            layer.bottom = layer.bottom.map((bottom) => {
+                if (inplaceMap[bottom]) {
+                    return bottom + ':' + inplaceMap[bottom].length.toString();
+                }
+                return bottom;
+            });
+        });
+
+        layers.forEach((layer) => {
+            this._nodes.push(new CaffeNode(layer, version));
+        });
     }
 
     get name() {
@@ -109,48 +142,44 @@ class CaffeNode {
                 var table = {};
                 table[caffe.V1LayerParameter.LayerType.NONE] = 'None';
                 table[caffe.V1LayerParameter.LayerType.ACCURACY] = 'Accuracy';
-                // BNLL = 2;
+                table[caffe.V1LayerParameter.LayerType.BNLL] = 'BNLL';
                 table[caffe.V1LayerParameter.LayerType.CONCAT] = 'Concat'; 
                 table[caffe.V1LayerParameter.LayerType.CONVOLUTION] = 'Convolution';
                 table[caffe.V1LayerParameter.LayerType.DATA] = 'Data';
                 table[caffe.V1LayerParameter.LayerType.DROPOUT] = 'Dropout';
-                // EUCLIDEAN_LOSS = 7;
+                table[caffe.V1LayerParameter.LayerType.EUCLIDEAN_LOSS] = 'EuclideanLoss';
                 table[caffe.V1LayerParameter.LayerType.FLATTEN] = 'Flatten';
                 table[caffe.V1LayerParameter.LayerType.HDF5_DATA] = 'HDF5Data';
                 table[caffe.V1LayerParameter.LayerType.HDF5_OUTPUT] = 'HDF5Output';
-                // IM2COL = 11;
+                table[caffe.V1LayerParameter.LayerType.IM2COL] = 'Im2col';
                 table[caffe.V1LayerParameter.LayerType.IMAGE_DATA] = 'ImageData';
-                // INFOGAIN_LOSS = 13;
+                table[caffe.V1LayerParameter.LayerType.INFOGAIN_LOSS] = 'InfogainLoss';
                 table[caffe.V1LayerParameter.LayerType.INNER_PRODUCT] = 'InnerProduct';
                 table[caffe.V1LayerParameter.LayerType.LRN] = 'LRN';
-                // MULTINOMIAL_LOGISTIC_LOSS = 16;
+                table[caffe.V1LayerParameter.LayerType.MULTINOMIAL_LOGISTIC_LOSS] = 'MultinomialLogisticLoss';
                 table[caffe.V1LayerParameter.LayerType.POOLING] = 'Pooling';
                 table[caffe.V1LayerParameter.LayerType.RELU] = 'ReLU';
                 table[caffe.V1LayerParameter.LayerType.SIGMOID] = 'Sigmoid';
                 table[caffe.V1LayerParameter.LayerType.SOFTMAX] = 'Softmax';
                 table[caffe.V1LayerParameter.LayerType.SOFTMAX_LOSS] = 'SoftmaxLoss';
                 table[caffe.V1LayerParameter.LayerType.SPLIT] = 'Split';
-                /*
-                    TANH = 23;
-                    WINDOW_DATA = 24;
-                    ELTWISE = 25;
-                    POWER = 26;
-                    SIGMOID_CROSS_ENTROPY_LOSS = 27;
-                    HINGE_LOSS = 28;
-                    MEMORY_DATA = 29;
-                    ARGMAX = 30;
-                    THRESHOLD = 31;
-                    DUMMY_DATA = 32;
-                */
+                table[caffe.V1LayerParameter.LayerType.TANH] = 'TanH';
+                table[caffe.V1LayerParameter.LayerType.WINDOW_DATA] = 'WindowData';
+                table[caffe.V1LayerParameter.LayerType.ELTWISE] = 'Eltwise';
+                table[caffe.V1LayerParameter.LayerType.POWER] = 'Power';
+                table[caffe.V1LayerParameter.LayerType.SIGMOID_CROSS_ENTROPY_LOSS] = 'SigmoidCrossEntropyLoss';
+                table[caffe.V1LayerParameter.LayerType.HINGE_LOSS] = 'HingeLoss';
+                table[caffe.V1LayerParameter.LayerType.MEMORY_DATA] = 'HingeLoss';
+                table[caffe.V1LayerParameter.LayerType.ARGMAX] = 'ArgMax';
+                table[caffe.V1LayerParameter.LayerType.THRESHOLD] = 'Threshold';
+                table[caffe.V1LayerParameter.LayerType.DUMMY_DATA] = 'DummyData';
                 table[caffe.V1LayerParameter.LayerType.SLICE] = 'Slice';
-                /*
-                    MVN = 34;
-                    ABSVAL = 35;
-                    SILENCE = 36;
-                    CONTRASTIVE_LOSS = 37;
-                    EXP = 38;
-                    DECONVOLUTION = 39;
-                */
+                table[caffe.V1LayerParameter.LayerType.MVN] = 'MVN';
+                table[caffe.V1LayerParameter.LayerType.ABSVAL] = 'AbsVal';
+                table[caffe.V1LayerParameter.LayerType.SILENCE] = 'Silence';
+                table[caffe.V1LayerParameter.LayerType.CONTRASTIVE_LOSS] = 'ContrastiveLoss';
+                table[caffe.V1LayerParameter.LayerType.EXP] = 'Exp';
+                table[caffe.V1LayerParameter.LayerType.DECONVOLUTION] = 'Deconvolution';
                 CaffeNode._operatorTable = table;
             }
             this._type = CaffeNode._operatorTable[layer.type];
@@ -198,6 +227,10 @@ class CaffeNode {
         return this._type;
     }
 
+    get category() {
+        return CaffeOperatorMetadata.operatorMetadata.getOperatorCategory(this._type);
+    }
+
     get name() { 
         return this._name;
     }
@@ -229,4 +262,45 @@ class CaffeAttribute {
     get value() { 
         return JSON.stringify(this._value);
     }
+}
+
+class CaffeOperatorMetadata 
+{
+
+    static open(host, callback) {
+        if (CaffeOperatorMetadata.operatorMetadata) {
+            callback(null, CaffeOperatorMetadata.operatorMetadata);
+        }
+        else {
+            host.request('/caffe-operator.json', (err, data) => {
+                CaffeOperatorMetadata.operatorMetadata = new CaffeOperatorMetadata(data);
+                callback(null, CaffeOperatorMetadata.operatorMetadata);
+            });
+        }    
+    }
+
+    constructor(data) {
+        this._map = {};
+        if (data) {
+            var items = JSON.parse(data);
+            if (items) {
+                items.forEach((item) => {
+                    if (item.name && item.schema)
+                    {
+                        var name = item.name;
+                        var schema = item.schema;
+                        this._map[name] = schema;
+                    }
+                });
+            }
+        }
+    }
+
+    getOperatorCategory(operator) {
+        var schema = this._map[operator];
+        if (schema && schema.category) {
+            return schema.category;
+        }
+        return null;
+    }
 }

+ 70 - 0
src/caffe-operator.json

@@ -0,0 +1,70 @@
+[
+  {
+    "name": "Convolution",
+    "schema": {
+      "category": "Layer",
+      "attributes": [
+      ]
+    }
+  },
+  {
+    "name": "InnerProduct",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "Scale",
+    "schema": {
+      "category": "Layer"
+    }
+  },
+  {
+    "name": "Dropout",
+    "schema": {
+      "category": "Dropout"
+    }
+  },
+  {
+    "name": "LRN",
+    "schema": {
+      "category": "Normalization"
+    }
+  },
+  {
+    "name": "BatchNorm",
+    "schema": {
+      "category": "Normalization"
+    }
+  },
+  {
+    "name": "Softmax",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "ReLU",
+    "schema": {
+      "category": "Activation"
+    }
+  },
+  {
+    "name": "Concat",
+    "schema": {
+      "category": "Tensor"
+    }
+  },
+  {
+    "name": "Split",
+    "schema": {
+      "category": "Tensor"
+    }
+  },
+  {
+    "name": "Pooling",
+    "schema": {
+      "category": "Pool"
+    }
+  }
+]