Quellcode durchsuchen

Add PyTorch test file (#842)

Lutz Roeder vor 4 Jahren
Ursprung
Commit
d74d5db586
2 geänderte Dateien mit 21 neuen und 14 gelöschten Zeilen
  1. 14 14
      source/pytorch.js
  2. 7 0
      test/models.json

+ 14 - 14
source/pytorch.js

@@ -2979,7 +2979,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                             break;
                                         }
                                         case 'torch.slice': {
-                                            const input = this.expression(args[0], context);
+                                            const input = evalArgs[0];
                                             if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
                                                 const size = input.size();
                                                 parameter.resize_(size);
@@ -2987,7 +2987,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                             break;
                                         }
                                         case 'torch.to': {
-                                            const input = this.expression(args[0], context);
+                                            const input = evalArgs[0];
                                             if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
                                                 const size = input.size();
                                                 parameter.resize_(size);
@@ -3006,7 +3006,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                         case 'torch.relu':
                                         case 'torch.clamp_':
                                         case 'torch.hardswish_': {
-                                            const input = this.expression(args[0], context);
+                                            const input = evalArgs[0];
                                             if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
                                                 parameter.resize_(input.size());
                                             }
@@ -3014,12 +3014,12 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                         }
                                         case 'torch.add':
                                         case 'torch.sub': {
-                                            const input = this.expression(args[0], context);
+                                            const input = evalArgs[0];
                                             if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
                                                 parameter.resize_(input.size());
                                             }
                                             else {
-                                                const other = this.expression(args[1], context);
+                                                const other = evalArgs[1];
                                                 if (pytorch.Utility.isTensor(other) && Array.isArray(other.size())) {
                                                     parameter.resize_(other.size());
                                                 }
@@ -3027,15 +3027,15 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                             break;
                                         }
                                         case 'torch.select': {
-                                            const input = this.expression(args[0], context);
+                                            const input = evalArgs[0];
                                             if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
                                                 parameter.resize_(Array(input.size().length - 1).fill(NaN));
                                             }
                                             break;
                                         }
                                         case 'torch.layer_norm': {
-                                            const input = this.expression(args[0], context);
-                                            const normalized_shape = this.expression(args[1], context);
+                                            const input = evalArgs[0];
+                                            const normalized_shape = evalArgs[1];
                                             if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
                                                 const shape = input.size();
                                                 if (Array.isArray(normalized_shape) && normalized_shape.length === 1) {
@@ -3048,19 +3048,19 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                         case 'torch.ones':
                                         case 'torch.zeros':
                                         case 'torch.zeros_like': {
-                                            parameter.resize_(this.expression(args[0], context));
+                                            parameter.resize_(evalArgs[0]);
                                             break;
                                         }
                                         case 'torch.view':
                                         case 'torch.reshape':
                                         case 'torch.new_full': {
-                                            parameter.resize_(this.expression(args[1], context));
+                                            parameter.resize_(evalArgs[1]);
                                             break;
                                         }
                                         case 'torch.transpose': {
-                                            const input = this.expression(args[0], context);
-                                            let dim0 = this.expression(args[1], context);
-                                            let dim1 = this.expression(args[2], context);
+                                            const input = evalArgs[0];
+                                            let dim0 = evalArgs[1];
+                                            let dim1 = evalArgs[2];
                                             if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
                                                 const size = input.size().slice();
                                                 dim0 = dim0 > 0 ? dim0 : size.length + dim0;
@@ -3083,7 +3083,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                             parameter.__quantized__ = true;
                                             break;
                                         case 'torch.contiguous':
-                                            parameter.__source__ = this.expression(args[0], context);
+                                            parameter.__source__ = evalArgs[0];
                                             break;
                                     }
                                 }

+ 7 - 0
test/models.json

@@ -4383,6 +4383,13 @@
     "format": "TorchScript v1.0",
     "link":   "https://github.com/ApolloAuto/apollo"
   },
+  {
+    "type":   "pytorch",
+    "target": "LMModel1.pt",
+    "source": "https://github.com/lutzroeder/netron/files/7726055/LMModel1.zip[LMModel1.pt]",
+    "format": "TorchScript v1.6",
+    "link":   "https://github.com/lutzroeder/netron/issues/842"
+  },
   {
     "type":   "pytorch",
     "target": "mask_depthwise_conv.pt",