ソースを参照

Add TorchScript test file (#842)

Lutz Roeder 5 ヶ月 前
コミット
3995f95c0d
2 ファイル変更99 行追加0 行削除
  1. 92 0
      source/python.js
  2. 7 0
      test/models.json

+ 92 - 0
source/python.js

@@ -14626,6 +14626,12 @@ python.Execution = class {
                 }
                 return this._value;
             }
+            type() {
+                if (this._value) {
+                    return this._value.type();
+                }
+                return this._ivalue.type();
+            }
         });
         this.registerType('torch._C.SugaredValue', class {
             kind() {
@@ -16564,6 +16570,8 @@ python.Execution = class {
                     return this.emitConst(new ast.Constant(-Infinity, 'float'));
                 } else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub) {
                     return this.emitUnaryOp(tree, '__neg__', 'aten::neg');
+                } else if (tree instanceof ast.BinOp) {
+                    return this.emitBinaryOp(tree);
                 } else if (tree instanceof ast.Dict) {
                     return this.emitDictLiteral(tree, type_hint);
                 } else if (tree instanceof ast.Tuple) {
@@ -16572,6 +16580,90 @@ python.Execution = class {
                 }
                 throw new python.Error(`Simple expression '${tree.__class__.__name__}' not implemented.`);
             }
+            getNodeKind(kind /*, ninputs */) {
+                if (kind instanceof ast.Add) {
+                    return 'aten::add';
+                } else if (kind instanceof ast.Sub) {
+                    return 'aten::sub';
+                } else if (kind instanceof ast.Mult) {
+                    return 'aten::mul';
+                }
+                /*
+                case TK_UNARY_MINUS: return 'aten::neg';
+                case TK_POW: return 'aten::pow';
+                case '@': return 'aten::matmul';
+                case TK_STARRED: return 'prim::Starred';
+                case '/': return 'aten::div';
+                case '%': return 'aten::remainder';
+                case TK_NE: return 'aten::ne';
+                case TK_EQ: return 'aten::eq';
+                case '<': return 'aten::lt';
+                case '>': return 'aten::gt';
+                case TK_LE: return 'aten::le';
+                case TK_GE: return 'aten::ge';
+                case TK_AND: return 'aten::__and__';
+                case TK_OR: return 'aten::__or__';
+                case TK_IS: return 'aten::__is__';
+                case TK_ISNOT: return 'aten::__isnot__';
+                case TK_NOT: return 'aten::__not__';
+                case TK_FLOOR_DIV: return 'aten::floordiv';
+                case TK_LSHIFT: return 'aten::__lshift__';
+                case TK_RSHIFT: return 'aten::__rshift__';
+                case '&': return 'aten::__and__';
+                case '|': return 'aten::__or__';
+                case '^': return 'aten::__xor__';
+                case TK_IN: return 'aten::__contains__';
+                */
+                throw new python.Error(`Unknown kind '${kind.__class__.__name__}'.`);
+            }
+            getOperatorOverload(kind /*, ninputs */) {
+                if (kind instanceof ast.Add) {
+                    return '__add__';
+                } else if (kind instanceof ast.Sub) {
+                    return '__sub__';
+                } else if (kind instanceof ast.Mult) {
+                    return '__mul__';
+                }
+                /*
+                case TK_UNARY_MINUS: return "__neg__";
+                case '~': return "__invert__";
+                case TK_POW: return "__pow__";
+                case '/': return "__truediv__";
+                case '%': return "__mod__";
+                case TK_NE: return "__ne__";
+                case TK_EQ: return "__eq__";
+                case '<': return "__lt__";
+                case '>': return "__gt__";
+                case TK_LE: return "__le__";
+                case TK_GE: return "__ge__";
+                case '&': return "__and__";
+                case '|': return "__or__";
+                case '^': return "__xor__";
+                case TK_IN: return "__contains__";
+                case TK_LSHIFT: return "__lshift__";
+                case TK_RSHIFT: return "__rshift__";
+                */
+                throw new python.Error(`Unknown kind '${kind.__class__.__name__}'.`);
+            }
+            emitBinaryOp(tree) {
+                const inputs = [tree.left, tree.right];
+                const kind = this.getNodeKind(tree.op, inputs.length);
+                const overload = this.getOperatorOverload(tree.op, inputs.length);
+                const named_values = this.getNamedValues(inputs, /*maybe_unpack=*/false);
+                if (tree.op instanceof ast.In) {
+                    // std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
+                    throw new python.Error('Not implemented.');
+                }
+                if (named_values[0].type() instanceof torch.TupleType &&
+                    named_values[1].type() instanceof torch.TupleType &&
+                    kind === 'aten::add') {
+                    const first_tuple = torch._C.createTupleUnpack(named_values[0].value(this.graph)).vec();
+                    const second_tuple = torch._C.createTupleUnpack(named_values[1].value(this.graph)).vec();
+                    first_tuple.insert(first_tuple.end(), second_tuple.begin(), second_tuple.end());
+                    return this.graph.insertNode(this.graph.createTuple(first_tuple)).output();
+                }
+                return torch._C.asSimple(torch._C.makeMagic(overload, new torch._C.BuiltinFunction(kind, null)).call(tree.range(), this.method, named_values, [], 0));
+            }
             emitDictLiteral(dl, type_hint) {
                 const key_trees = dl.keys;
                 const value_trees = dl.values;

+ 7 - 0
test/models.json

@@ -5504,6 +5504,13 @@
     "format":   "PyTorch v1.6",
     "link":     "https://github.com/lutzroeder/netron/issues/543"
   },
+  {
+    "type":     "pytorch",
+    "target":   "binop.pt",
+    "source":   "https://github.com/user-attachments/files/22703318/binop.pt.zip[binop.pt]",
+    "format":   "TorchScript v1.6",
+    "link":     "https://github.com/lutzroeder/netron/issues/842"
+  },
   {
     "type":     "pytorch",
     "target":   "cloudpickle.pth",