瀏覽代碼

Update espresso.js

Lutz Roeder 1 年之前
父節點
當前提交
529deb793e
共有 2 個文件被更改,包括 16 次插入2 次删除
  1. 9 1
      source/espresso-metadata.json
  2. 7 1
      source/espresso.js

+ 9 - 1
source/espresso-metadata.json

@@ -31,6 +31,10 @@
     "name": "batch_norm",
     "category": "Normalization"
   },
+  {
+    "name": "batchnorm",
+    "category": "Normalization"
+  },
   {
     "name": "reshape",
     "category": "Shape"
@@ -73,6 +77,10 @@
   },
   {
     "name": "rnn_arch",
-    "category": "layer"
+    "category": "Layer"
+  },
+  {
+    "name": "flatten",
+    "category": "Shape"
   }
 ]

+ 7 - 1
source/espresso.js

@@ -320,7 +320,8 @@ espresso.Reader = class {
                 obj.outputs = [{ name: 'outputs', value: top }];
                 obj.chain = [];
                 switch (type) {
-                    case 'convolution': {
+                    case 'convolution':
+                    case 'deconvolution': {
                         this._weights(obj, data, [data.C, data.K, data.Nx, data.Ny]);
                         if (data.has_biases) {
                             obj.inputs.push(this._initializer('biases', data.blob_biases, 'float32', [data.C]));
@@ -332,6 +333,11 @@ espresso.Reader = class {
                         }
                         break;
                     }
+                    case 'batchnorm': {
+                        obj.inputs.push(this._initializer('params', data.blob_batchnorm_params, 'float32', [4, data.C]));
+                        delete data.blob_batchnorm_params;
+                        break;
+                    }
                     case 'inner_product': {
                         this._weights(obj, data, [data.nC, data.nB]);
                         if (data.has_biases) {