Przeglądaj źródła

Add TorchScript test file (#647)

Lutz Roeder 4 lat temu
rodzic
commit
2526e30810
3 zmienionych plików z 17 dodań i 1 usunięć
  1. 1 1
      source/pytorch-metadata.json
  2. 9 0
      source/pytorch.js
  3. 7 0
      test/models.json

+ 1 - 1
source/pytorch-metadata.json

@@ -3154,7 +3154,7 @@
       { "name": "size", "type": "int64[]" },
       { "name": "dtype", "type": "ScalarType", "optional": true, "default": null },
       { "name": "layout", "type": "Layout", "optional": true, "default": null },
-      { "name": "device", "type": "Device", "optiona": true, "default": null },
+      { "name": "device", "type": "Device", "optional": true, "default": null },
       { "name": "pin_memory", "type": "boolean", "optional": true, "default": null }
     ],
     "outputs": [

+ 9 - 0
source/pytorch.js

@@ -3031,6 +3031,15 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                 pytorch.Utility.isCall(statement.expression, 'torch.unbind', 2)) {
                 statement.expression.arguments[0].__tuple__ = statement.target.value.length;
             }
+            // x = torch.len(input)
+            if (statement.type === '=' &&
+                statement.target.type === 'id' &&
+                pytorch.Utility.isCall(statement.expression, 'torch.len', 1)) {
+                const tensor = this.expression(statement.expression.arguments[0], context);
+                if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
+                    tensor.resize_([ NaN, NaN, NaN, NaN ]);
+                }
+            }
             const value = this.statement(statement, context);
             if (value !== undefined) {
                 return value;

+ 7 - 0
test/models.json

@@ -4317,6 +4317,13 @@
     "source": "https://github.com/lutzroeder/netron/files/6096623/mobilenet_v2_traced.pt.zip[mobilenet_v2_traced.pt]",
     "link":   "https://github.com/lutzroeder/netron/issues/281"
   },
+  {
+    "type":   "pytorch",
+    "target": "model_0_epochs.pt",
+    "source": "https://github.com/lutzroeder/netron/files/6765569/model_0_epochs.pt.zip[model_0_epochs.pt]",
+    "format": "TorchScript v1.6",
+    "link":   "https://github.com/lutzroeder/netron/issues/647"
+  },
   {
     "type":   "pytorch",
     "target": "model-reddit16-f140225004_2.pt1",