|
|
@@ -987,8 +987,22 @@ pytorch.Execution = class {
|
|
|
this._context.scope.builtins = {};
|
|
|
this._context.scope.builtins.type = { __module__: 'builtins', __name__: 'type' };
|
|
|
this._context.scope.builtins.module = { __module__: 'builtins', __name__: 'module', __class__: this._context.scope.builtins.type };
|
|
|
- this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__:this._context.scope.builtins.type };
|
|
|
- this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__:this._context.scope.builtins.type };
|
|
|
+ this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.builtins.dict = { __module__: 'builtins', __name__: 'dict', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.builtins.list = { __module__: 'builtins', __name__: 'list', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.builtins.str = { __module__: 'builtins', __name__: 'str', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.builtins.tuple = { __module__: 'builtins', __name__: 'tuple', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.typing = { __name__: 'typing', __class__: this._context.scope.builtins.module };
|
|
|
+ this._context.scope.typing._GenericAlias = { __module__: 'typing', __name__: '_GenericAlias', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.typing._SpecialForm = { __module__: 'typing', __name__: '_SpecialForm', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.typing._VariadicGenericAlias = { __module__: 'typing', __name__: '_VariadicGenericAlias', __class__: this._context.scope.builtins.type };
|
|
|
+ this._context.scope.typing.Dict = { __module__: 'typing', __name__: 'Dict', __class__: this._context.scope.typing._VariadicGenericAlias, __origin__: this._context.scope.builtins.dict };
|
|
|
+ this._context.scope.typing.List = { __module__: 'typing', __name__: 'List', __class__: this._context.scope.typing._GenericAlias, __origin__: this._context.scope.builtins.list };
|
|
|
+ this._context.scope.typing.Optional = { __module__: 'typing', __class__: this._context.scope.typing._SpecialForm };
|
|
|
+ this._context.scope.typing.Tuple = { __module__: 'typing', __name__: 'Tuple', __class__: this._context.scope.typing._GenericAlias, __origin__: this._context.scope.builtins.tuple };
|
|
|
+ this._context.scope.torch = { __name__: 'torch', __class__: this._context.scope.builtins.module };
|
|
|
+ this._context.scope.torch.Tensor = { __module__: 'torch', __name__: 'Tensor', __class__: this._context.scope.builtins.type };
|
|
|
this._registerConstructor('argparse.Namespace', function (args) {
|
|
|
this.args = args;
|
|
|
});
|
|
|
@@ -1691,7 +1705,19 @@ pytorch.Execution = class {
|
|
|
this._registerFunction('torch.warn', function() {
|
|
|
});
|
|
|
this._registerFunction('uninitialized', function(type) {
|
|
|
- return ({ __module__: 'torch', __name__: type, __origin__: 'uninitialized' });
|
|
|
+ if (type && type.__module__ === 'typing' && type.__name__ === 'Tuple') {
|
|
|
+ return [];
|
|
|
+ }
|
|
|
+ if (type && type.__module__ === 'typing' && type.__name__ === 'List') {
|
|
|
+ return [];
|
|
|
+ }
|
|
|
+ if (type && type.__module__ === 'typing' && type.__name__ === 'Dict') {
|
|
|
+ return {};
|
|
|
+ }
|
|
|
+ if (type && type.__module__ === 'torch' && type.__name__ === 'Tensor') {
|
|
|
+ return { __module__: type.__module__, __name__: type.__name__ };
|
|
|
+ }
|
|
|
+ throw new pytorch.Error("Unsupported uninitialized argument '" + JSON.stringify(type) + "'.");
|
|
|
});
|
|
|
}
|
|
|
|
|
|
@@ -1986,14 +2012,16 @@ pytorch.Execution = class {
|
|
|
const index = this.expression(expression.arguments.value[0], context);
|
|
|
return context.get(expression.target.value)[index];
|
|
|
}
|
|
|
- if (expression.target.value === 'List' || expression.target.value === 'Optional') {
|
|
|
- if (expression.arguments.value.every((item) => item.type === 'id')) {
|
|
|
- throw new pytorch.Error('Unsupported index expression.');
|
|
|
- // return { __typeref__: expression.target.value + '[' + expression.arguments.value.map((item) => item.value).join(',') + ']' };
|
|
|
- }
|
|
|
- }
|
|
|
}
|
|
|
const target = this.expression(expression.target, context);
|
|
|
+ if (target && expression.arguments.type === 'list' &&
|
|
|
+ (target.__class__ === this.context.scope.typing._VariadicGenericAlias ||
|
|
|
+ target.__class__ === this.context.scope.typing._GenericAlias ||
|
|
|
+ target.__class__ === this.context.scope.typing._SpecialForm)) {
|
|
|
+ const type = Object.assign({}, target);
|
|
|
+ type.__args__ = expression.arguments.value.map((arg) => this.expression(arg, context));
|
|
|
+ return type;
|
|
|
+ }
|
|
|
if (expression.arguments.type === 'list' && expression.arguments.value.length === 1) {
|
|
|
const index = this.expression(expression.arguments.value[0], context);
|
|
|
return target[index];
|
|
|
@@ -2008,16 +2036,6 @@ pytorch.Execution = class {
|
|
|
throw new pytorch.Error("Unsupported field expression.");
|
|
|
}
|
|
|
case 'call': {
|
|
|
- if (expression.target.type === 'id' && expression.target.value === 'uninitialized' && expression.arguments.length === 1) {
|
|
|
- const argument = expression.arguments[0];
|
|
|
- if (argument.type === 'id' && argument.value === 'Tensor') {
|
|
|
- return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' };
|
|
|
- }
|
|
|
- if (argument.type === '[]' && argument.target.type === 'id' && argument.target.value === 'Tuple' &&
|
|
|
- argument.arguments.type === 'list' && argument.arguments.value.every((item) => item.type === 'id' && item.value === 'Tensor')) {
|
|
|
- return argument.arguments.value.map((/* item */) => { return { __module__: 'torch', __name__: 'Tensor', __origin__: 'uninitialized' }; });
|
|
|
- }
|
|
|
- }
|
|
|
if (expression.target.type === 'id' && expression.target.value === 'annotate' && expression.arguments.length === 2) {
|
|
|
return this.expression(expression.arguments[1], context);
|
|
|
}
|
|
|
@@ -2036,6 +2054,17 @@ pytorch.Execution = class {
|
|
|
case 'True': return true;
|
|
|
case 'False': return false;
|
|
|
}
|
|
|
+ const type =
|
|
|
+ this._context.scope.builtins[expression.value] ||
|
|
|
+ this._context.scope.typing[expression.value] ||
|
|
|
+ this._context.scope.torch[expression.value];
|
|
|
+ if (type &&
|
|
|
+ (type.__class__ === this._context.scope.builtins.type ||
|
|
|
+ type.__class__ === this._context.scope.typing._VariadicGenericAlias ||
|
|
|
+ type.__class__ === this._context.scope.typing._GenericAlias ||
|
|
|
+ type.__class__ === this._context.scope.typing._SpecialForm)) {
|
|
|
+ return type;
|
|
|
+ }
|
|
|
return context.get(expression.value);
|
|
|
}
|
|
|
case 'tuple': {
|