|
|
@@ -14626,6 +14626,12 @@ python.Execution = class {
|
|
|
}
|
|
|
return this._value;
|
|
|
}
|
|
|
+ type() {
|
|
|
+ if (this._value) {
|
|
|
+ return this._value.type();
|
|
|
+ }
|
|
|
+ return this._ivalue.type();
|
|
|
+ }
|
|
|
});
|
|
|
this.registerType('torch._C.SugaredValue', class {
|
|
|
kind() {
|
|
|
@@ -16564,6 +16570,8 @@ python.Execution = class {
|
|
|
return this.emitConst(new ast.Constant(-Infinity, 'float'));
|
|
|
} else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub) {
|
|
|
return this.emitUnaryOp(tree, '__neg__', 'aten::neg');
|
|
|
+ } else if (tree instanceof ast.BinOp) {
|
|
|
+ return this.emitBinaryOp(tree);
|
|
|
} else if (tree instanceof ast.Dict) {
|
|
|
return this.emitDictLiteral(tree, type_hint);
|
|
|
} else if (tree instanceof ast.Tuple) {
|
|
|
@@ -16572,6 +16580,90 @@ python.Execution = class {
|
|
|
}
|
|
|
throw new python.Error(`Simple expression '${tree.__class__.__name__}' not implemented.`);
|
|
|
}
|
|
|
+ getNodeKind(kind /*, ninputs */) {
|
|
|
+ if (kind instanceof ast.Add) {
|
|
|
+ return 'aten::add';
|
|
|
+ } else if (kind instanceof ast.Sub) {
|
|
|
+ return 'aten::sub';
|
|
|
+ } else if (kind instanceof ast.Mult) {
|
|
|
+ return 'aten::mul';
|
|
|
+ }
|
|
|
+ /*
|
|
|
+ case TK_UNARY_MINUS: return 'aten::neg';
|
|
|
+ case TK_POW: return 'aten::pow';
|
|
|
+ case '@': return 'aten::matmul';
|
|
|
+ case TK_STARRED: return 'prim::Starred';
|
|
|
+ case '/': return 'aten::div';
|
|
|
+ case '%': return 'aten::remainder';
|
|
|
+ case TK_NE: return 'aten::ne';
|
|
|
+ case TK_EQ: return 'aten::eq';
|
|
|
+ case '<': return 'aten::lt';
|
|
|
+ case '>': return 'aten::gt';
|
|
|
+ case TK_LE: return 'aten::le';
|
|
|
+ case TK_GE: return 'aten::ge';
|
|
|
+ case TK_AND: return 'aten::__and__';
|
|
|
+ case TK_OR: return 'aten::__or__';
|
|
|
+ case TK_IS: return 'aten::__is__';
|
|
|
+ case TK_ISNOT: return 'aten::__isnot__';
|
|
|
+ case TK_NOT: return 'aten::__not__';
|
|
|
+ case TK_FLOOR_DIV: return 'aten::floordiv';
|
|
|
+ case TK_LSHIFT: return 'aten::__lshift__';
|
|
|
+ case TK_RSHIFT: return 'aten::__rshift__';
|
|
|
+ case '&': return 'aten::__and__';
|
|
|
+ case '|': return 'aten::__or__';
|
|
|
+ case '^': return 'aten::__xor__';
|
|
|
+ case TK_IN: return 'aten::__contains__';
|
|
|
+ */
|
|
|
+ throw new python.Error(`Unknown kind '${kind.__class__.__name__}'.`);
|
|
|
+ }
|
|
|
+ getOperatorOverload(kind /*, ninputs */) {
|
|
|
+ if (kind instanceof ast.Add) {
|
|
|
+ return '__add__';
|
|
|
+ } else if (kind instanceof ast.Sub) {
|
|
|
+ return '__sub__';
|
|
|
+ } else if (kind instanceof ast.Mult) {
|
|
|
+ return '__mul__';
|
|
|
+ }
|
|
|
+ /*
|
|
|
+ case TK_UNARY_MINUS: return "__neg__";
|
|
|
+ case '~': return "__invert__";
|
|
|
+ case TK_POW: return "__pow__";
|
|
|
+ case '/': return "__truediv__";
|
|
|
+ case '%': return "__mod__";
|
|
|
+ case TK_NE: return "__ne__";
|
|
|
+ case TK_EQ: return "__eq__";
|
|
|
+ case '<': return "__lt__";
|
|
|
+ case '>': return "__gt__";
|
|
|
+ case TK_LE: return "__le__";
|
|
|
+ case TK_GE: return "__ge__";
|
|
|
+ case '&': return "__and__";
|
|
|
+ case '|': return "__or__";
|
|
|
+ case '^': return "__xor__";
|
|
|
+ case TK_IN: return "__contains__";
|
|
|
+ case TK_LSHIFT: return "__lshift__";
|
|
|
+ case TK_RSHIFT: return "__rshift__";
|
|
|
+ */
|
|
|
+ throw new python.Error(`Unknown kind '${kind.__class__.__name__}'.`);
|
|
|
+ }
|
|
|
+ emitBinaryOp(tree) {
|
|
|
+ const inputs = [tree.left, tree.right];
|
|
|
+ const kind = this.getNodeKind(tree.op, inputs.length);
|
|
|
+ const overload = this.getOperatorOverload(tree.op, inputs.length);
|
|
|
+ const named_values = this.getNamedValues(inputs, /*maybe_unpack=*/false);
|
|
|
+ if (tree.op instanceof ast.In) {
|
|
|
+ // std::iter_swap(named_values.begin() + 0, named_values.begin() + 1);
|
|
|
+ throw new python.Error('Not implemented.');
|
|
|
+ }
|
|
|
+ if (named_values[0].type() instanceof torch.TupleType &&
|
|
|
+ named_values[1].type() instanceof torch.TupleType &&
|
|
|
+ kind === 'aten::add') {
|
|
|
+ const first_tuple = torch._C.createTupleUnpack(named_values[0].value(this.graph)).vec();
|
|
|
+ const second_tuple = torch._C.createTupleUnpack(named_values[1].value(this.graph)).vec();
|
|
|
+ first_tuple.insert(first_tuple.end(), second_tuple.begin(), second_tuple.end());
|
|
|
+ return this.graph.insertNode(this.graph.createTuple(first_tuple)).output();
|
|
|
+ }
|
|
|
+ return torch._C.asSimple(torch._C.makeMagic(overload, new torch._C.BuiltinFunction(kind, null)).call(tree.range(), this.method, named_values, [], 0));
|
|
|
+ }
|
|
|
emitDictLiteral(dl, type_hint) {
|
|
|
const key_trees = dl.keys;
|
|
|
const value_trees = dl.values;
|