Explorar el Código

Workaround pytorch/pytorch#48525 (#779)

Lutz Roeder hace 4 años
padre
commit
ed3f8e6f43
Se han modificado 1 ficheros con 24 adiciones y 8 borrados
  1. 24 8
      source/pytorch.js

+ 24 - 8
source/pytorch.js

@@ -2989,7 +2989,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                     condition.then.statements.length == 1 &&
                     pytorch.Utility.isCall(condition.then.statements[0], 'ops.prim.RaiseException', 1)) {
                     const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
-                    if (tensor && tensor.size) {
+                    if (pytorch.Utility.isTensor(tensor) && tensor.size) {
                         const number = this.expression(assign.expression.arguments[1], context);
                         const size = tensor.size();
                         if (size && size.length && size.length !== number &&
@@ -3014,19 +3014,35 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                     condition.else.statements.length == 1 &&
                     pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) {
                     const tensor = this.expression(assign.expression.arguments[0].arguments[0].arguments[0], context);
-                    if (tensor && tensor.shape === undefined) {
+                    if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
                         const number = this.expression(assign.expression.arguments[1], context);
                         tensor.resize_(Array(number).fill(NaN));
                     }
                 }
             }
+            if (statements.length > 1) {
+                const size = statements[0];
+                const statement = statements[1];
+                // getattr_1 = torch.size(x)
+                // getitem = torch.slice(getattr_1, -2, 9223372036854775807, 1)
+                if (size.type === '=' && statement.type === '=' &&
+                    size.target.type === 'id' &&
+                    pytorch.Utility.isCall(size.expression, 'torch.size', 1) &&
+                    pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
+                    statement.expression.arguments[0].type === 'id' && size.target.value === statement.expression.arguments[0].value) {
+                    const tensor = this.expression(size.expression.arguments[0], context);
+                    if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
+                        tensor.resize_([ 1, 3, 299, 299 ]);
+                    }
+                }
+            }
             const statement = statements.shift();
             // input_shape = torch.slice(torch.size(x), -2, 9223372036854775807, 1)
             if (statement.type === '=' &&
                 pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
                 pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.size', 1)) {
                 const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
-                if (tensor && tensor.shape === undefined) {
+                if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
                     tensor.resize_([ 1, 3, 299, 299 ]);
                 }
             }
@@ -3035,7 +3051,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                 pytorch.Utility.isCall(statement.expression, 'torch.slice', 4) &&
                 pytorch.Utility.isCall(statement.expression.arguments[0], 'ops.prim.shape', 1)) {
                 const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
-                if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
+                if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
                     tensor.resize_([ NaN, NaN, NaN, NaN ]);
                 }
             }
@@ -3044,7 +3060,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                 pytorch.Utility.isCall(statement.expression, 'torch.le', 2) &&
                 pytorch.Utility.isCall(statement.expression.arguments[1], 'torch.dim', 1)) {
                 const tensor = this.expression(statement.expression.arguments[1].arguments[0], context);
-                if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
+                if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
                     tensor.resize_([ NaN, NaN, NaN, NaN ]);
                 }
             }
@@ -3058,7 +3074,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                 pytorch.Utility.isCall(statement.then.statements.slice(-1).pop(), 'ops.prim.RaiseException', 1)) {
                 const tensor = this.expression(statement.condition.arguments[0].arguments[0], context);
                 const size = this.expression(statement.condition.arguments[1], context);
-                if (tensor && Number.isInteger(size) && size < 10) {
+                if (pytorch.Utility.isTensor(tensor) && Number.isInteger(size) && size < 10) {
                     tensor.resize_(Array.isArray(tensor.shape) && tensor.shape.length > size ? tensor.shape.slice(-size) : Array(size).fill(NaN));
                 }
             }
@@ -3068,7 +3084,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                 pytorch.Utility.isCall(statement.expression, 'torch.sub', 2) &&
                 pytorch.Utility.isCall(statement.expression.arguments[0], 'torch.dim', 1)) {
                 const tensor = this.expression(statement.expression.arguments[0].arguments[0], context);
-                if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
+                if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
                     tensor.resize_([ NaN, NaN, NaN, NaN ]);
                 }
             }
@@ -3083,7 +3099,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                 statement.target.type === 'id' &&
                 pytorch.Utility.isCall(statement.expression, 'torch.len', 1)) {
                 const tensor = this.expression(statement.expression.arguments[0], context);
-                if (tensor && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
+                if (pytorch.Utility.isTensor(tensor) && tensor.__origin__ === 'graph-input' && tensor.shape === undefined) {
                     tensor.resize_([ NaN, NaN, NaN, NaN ]);
                 }
             }