|
|
@@ -1385,28 +1385,23 @@ python.Execution = class {
|
|
|
}
|
|
|
const literal = this._parseLiteral();
|
|
|
if (literal) {
|
|
|
- if (stack.length > 0 &&
|
|
|
- (literal.type === 'int' || literal.type === 'float' || literal.type === 'complex') &&
|
|
|
- (literal.value.startsWith('-') || literal.value.startsWith('+'))) {
|
|
|
- const op = literal.value < 0 ? new ast.Sub() : new ast.Add();
|
|
|
- const left = stack.pop();
|
|
|
- const right = new ast.Constant(Math.abs(literal.value));
|
|
|
- node = new ast.BinOp(left, op, right);
|
|
|
- stack.push(node);
|
|
|
- } else if (stack.length === 1 && literal.type === 'str' && stack[0] instanceof ast.Constant && typeof stack[0].value === 'string') {
|
|
|
+ if (stack.length === 1 && literal.type === 'str' && stack[0] instanceof ast.Constant && typeof stack[0].value === 'string') {
|
|
|
stack[0].value += literal.value.substring(1, literal.value.length - 1);
|
|
|
} else {
|
|
|
let value = literal.value;
|
|
|
- if (literal.type === 'int' || literal.type === 'float' || literal.type === 'complex') {
|
|
|
- switch (value) {
|
|
|
- case 'inf': value = Infinity; break;
|
|
|
- case '-inf': value = -Infinity; break;
|
|
|
- default: value = Number(value); break;
|
|
|
- }
|
|
|
- } else if (literal.type === 'str') {
|
|
|
- value = literal.value.substring(1, literal.value.length - 1);
|
|
|
- } else {
|
|
|
- throw new python.Error(`Invalid literal ${this._location()}`);
|
|
|
+ switch (literal.type) {
|
|
|
+ case 'int':
|
|
|
+ case 'float':
|
|
|
+ value = value === 'inf' ? Infinity : Number(value);
|
|
|
+ break;
|
|
|
+ case 'complex':
|
|
|
+ value = new builtins.complex(0, Number(value.slice(0, -1)));
|
|
|
+ break;
|
|
|
+ case 'str':
|
|
|
+ value = value.substring(1, value.length - 1);
|
|
|
+ break;
|
|
|
+ default:
|
|
|
+ throw new python.Error(`Invalid literal type '${literal.type}' ${this._location()}`);
|
|
|
}
|
|
|
const node = new ast.Constant(value, literal.type);
|
|
|
this._mark(node, position);
|
|
|
@@ -2060,8 +2055,7 @@ python.Execution = class {
|
|
|
const decimal = (c) => c >= '0' && c <= '9' || c === '_';
|
|
|
const hex = (c) => decimal(c) || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') || c === '_';
|
|
|
let c = this._get(this._position);
|
|
|
- const sign = (c === '-' || c === '+') ? 1 : 0;
|
|
|
- let i = this._position + sign;
|
|
|
+ let i = this._position;
|
|
|
c = this._get(i);
|
|
|
if (c === '0') {
|
|
|
let radix = 0;
|
|
|
@@ -2105,7 +2099,7 @@ python.Execution = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- i = this._position + sign;
|
|
|
+ i = this._position;
|
|
|
let isDecimal = false;
|
|
|
if (this._get(i) >= '1' && this._get(i) <= '9') {
|
|
|
while (decimal(this._get(i))) {
|
|
|
@@ -2131,7 +2125,7 @@ python.Execution = class {
|
|
|
return { type: 'int', value: intText };
|
|
|
}
|
|
|
}
|
|
|
- i = this._position + sign;
|
|
|
+ i = this._position;
|
|
|
if ((this._get(i) >= '0' && this._get(i) <= '9') ||
|
|
|
(this._get(i) === '.' && this._get(i + 1) >= '0' && this._get(i + 1) <= '9')) {
|
|
|
while (decimal(this._get(i))) {
|
|
|
@@ -2143,7 +2137,7 @@ python.Execution = class {
|
|
|
while (decimal(this._get(i))) {
|
|
|
i++;
|
|
|
}
|
|
|
- if (i > (this._position + sign)) {
|
|
|
+ if (i > this._position) {
|
|
|
if (this._get(i) === 'e' || this._get(i) === 'E') {
|
|
|
i++;
|
|
|
if (this._get(i) === '-' || this._get(i) === '+') {
|
|
|
@@ -2162,7 +2156,7 @@ python.Execution = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- if (i > (this._position + sign)) {
|
|
|
+ if (i > this._position) {
|
|
|
if (this._get(i) === 'j' || this._get(i) === 'J') {
|
|
|
return { type: 'complex', value: this._text.substring(this._position, i + 1) };
|
|
|
}
|
|
|
@@ -2407,6 +2401,9 @@ python.Execution = class {
|
|
|
this.real = real;
|
|
|
this.imag = imaginary;
|
|
|
}
|
|
|
+ toString() {
|
|
|
+ return `${this.real}${this.imag < 0 ? '' : '+'}${this.imag}j`;
|
|
|
+ }
|
|
|
});
|
|
|
this.registerType('builtins.NoneType', class {});
|
|
|
this.registerType('builtins.object', class {
|
|
|
@@ -6604,6 +6601,14 @@ python.Execution = class {
|
|
|
}
|
|
|
break;
|
|
|
}
|
|
|
+ case 'c': {
|
|
|
+ const lc = lhs.c(name);
|
|
|
+ const rc = rhs.c(name);
|
|
|
+ if (lc.real !== rc.real || lc.imag !== rc.imag) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
case 'ival': {
|
|
|
if (lhs[kind](name) !== rhs[kind](name)) {
|
|
|
return false;
|
|
|
@@ -6628,6 +6633,8 @@ python.Execution = class {
|
|
|
for (const item of value) {
|
|
|
hash += torch._C.get_hash(item);
|
|
|
}
|
|
|
+ } else if (value instanceof builtins.complex) {
|
|
|
+ hash += (value.real | 0) + (value.imag | 0);
|
|
|
}
|
|
|
}
|
|
|
return hash;
|
|
|
@@ -9099,15 +9106,24 @@ python.Execution = class {
|
|
|
this.registerOperator('aten::replace', (value, oldvalue, newvalue /*, max */) => {
|
|
|
return value.replace(oldvalue, newvalue);
|
|
|
});
|
|
|
- this.registerOperator('aten::add', (left, right) => {
|
|
|
- if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
|
|
|
- return left + right;
|
|
|
+ this.registerOperator('aten::add', (a, b) => {
|
|
|
+ if ((typeof a === 'number' || a instanceof Number) && (typeof b === 'number' || b instanceof Number)) {
|
|
|
+ return a + b;
|
|
|
}
|
|
|
- if (Array.isArray(left) && Array.isArray(right)) {
|
|
|
- return left.concat(right);
|
|
|
+ if (typeof a === 'number' && b instanceof builtins.complex) {
|
|
|
+ return new builtins.complex(a + b.real, b.imag);
|
|
|
}
|
|
|
- if (typeof left === 'string' && typeof right === 'string') {
|
|
|
- return left + right;
|
|
|
+ if (a instanceof builtins.complex && typeof b === 'number') {
|
|
|
+ return new builtins.complex(a.real + b, a.imag);
|
|
|
+ }
|
|
|
+ if (a instanceof builtins.complex && b instanceof builtins.complex) {
|
|
|
+ return new builtins.complex(a.real + b.real, a.imag + b.imag);
|
|
|
+ }
|
|
|
+ if (Array.isArray(a) && Array.isArray(b)) {
|
|
|
+ return a.concat(b);
|
|
|
+ }
|
|
|
+ if (typeof a === 'string' && typeof b === 'string') {
|
|
|
+ return a + b;
|
|
|
}
|
|
|
throw new python.Error('Unsupported aten::add expression type.');
|
|
|
});
|
|
|
@@ -9461,6 +9477,15 @@ python.Execution = class {
|
|
|
}
|
|
|
return _legacy_load(f);
|
|
|
});
|
|
|
+ this.registerOperator('prim::abs', (a) => {
|
|
|
+ if (typeof a === 'number' || a instanceof Number) {
|
|
|
+ return Math.abs(a);
|
|
|
+ }
|
|
|
+ if (a instanceof builtins.complex) {
|
|
|
+ return Math.hypot(a.real, a.imag);
|
|
|
+ }
|
|
|
+ throw new python.Error('Unsupported prim::abs expression type.');
|
|
|
+ });
|
|
|
this.registerOperator('prim::unchecked_cast', (type, value) => {
|
|
|
return value;
|
|
|
});
|
|
|
@@ -9744,6 +9769,15 @@ python.Execution = class {
|
|
|
if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
|
|
|
return left - right;
|
|
|
}
|
|
|
+ if (typeof left === 'number' && right instanceof builtins.complex) {
|
|
|
+ return new builtins.complex(left - right.real, right.imag);
|
|
|
+ }
|
|
|
+ if (left instanceof builtins.complex && typeof right === 'number') {
|
|
|
+ return new builtins.complex(left.real - right, left.imag);
|
|
|
+ }
|
|
|
+ if (left instanceof builtins.complex && right instanceof builtins.complex) {
|
|
|
+ return new builtins.complex(left.real - right.real, left.imag - right.imag);
|
|
|
+ }
|
|
|
throw new python.Error("Unsupported 'torch.sub' expression type.");
|
|
|
});
|
|
|
this.registerFunction('torch.sym_int');
|
|
|
@@ -10741,7 +10775,7 @@ python.Execution = class {
|
|
|
return this.kind() === rhs.kind();
|
|
|
}
|
|
|
isSubtypeOf(rhs) {
|
|
|
- return this.kind() === 'NumberType' || super.isSubtypeOf(rhs);
|
|
|
+ return rhs.kind() === 'NumberType' || super.isSubtypeOf(rhs);
|
|
|
}
|
|
|
str() {
|
|
|
return 'complex';
|
|
|
@@ -12439,6 +12473,13 @@ python.Execution = class {
|
|
|
f(name) {
|
|
|
return this._values.get(name)[0];
|
|
|
}
|
|
|
+ c_(name, value) {
|
|
|
+ this._values.set(name, [value, 'c']);
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+ c(name) {
|
|
|
+ return this._values.get(name)[0];
|
|
|
+ }
|
|
|
t_(name, value) {
|
|
|
this._values.set(name, [value, 't']);
|
|
|
return this;
|
|
|
@@ -12762,6 +12803,8 @@ python.Execution = class {
|
|
|
this.tag = 'Int';
|
|
|
} else if (typeof value === 'number') {
|
|
|
this.tag = 'Double';
|
|
|
+ } else if (value instanceof builtins.complex) {
|
|
|
+ this.tag = 'ComplexDouble';
|
|
|
} else if (value instanceof torch._C.EnumHolder) {
|
|
|
this.tag = 'Enum';
|
|
|
} else {
|
|
|
@@ -12796,7 +12839,10 @@ python.Execution = class {
|
|
|
return this.value;
|
|
|
}
|
|
|
isComplexDouble() {
|
|
|
- return this.tag === 'ComplexDouble';
|
|
|
+ return this.tag === 'ComplexDouble' || this.tag === 'Complex';
|
|
|
+ }
|
|
|
+ toComplexDouble() {
|
|
|
+ return this.value;
|
|
|
}
|
|
|
isInt() {
|
|
|
return this.tag === 'Int';
|
|
|
@@ -14975,7 +15021,7 @@ python.Execution = class {
|
|
|
return new torch._C.IValue(node.i('value'), 'Int');
|
|
|
} else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'f') {
|
|
|
return new torch._C.IValue(node.f('value'), 'Double');
|
|
|
- } else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'c') {
|
|
|
+ } else if (type.isSubtypeOf(torch.ComplexType.get()) && node.kindOf('value') === 'c') {
|
|
|
return new torch._C.IValue(node.c('value'), 'Complex');
|
|
|
} else if (type instanceof torch.ListType && node.kindOf('value') === 'ival') {
|
|
|
let list = node.ival('value');
|
|
|
@@ -16634,7 +16680,7 @@ python.Execution = class {
|
|
|
if (val.node().kind() !== opSymbol) {
|
|
|
return val;
|
|
|
}
|
|
|
- const maybe_out_stack = this.runNodeIfInputsAreConstant(val.node());
|
|
|
+ const maybe_out_stack = torch._C.runNodeIfInputsAreConstant(val.node());
|
|
|
if (!maybe_out_stack) {
|
|
|
return val;
|
|
|
}
|
|
|
@@ -16990,6 +17036,12 @@ python.Execution = class {
|
|
|
return this.emitListLiteral(tree, type_hint);
|
|
|
} else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub && tree.operand instanceof ast.Name && tree.operand.id === 'inf') {
|
|
|
return this.emitConst(new ast.Constant(-Infinity, 'float'));
|
|
|
+ } else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub && tree.operand instanceof ast.Constant) {
|
|
|
+ const c = tree.operand;
|
|
|
+ if (c.type === 'complex') {
|
|
|
+ return this.emitConst(new ast.Constant(new builtins.complex(-c.value.real, -c.value.imag), 'complex'));
|
|
|
+ }
|
|
|
+ return this.emitConst(new ast.Constant(-c.value, c.type));
|
|
|
} else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub) {
|
|
|
return this.emitUnaryOp(tree, '__neg__', 'aten::neg');
|
|
|
} else if (tree instanceof ast.BinOp) {
|