소스 검색

Add PyTorch test file (#720)

Lutz Roeder 1 개월 전
부모
커밋
ead8d24019
5개의 변경된 파일102개의 추가작업 그리고 37개의 파일을 삭제
  1. 88 36
      source/python.js
  2. 3 0
      source/pytorch.js
  3. 2 0
      source/view.js
  4. 8 0
      test/models.json
  5. 1 1
      test/worker.js

+ 88 - 36
source/python.js

@@ -1385,28 +1385,23 @@ python.Execution = class {
                     }
                     const literal = this._parseLiteral();
                     if (literal) {
-                        if (stack.length > 0 &&
-                            (literal.type === 'int' || literal.type === 'float' || literal.type === 'complex') &&
-                            (literal.value.startsWith('-') || literal.value.startsWith('+'))) {
-                            const op = literal.value < 0 ? new ast.Sub() : new ast.Add();
-                            const left = stack.pop();
-                            const right = new ast.Constant(Math.abs(literal.value));
-                            node = new ast.BinOp(left, op, right);
-                            stack.push(node);
-                        } else if (stack.length === 1 && literal.type === 'str' && stack[0] instanceof ast.Constant && typeof stack[0].value === 'string') {
+                        if (stack.length === 1 && literal.type === 'str' && stack[0] instanceof ast.Constant && typeof stack[0].value === 'string') {
                             stack[0].value += literal.value.substring(1, literal.value.length - 1);
                         } else {
                             let value = literal.value;
-                            if (literal.type === 'int' || literal.type === 'float' || literal.type === 'complex') {
-                                switch (value) {
-                                    case 'inf': value = Infinity; break;
-                                    case '-inf': value = -Infinity; break;
-                                    default: value = Number(value); break;
-                                }
-                            } else if (literal.type === 'str') {
-                                value = literal.value.substring(1, literal.value.length - 1);
-                            } else {
-                                throw new python.Error(`Invalid literal ${this._location()}`);
+                            switch (literal.type) {
+                                case 'int':
+                                case 'float':
+                                    value = value === 'inf' ? Infinity : Number(value);
+                                    break;
+                                case 'complex':
+                                    value = new builtins.complex(0, Number(value.slice(0, -1)));
+                                    break;
+                                case 'str':
+                                    value = value.substring(1, value.length - 1);
+                                    break;
+                                default:
+                                    throw new python.Error(`Invalid literal type '${literal.type}' ${this._location()}`);
                             }
                             const node = new ast.Constant(value, literal.type);
                             this._mark(node, position);
@@ -2060,8 +2055,7 @@ python.Execution = class {
                 const decimal = (c) => c >= '0' && c <= '9' || c === '_';
                 const hex = (c) => decimal(c) || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') || c === '_';
                 let c = this._get(this._position);
-                const sign = (c === '-' || c === '+') ? 1 : 0;
-                let i = this._position + sign;
+                let i = this._position;
                 c = this._get(i);
                 if (c === '0') {
                     let radix = 0;
@@ -2105,7 +2099,7 @@ python.Execution = class {
                         }
                     }
                 }
-                i = this._position + sign;
+                i = this._position;
                 let isDecimal = false;
                 if (this._get(i) >= '1' && this._get(i) <= '9') {
                     while (decimal(this._get(i))) {
@@ -2131,7 +2125,7 @@ python.Execution = class {
                         return { type: 'int', value: intText };
                     }
                 }
-                i = this._position + sign;
+                i = this._position;
                 if ((this._get(i) >= '0' && this._get(i) <= '9') ||
                     (this._get(i) === '.' && this._get(i + 1) >= '0' && this._get(i + 1) <= '9')) {
                     while (decimal(this._get(i))) {
@@ -2143,7 +2137,7 @@ python.Execution = class {
                     while (decimal(this._get(i))) {
                         i++;
                     }
-                    if (i > (this._position + sign)) {
+                    if (i > this._position) {
                         if (this._get(i) === 'e' || this._get(i) === 'E') {
                             i++;
                             if (this._get(i) === '-' || this._get(i) === '+') {
@@ -2162,7 +2156,7 @@ python.Execution = class {
                             }
                         }
                     }
-                    if (i > (this._position + sign)) {
+                    if (i > this._position) {
                         if (this._get(i) === 'j' || this._get(i) === 'J') {
                             return { type: 'complex', value: this._text.substring(this._position, i + 1) };
                         }
@@ -2407,6 +2401,9 @@ python.Execution = class {
                 this.real = real;
                 this.imag = imaginary;
             }
+            toString() {
+                return `${this.real}${this.imag < 0 ? '' : '+'}${this.imag}j`;
+            }
         });
         this.registerType('builtins.NoneType', class {});
         this.registerType('builtins.object', class {
@@ -6604,6 +6601,14 @@ python.Execution = class {
                         }
                         break;
                     }
+                    case 'c': {
+                        const lc = lhs.c(name);
+                        const rc = rhs.c(name);
+                        if (lc.real !== rc.real || lc.imag !== rc.imag) {
+                            return false;
+                        }
+                        break;
+                    }
                     case 'ival': {
                         if (lhs[kind](name) !== rhs[kind](name)) {
                             return false;
@@ -6628,6 +6633,8 @@ python.Execution = class {
                     for (const item of value) {
                         hash += torch._C.get_hash(item);
                     }
+                } else if (value instanceof builtins.complex) {
+                    hash += (value.real | 0) + (value.imag | 0);
                 }
             }
             return hash;
@@ -9099,15 +9106,24 @@ python.Execution = class {
         this.registerOperator('aten::replace', (value, oldvalue, newvalue /*, max */) => {
             return value.replace(oldvalue, newvalue);
         });
-        this.registerOperator('aten::add', (left, right) => {
-            if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
-                return left + right;
+        this.registerOperator('aten::add', (a, b) => {
+            if ((typeof a === 'number' || a instanceof Number) && (typeof b === 'number' || b instanceof Number)) {
+                return a + b;
             }
-            if (Array.isArray(left) && Array.isArray(right)) {
-                return left.concat(right);
+            if (typeof a === 'number' && b instanceof builtins.complex) {
+                return new builtins.complex(a + b.real, b.imag);
             }
-            if (typeof left === 'string' && typeof right === 'string') {
-                return left + right;
+            if (a instanceof builtins.complex && typeof b === 'number') {
+                return new builtins.complex(a.real + b, a.imag);
+            }
+            if (a instanceof builtins.complex && b instanceof builtins.complex) {
+                return new builtins.complex(a.real + b.real, a.imag + b.imag);
+            }
+            if (Array.isArray(a) && Array.isArray(b)) {
+                return a.concat(b);
+            }
+            if (typeof a === 'string' && typeof b === 'string') {
+                return a + b;
             }
             throw new python.Error('Unsupported aten::add expression type.');
         });
@@ -9461,6 +9477,15 @@ python.Execution = class {
             }
             return _legacy_load(f);
         });
+        this.registerOperator('prim::abs', (a) => {
+            if (typeof a === 'number' || a instanceof Number) {
+                return Math.abs(a);
+            }
+            if (a instanceof builtins.complex) {
+                return Math.hypot(a.real, a.imag);
+            }
+            throw new python.Error('Unsupported prim::abs expression type.');
+        });
         this.registerOperator('prim::unchecked_cast', (type, value) => {
             return value;
         });
@@ -9744,6 +9769,15 @@ python.Execution = class {
             if ((typeof left === 'number' || left instanceof Number) && (typeof right === 'number' || right instanceof Number)) {
                 return left - right;
             }
+            if (typeof left === 'number' && right instanceof builtins.complex) {
+                return new builtins.complex(left - right.real, right.imag);
+            }
+            if (left instanceof builtins.complex && typeof right === 'number') {
+                return new builtins.complex(left.real - right, left.imag);
+            }
+            if (left instanceof builtins.complex && right instanceof builtins.complex) {
+                return new builtins.complex(left.real - right.real, left.imag - right.imag);
+            }
             throw new python.Error("Unsupported 'torch.sub' expression type.");
         });
         this.registerFunction('torch.sym_int');
@@ -10741,7 +10775,7 @@ python.Execution = class {
                 return this.kind() === rhs.kind();
             }
             isSubtypeOf(rhs) {
-                return this.kind() === 'NumberType' || super.isSubtypeOf(rhs);
+                return rhs.kind() === 'NumberType' || super.isSubtypeOf(rhs);
             }
             str() {
                 return 'complex';
@@ -12439,6 +12473,13 @@ python.Execution = class {
             f(name) {
                 return this._values.get(name)[0];
             }
+            c_(name, value) {
+                this._values.set(name, [value, 'c']);
+                return this;
+            }
+            c(name) {
+                return this._values.get(name)[0];
+            }
             t_(name, value) {
                 this._values.set(name, [value, 't']);
                 return this;
@@ -12762,6 +12803,8 @@ python.Execution = class {
                     this.tag = 'Int';
                 } else if (typeof value === 'number') {
                     this.tag = 'Double';
+                } else if (value instanceof builtins.complex) {
+                    this.tag = 'ComplexDouble';
                 } else if (value instanceof torch._C.EnumHolder) {
                     this.tag = 'Enum';
                 } else {
@@ -12796,7 +12839,10 @@ python.Execution = class {
                 return this.value;
             }
             isComplexDouble() {
-                return this.tag === 'ComplexDouble';
+                return this.tag === 'ComplexDouble' || this.tag === 'Complex';
+            }
+            toComplexDouble() {
+                return this.value;
             }
             isInt() {
                 return this.tag === 'Int';
@@ -14975,7 +15021,7 @@ python.Execution = class {
                 return new torch._C.IValue(node.i('value'), 'Int');
             } else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'f') {
                 return new torch._C.IValue(node.f('value'), 'Double');
-            } else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'c') {
+            } else if (type.isSubtypeOf(torch.ComplexType.get()) && node.kindOf('value') === 'c') {
                 return new torch._C.IValue(node.c('value'), 'Complex');
             } else if (type instanceof torch.ListType && node.kindOf('value') === 'ival') {
                 let list = node.ival('value');
@@ -16634,7 +16680,7 @@ python.Execution = class {
                 if (val.node().kind() !== opSymbol) {
                     return val;
                 }
-                const maybe_out_stack = this.runNodeIfInputsAreConstant(val.node());
+                const maybe_out_stack = torch._C.runNodeIfInputsAreConstant(val.node());
                 if (!maybe_out_stack) {
                     return val;
                 }
@@ -16990,6 +17036,12 @@ python.Execution = class {
                     return this.emitListLiteral(tree, type_hint);
                 } else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub && tree.operand instanceof ast.Name && tree.operand.id === 'inf') {
                     return this.emitConst(new ast.Constant(-Infinity, 'float'));
+                } else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub && tree.operand instanceof ast.Constant) {
+                    const c = tree.operand;
+                    if (c.type === 'complex') {
+                        return this.emitConst(new ast.Constant(new builtins.complex(-c.value.real, -c.value.imag), 'complex'));
+                    }
+                    return this.emitConst(new ast.Constant(-c.value, c.type));
                 } else if (tree instanceof ast.UnaryOp && tree.op instanceof ast.USub) {
                     return this.emitUnaryOp(tree, '__neg__', 'aten::neg');
                 } else if (tree instanceof ast.BinOp) {

+ 3 - 0
source/pytorch.js

@@ -419,6 +419,9 @@ pytorch.Node = class {
                         if (value && value instanceof torch._C.IValue) {
                             value = pytorch.Utility.toString(value);
                         }
+                        if (value && value instanceof builtins.complex) {
+                            value = new base.Complex(value.real, value.imag);
+                        }
                         argument = new pytorch.Argument(name, value, type || 'attribute');
                     } else if (input.type() instanceof torch.ListType) {
                         if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 &&

+ 2 - 0
source/view.js

@@ -4698,6 +4698,8 @@ view.Formatter = class {
                 return value ? value.toString() : '(null)';
             case 'type[]':
                 return value ? value.map((item) => item.toString()).join(', ') : '(null)';
+            case 'complex':
+                return value ? value.toString() : '(null)';
             default:
                 break;
         }

+ 8 - 0
test/models.json

@@ -7028,6 +7028,14 @@
     "assert":   "model.modules[0].nodes[8].type.name == 'conv1d_prepack'",
     "link":     "https://github.com/lutzroeder/netron/issues/842"
   },
+  {
+    "type":     "pytorch",
+    "target":   "test_complex.pt",
+    "source":   "https://github.com/user-attachments/files/24839070/test_complex.pt.zip[test_complex.pt]",
+    "assert":   "model.modules[0].nodes[1].inputs[0].value == '5 + 6i'",
+    "format":   "TorchScript v1.6",
+    "link":     "https://github.com/lutzroeder/netron/issues/720"
+  },
   {
     "type":     "pytorch",
     "target":   "test_model.pt2",

+ 1 - 1
test/worker.js

@@ -312,7 +312,7 @@ export class Target {
                     }
                     throw new Error(`Invalid property path '${parts[0]}'.`);
                 }
-                if (context !== value) {
+                if (context !== value && context.toString() !== value) {
                     throw new Error(`Invalid '${context}' != '${assert}'.`);
                 }
             }