Browse Source

Add PyTorch test file

Lutz Roeder 6 năm trước cách đây
mục cha
commit
6bcbbe1005
4 tập tin đã thay đổi với 81 bổ sung1 xóa
  1. 45 0
      src/pytorch-metadata.json
  2. 10 1
      src/pytorch.js
  3. 25 0
      test/models.json
  4. 1 0
      tools/pytorch-script.py

+ 45 - 0
src/pytorch-metadata.json

@@ -845,6 +845,51 @@
       ]
     }
   },
+  {
+    "name": "torch.conv3d",
+    "schema": {
+      "attributes": [
+        {
+          "default": 1,
+          "name": "stride",
+          "type": "int64[]"
+        },
+        {
+          "default": 0,
+          "name": "padding",
+          "type": "int64[]"
+        },
+        {
+          "default": 1,
+          "name": "dilation",
+          "type": "int64[]"
+        },
+        {
+          "default": 1,
+          "name": "groups",
+          "type": "int64"
+        }
+      ],
+      "category": "Layer",
+      "inputs": [
+        {
+          "name": "input"
+        },
+        {
+          "name": "weight"
+        },
+        {
+          "name": "bias",
+          "option": "optional"
+        }
+      ],
+      "outputs": [
+        {
+          "name": "output"
+        }
+      ]
+    }
+  },
   {
     "name": "ops.quantized.add",
     "schema": {

+ 10 - 1
src/pytorch.js

@@ -1154,6 +1154,11 @@ pytorch.Execution = class {
         this._registerConstructor('torchvision.models.squeezenet.SqueezeNet', function() {});
         this._registerConstructor('torchvision.models.resnet.ResNet', function() {});
         this._registerConstructor('torchvision.models.vgg.VGG', function() {});
+        this._registerConstructor('torchvision.models.video.resnet.BasicBlock', function() {});
+        this._registerConstructor('torchvision.models.video.resnet.BasicStem', function() {});
+        this._registerConstructor('torchvision.models.video.resnet.Conv3DNoTemporal', function() {});
+        this._registerConstructor('torchvision.models.video.resnet.Conv3DSimple', function() {});
+        this._registerConstructor('torchvision.models.video.resnet.VideoResNet', function() {});
         this._registerConstructor('torchvision.models._utils.IntermediateLayerGetter', function() {});
         this._registerConstructor('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', function() {});
         this._registerConstructor('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', function() {});
@@ -2970,7 +2975,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
         let callArgs = Array.prototype.slice.call(args);
         if (callTarget) {
             const type = callTarget + '.' + name;
-            // ./aten/src/ATen/native/native_functions.yaml
+            // ./third_party/src/pytorch/aten/src/ATen/native/native_functions.yaml
             let schemas = this._metadata.type(type);
             if (schemas) {
                 if (!Array.isArray(schemas)) {
@@ -3051,6 +3056,10 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                 parameter.size = [ NaN, NaN, NaN, NaN ];
                                 break;
                             }
+                            case 'torch.conv3d': {
+                                parameter.size = [ NaN, NaN, NaN, NaN, NaN ];
+                                break;
+                            }
                             case 'torch.embedding': {
                                 parameter.size = [ NaN, NaN, NaN ];
                                 break;

+ 25 - 0
test/models.json

@@ -4216,6 +4216,31 @@
     "error":  "Could not find end of line in 'pytorch_invalid_file.pth'.",
     "link":   "https://github.com/lutzroeder/netron/issues/133"
   },
+  {
+    "type":   "pytorch",
+    "target": "r3d_18.pkl.pth",
+    "script": "./tools/pytorch sync install zoo",
+    "format": "PyTorch v0.1.10",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
+  },
+  {
+    "type":   "pytorch",
+    "target": "r3d_18.zip.pth",
+    "script": "./tools/pytorch sync install zoo",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
+  },
+  {
+    "type":   "pytorch",
+    "target": "r3d_18.pt",
+    "script": "./tools/pytorch sync install zoo",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
+  },
+  {
+    "type":   "pytorch",
+    "target": "r3d_18_traced.pt",
+    "script": "./tools/pytorch sync install zoo",
+    "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
+  },
   {
     "type":   "pytorch",
     "target": "resnet18.pkl.pth",

+ 1 - 0
tools/pytorch-script.py

@@ -69,6 +69,7 @@ def zoo():
     download_torchvision_model('torchvision.models.resnet101', [ 1, 3, 224, 224 ])
     download_torchvision_model('torchvision.models.shufflenet_v2_x1_0', [ 1, 3, 224, 224 ])
     download_torchvision_model('torchvision.models.squeezenet1_1', [ 1, 3, 224, 224 ])
+    download_torchvision_model('torchvision.models.video.r3d_18', [ 1, 3, 4, 112, 112 ])
     download_torchvision_model('torchvision.models.vgg11_bn', [ 1, 3, 224, 224 ])
     download_torchvision_model('torchvision.models.vgg16', [ 1, 3, 224, 224 ])