Răsfoiți Sursa

PyTorch Python typing support

Lutz Roeder 6 ani în urmă
părinte
comite
3ec0ee8fb3
1 a modificat fișierele cu 48 adăugiri și 19 ștergeri
  1. 48 19
      src/pytorch.js

+ 48 - 19
src/pytorch.js

@@ -987,8 +987,22 @@ pytorch.Execution = class {
         this._context.scope.builtins = {};
         this._context.scope.builtins.type = { __module__: 'builtins', __name__: 'type' };
         this._context.scope.builtins.module = { __module__: 'builtins', __name__: 'module', __class__: this._context.scope.builtins.type };
-        this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__:this._context.scope.builtins.type };
-        this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__:this._context.scope.builtins.type };
+        this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__: this._context.scope.builtins.type };
+        this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__: this._context.scope.builtins.type };
+        this._context.scope.builtins.dict = { __module__: 'builtins', __name__: 'dict', __class__: this._context.scope.builtins.type };
+        this._context.scope.builtins.list = { __module__: 'builtins', __name__: 'list', __class__: this._context.scope.builtins.type };
+        this._context.scope.builtins.str = { __module__: 'builtins', __name__: 'str', __class__: this._context.scope.builtins.type };
+        this._context.scope.builtins.tuple = { __module__: 'builtins', __name__: 'tuple', __class__: this._context.scope.builtins.type };
+        this._context.scope.typing = { __name__: 'typing', __class__: this._context.scope.builtins.module };
+        this._context.scope.typing._GenericAlias = { __module__: 'typing', __name__: '_GenericAlias', __class__: this._context.scope.builtins.type };
+        this._context.scope.typing._SpecialForm = { __module__: 'typing', __name__: '_SpecialForm', __class__: this._context.scope.builtins.type };
+        this._context.scope.typing._VariadicGenericAlias = { __module__: 'typing', __name__: '_VariadicGenericAlias', __class__: this._context.scope.builtins.type };
+        this._context.scope.typing.Dict = { __module__: 'typing', __name__: 'Dict', __class__: this._context.scope.typing._VariadicGenericAlias, __origin__: this._context.scope.builtins.dict };
+        this._context.scope.typing.List = { __module__: 'typing', __name__: 'List', __class__: this._context.scope.typing._GenericAlias, __origin__: this._context.scope.builtins.list };
+        this._context.scope.typing.Optional = { __module__: 'typing', __class__: this._context.scope.typing._SpecialForm };
+        this._context.scope.typing.Tuple = { __module__: 'typing', __name__: 'Tuple', __class__: this._context.scope.typing._GenericAlias, __origin__: this._context.scope.builtins.tuple };
+        this._context.scope.torch = { __name__: 'torch', __class__: this._context.scope.builtins.module };
+        this._context.scope.torch.Tensor = { __module__: 'torch', __name__: 'Tensor', __class__: this._context.scope.builtins.type };
         this._registerConstructor('argparse.Namespace', function (args) {
             this.args = args;
         });
@@ -1691,7 +1705,19 @@ pytorch.Execution = class {
         this._registerFunction('torch.warn', function() {
         });
         this._registerFunction('uninitialized', function(type) {
-            return ({ __module__: 'torch', __name__: type, __origin__: 'uninitialized' });
+            if (type && type.__module__ === 'typing' && type.__name__ === 'Tuple') {
+                return [];
+            }
+            if (type && type.__module__ === 'typing' && type.__name__ === 'List') {
+                return [];
+            }
+            if (type && type.__module__ === 'typing' && type.__name__ === 'Dict') {
+                return {};
+            }
+            if (type && type.__module__ === 'torch' && type.__name__ === 'Tensor') {
+                return { __module__: type.__module__, __name__: type.__name__ };
+            }
+            throw new pytorch.Error("Unsupported uninitialized argument '" + JSON.stringify(type) + "'.");
         });
     }
 
@@ -1986,14 +2012,16 @@ pytorch.Execution = class {
                         const index = this.expression(expression.arguments.value[0], context);
                         return context.get(expression.target.value)[index];
                     }
-                    if (expression.target.value === 'List' || expression.target.value === 'Optional') {
-                        if (expression.arguments.value.every((item) => item.type === 'id')) {
-                            throw new pytorch.Error('Unsupported index expression.');
-                            // return { __typeref__: expression.target.value + '[' + expression.arguments.value.map((item) => item.value).join(',') + ']' };
-                        }
-                    }
                 }
                 const target = this.expression(expression.target, context);
+                if (target && expression.arguments.type === 'list' &&
+                    (target.__class__ === this.context.scope.typing._VariadicGenericAlias ||
+                     target.__class__ === this.context.scope.typing._GenericAlias ||
+                     target.__class__ === this.context.scope.typing._SpecialForm)) {
+                    const type = Object.assign({}, target);
+                    type.__args__ = expression.arguments.value.map((arg) => this.expression(arg, context));
+                    return type;
+                }
                 if (expression.arguments.type === 'list' && expression.arguments.value.length === 1) {
                     const index = this.expression(expression.arguments.value[0], context);
                     return target[index];
@@ -2008,16 +2036,6 @@ pytorch.Execution = class {
                 throw new pytorch.Error("Unsupported field expression.");
             }
             case 'call': {
-                if (expression.target.type === 'id' && expression.target.value === 'uninitialized' && expression.arguments.length === 1) {
-                    const argument = expression.arguments[0];
-                    if (argument.type === 'id' && argument.value === 'Tensor') {
-                        return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' };
-                    }
-                    if (argument.type === '[]' && argument.target.type === 'id' && argument.target.value === 'Tuple' &&
-                        argument.arguments.type === 'list' && argument.arguments.value.every((item) => item.type === 'id' && item.value === 'Tensor')) {
-                        return argument.arguments.value.map((/* item */) => { return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' }; });
-                    }
-                }
                 if (expression.target.type === 'id' && expression.target.value === 'annotate' && expression.arguments.length === 2) {
                     return this.expression(expression.arguments[1], context);
                 }
@@ -2036,6 +2054,17 @@ pytorch.Execution = class {
                     case 'True': return true;
                     case 'False': return false;
                 }
+                const type =
+                    this._context.scope.builtins[expression.value] ||
+                    this._context.scope.typing[expression.value] || 
+                    this._context.scope.torch[expression.value];
+                if (type && 
+                    (type.__class__ === this._context.scope.builtins.type ||
+                     type.__class__ === this._context.scope.typing._VariadicGenericAlias ||
+                     type.__class__ === this._context.scope.typing._GenericAlias ||
+                     type.__class__ === this._context.scope.typing._SpecialForm)) {
+                    return type;
+                }
                 return context.get(expression.value);
             }
             case 'tuple': {