浏览代码

Add PyTorch LeakyReLU

Lutz Roeder 7 年之前
父节点
当前提交
5f9e5413d4
共有 5 个文件被更改,包括 93 次插入23 次删除
  1. 2 0
      electron-builder.yml
  2. 82 19
      src/pytorch-metadata.json
  3. 4 1
      src/pytorch.js
  4. 2 2
      tools/metadata/pytorch-metadata.py
  5. 3 1
      tools/metadata/pytorch-update

+ 2 - 0
electron-builder.yml

@@ -50,6 +50,8 @@ fileAssociations:
     ext: prototxt
   - name: "Caffe Model"
     ext: prototxt
+  - name: "PyTorch Model"
+    ext: pth
 publish:
   - provider: github
     releaseType: release

+ 82 - 19
src/pytorch-metadata.json

@@ -153,6 +153,19 @@
       "package": "torch.nn.modules.activation"
     }
   },
+  {
+    "name": "LeakyReLU",
+    "schema": {
+      "attributes": [
+        {
+          "name": "training",
+          "visible": false
+        }
+      ],
+      "category": "Activation",
+      "package": "torch.nn.modules.activation"
+    }
+  },
   {
     "name": "MaxPool2d",
     "schema": {
@@ -208,7 +221,7 @@
     }
   },
   {
-    "name": "BatchNorm1",
+    "name": "BatchNorm1d",
     "schema": {
       "attributes": [
         {
@@ -249,6 +262,13 @@
       "package": "torch.nn.modules.batchnorm"
     }
   },
+  {
+    "name": "GroupNorm",
+    "schema": {
+      "category": "Normalization",
+      "package": "torch.nn.modules.normalization"
+    }
+  },
   {
     "name": "Dropout2d",
     "schema": {
@@ -294,31 +314,74 @@
   {
     "name": "LSTM",
     "schema": {
-      "category": "Layer",
-      "package": "torch.nn.modules.rnn",
       "attributes": [
-        { "name": "training", "visible": false },
-        { "name": "dropout", "default": 0 },
-        { "name": "dropout_state", "default": {} },
-        { "name": "num_layers", "default": 1 },
-        { "name": "batch_first", "visible": false },
-        { "name": "bidirectional", "visible": false }
-      ]
+        {
+          "name": "training",
+          "visible": false
+        },
+        {
+          "default": 0,
+          "name": "dropout"
+        },
+        {
+          "default": {},
+          "name": "dropout_state"
+        },
+        {
+          "default": 1,
+          "name": "num_layers"
+        },
+        {
+          "name": "batch_first",
+          "visible": false
+        },
+        {
+          "name": "bidirectional",
+          "visible": false
+        }
+      ],
+      "category": "Layer",
+      "package": "torch.nn.modules.rnn"
     }
   },
   {
     "name": "Embedding",
     "schema": {
-      "category": "Transform",
-      "package": "torch.nn.modules.sparse",
       "attributes": [
-        { "name": "training", "visible": false },
-        { "name": "norm_type", "default": 2 },
-        { "name": "scale_grad_by_freq", "default": false },
-        { "name": "sparse", "default": false },
-        { "name": "max_norm", "default": null },
-        { "name": "padding_idx", "default": null }
-      ]
+        {
+          "name": "training",
+          "visible": false
+        },
+        {
+          "default": 2,
+          "name": "norm_type"
+        },
+        {
+          "default": false,
+          "name": "scale_grad_by_freq"
+        },
+        {
+          "default": false,
+          "name": "sparse"
+        },
+        {
+          "default": null,
+          "name": "max_norm"
+        },
+        {
+          "default": null,
+          "name": "padding_idx"
+        }
+      ],
+      "category": "Transform",
+      "package": "torch.nn.modules.sparse"
+    }
+  },
+  {
+    "name": "ReflectionPad2d",
+    "schema": {
+      "category": "Tensor",
+      "package": "torch.nn.modules.padding"
     }
   }
 ]

+ 4 - 1
src/pytorch.js

@@ -71,10 +71,11 @@ pytorch.ModelFactory = class {
             var functionTable = {};
 
             constructorTable['argparse.Namespace'] = function (args) { this.args = args; };
+            constructorTable['torch.nn.modules.activation.LeakyReLU'] = function () {};
             constructorTable['torch.nn.modules.activation.ReLU'] = function () {};
             constructorTable['torch.nn.modules.activation.PReLU'] = function () {};
-            constructorTable['torch.nn.modules.activation.Tanh'] = function () {};
             constructorTable['torch.nn.modules.activation.Sigmoid'] = function () {};
+            constructorTable['torch.nn.modules.activation.Tanh'] = function () {};
             constructorTable['torch.nn.modules.batchnorm.BatchNorm1d'] = function () {};
             constructorTable['torch.nn.modules.batchnorm.BatchNorm2d'] = function () {};
             constructorTable['torch.nn.modules.batchnorm.BatchNorm3d'] = function () {};
@@ -89,6 +90,7 @@ pytorch.ModelFactory = class {
             constructorTable['torch.nn.modules.dropout.Dropout2d'] = function () {};
             constructorTable['torch.nn.modules.dropout.Dropout3d'] = function () {};
             constructorTable['torch.nn.modules.linear.Linear'] = function () {};
+            constructorTable['torch.nn.modules.normalization.GroupNorm'] = function () {};
             constructorTable['torch.nn.modules.pooling.AvgPool1d'] = function () {};
             constructorTable['torch.nn.modules.pooling.AvgPool2d'] = function () {};
             constructorTable['torch.nn.modules.pooling.AvgPool3d'] = function () {};
@@ -110,6 +112,7 @@ pytorch.ModelFactory = class {
             constructorTable['torchvision.models.inception.InceptionC'] = function () {};
             constructorTable['torchvision.models.inception.InceptionD'] = function () {};
             constructorTable['torchvision.models.inception.InceptionE'] = function () {};
+            constructorTable['torch.nn.modules.padding.ReflectionPad2d'] = function () {};
             constructorTable['torchvision.models.resnet.Bottleneck'] = function () {};
             constructorTable['torchvision.models.resnet.BasicBlock'] = function() {};
             constructorTable['torchvision.models.resnet.ResNet'] = function () {};

+ 2 - 2
tools/metadata/pytorch-metadata.py

@@ -25,14 +25,14 @@ for entry in json_root:
     schema = entry['schema']
     if 'package' in schema:
         class_name = schema['package'] + '.' + name
-        print(class_name)
+        # print(class_name)
         class_definition = pydoc.locate(class_name)
         if not class_definition:
             raise Exception('\'' + class_name + '\' not found.')
         docstring = class_definition.__doc__
         if not docstring:
             raise Exception('\'' + class_name + '\' missing __doc__.')
-        print(docstring)
+        # print(docstring)
 
 with io.open(json_file, 'w', newline='') as fout:
     json_data = json.dumps(json_root, sort_keys=True, indent=2)

+ 3 - 1
tools/metadata/pytorch-update

@@ -77,9 +77,11 @@ if [ ${__build} ]; then
 fi
 
 if [ ${__update} ]; then
-    echo "Generate 'caffe2-metadata.json'"
     pushd ${tools}/metadata > /dev/null
+    echo "Generate 'caffe2-metadata.json'"
     ${python} caffe2-metadata.py
+    echo "Generate 'pytorch-metadata.json'"
+    ${python} pytorch-metadata.py
     popd > /dev/null
 fi