Browse Source

Add torch.export test file (#1211)

Lutz Roeder 4 months ago
parent
commit
bd6063da40
3 changed files with 19 additions and 4 deletions
  1. 12 1
      source/python.js
  2. 0 3
      source/pytorch.js
  3. 7 0
      test/models.json

+ 12 - 1
source/python.js

@@ -7544,6 +7544,13 @@ python.Execution = class {
                 super(lhs, rhs, '==');
             }
         });
+        this.registerType('sympy.functions.elementary.miscellaneous.MinMaxBase', class extends sympy.core.expr.Expr {
+        });
+        this.registerType('sympy.functions.elementary.miscellaneous.Max', class extends sympy.functions.elementary.miscellaneous.MinMaxBase {
+            __str__() {
+                return `Max(${this._args.map((a) => a.__str__()).join(', ')})`;
+            }
+        });
         this.registerFunction('sympy.core.sympify.sympify', (a /*, locals */) => {
             if (a instanceof sympy.core.expr.Expr) {
                 return a;
@@ -7556,7 +7563,11 @@ python.Execution = class {
                         case 'Mul': return new sympy.core.mul.Mul(...node.args.map((arg) => sympify(arg)));
                         case 'Add': return new sympy.core.add.Add(...node.args.map((arg) => sympify(arg)));
                         case 'Pow': return new sympy.core.power.Pow(...node.args.map((arg) => sympify(arg)));
+                        case 'Max': return new sympy.functions.elementary.miscellaneous.Max(...node.args.map((arg) => sympify(arg)));
                         case 'Integer': return new sympy.core.numbers.Integer(node.args[0].value);
+                        case 'GreaterThan': return new sympy.core.relational.GreaterThan(sympify(node.args[0]), sympify(node.args[1]));
+                        case 'LessThan': return new sympy.core.relational.LessThan(sympify(node.args[0]), sympify(node.args[1]));
+                        case 'Equality': return new sympy.core.relational.Equality(sympify(node.args[0]), sympify(node.args[1]));
                         default: throw new python.Error(`Unsupported SymPy function '${node.func.id}'.`);
                     }
                 }
@@ -19233,7 +19244,7 @@ python.Execution = class {
                     if (sym_arg.type === 'as_bool') {
                         return sym_arg.as_bool;
                     } else if (sym_arg.type === 'as_name') {
-                        return self.serialized_name_to_node.get(sym_arg.as_name);
+                        return this.serialized_name_to_node.get(sym_arg.as_name);
                     }
                 }
                 throw new python.Error(`Unsupported symbolic argument type '${sym_arg.type}`);

+ 0 - 3
source/pytorch.js

@@ -273,9 +273,6 @@ pytorch.Graph = class {
                     if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
                         continue;
                     }
-                    if (obj.users.size === 0) {
-                        continue;
-                    }
                 }
                 if (obj.op === 'output') {
                     for (const output of obj.args) {

+ 7 - 0
test/models.json

@@ -5731,6 +5731,13 @@
     "format":   "PyTorch v0.1.10",
     "link":     "https://github.com/lutzroeder/netron/issues/286"
   },
+  {
+    "type":     "pytorch",
+    "target":   "draft_export.pt2",
+    "source":   "https://github.com/user-attachments/files/22877643/draft_export.pt2.zip[draft_export.pt2]",
+    "format":   "PyTorch Export v8.14",
+    "link":     "https://github.com/lutzroeder/netron/issues/1211"
+  },
   {
     "type":     "pytorch",
     "target":   "DRNL4x_dual_model",