|
|
@@ -2989,7 +2989,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
condition.then.statements.length == 1 &&
|
|
|
pytorch.Utility.isCall(condition.then.statements[0], 'ops.prim.RaiseException', 1)) {
|
|
|
const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
|
|
|
- if (tensor && tensor.size) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.size) {
|
|
|
const number = this.expression(assign.expression.arguments[1], context);
|
|
|
const size = tensor.size();
|
|
|
if (size && size.length && size.length !== number &&
|
|
|
@@ -3014,19 +3014,35 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
condition.else.statements.length == 1 &&
|
|
|
pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) {
|
|
|
const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
|
|
|
- if (tensor && tensor.shape === undefined) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
|
|
|
const number = this.expression(assign.expression.arguments[1], context);
|
|
|
tensor.resize_(Array(number).fill(NaN));
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+ if (statements.length > 1) {
|
|
|
+ const size = statements[0];
|
|
|
+ const statement = statements[1];
|
|
|
+ // getattr_1 = torch.size(x)
|
|
|
+ // getitem = torch.slice(getattr_1, -2, 9223372036854775807, 1)
|
|
|
+ if (size.type === '=' && statement.type === '=' &&
|
|
|
+ size.target.type === 'id' &&
|
|
|
+ pytorch.Utility.isCall(size.expression, 'torch.size', 1) &&
|
|
|
+ pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
|
|
|
+ statement.expression.arguments[0].type === 'id' && size.target.value === statement.expression.arguments[0].value) {
|
|
|
+ const tensor = this.expression(size.expression.arguments[0], context);
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
+ tensor.resize_([ 1, 3, 299, 299 ]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
const statement = statements.shift();
|
|
|
// input_shape = torch.slice(torch.size(x), -2, 9223372036854775807, 1)
|
|
|
if (statement.type === '=' &&
|
|
|
pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
|
|
|
pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.size', 1)) {
|
|
|
const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
|
|
|
- if (tensor && tensor.shape === undefined) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
|
|
|
tensor.resize_([ 1, 3, 299, 299 ]);
|
|
|
}
|
|
|
}
|
|
|
@@ -3035,7 +3051,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
|
|
|
pytorch.Utility.isCall(statement.expression.arguments[0], 'ops.prim.shape', 1)) {
|
|
|
const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
|
|
|
- if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
tensor.resize_([ NaN, NaN, NaN, NaN ]);
|
|
|
}
|
|
|
}
|
|
|
@@ -3044,7 +3060,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
pytorch.Utility.isCall(statement.expression, 'torch.le', 2) &&
|
|
|
pytorch.Utility.isCall(statement.expression.arguments[1], 'torch.dim', 1)) {
|
|
|
const tensor = this.expression(statement.expression.arguments[1].arguments[0], context);
|
|
|
- if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
tensor.resize_([ NaN, NaN, NaN, NaN ]);
|
|
|
}
|
|
|
}
|
|
|
@@ -3058,7 +3074,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
pytorch.Utility.isCall(statement.then.statements.slice(-1).pop(), 'ops.prim.RaiseException', 1)) {
|
|
|
const tensor = this.expression(statement.condition.arguments[0].arguments[0], context);
|
|
|
const size = this.expression(statement.condition.arguments[1], context);
|
|
|
- if (tensor && Number.isInteger(size) && size < 10) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && Number.isInteger(size) && size < 10) {
|
|
|
tensor.resize_(Array.isArray(tensor.shape) && tensor.shape.length > size ? tensor.shape.slice(-size) : Array(size).fill(NaN));
|
|
|
}
|
|
|
}
|
|
|
@@ -3068,7 +3084,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
pytorch.Utility.isCall(statement.expression, 'torch.sub', 2) &&
|
|
|
pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.dim', 1)) {
|
|
|
const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
|
|
|
- if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
tensor.resize_([ NaN, NaN, NaN, NaN ]);
|
|
|
}
|
|
|
}
|
|
|
@@ -3083,7 +3099,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
statement.target.type === 'id' &&
|
|
|
pytorch.Utility.isCall(statement.expression, 'torch.len', 1)) {
|
|
|
const tensor = this.expression(statement.expression.arguments[0], context);
|
|
|
- if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
+ if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
|
|
|
tensor.resize_([ NaN, NaN, NaN, NaN ]);
|
|
|
}
|
|
|
}
|