|
|
@@ -3076,7 +3076,6 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
case 'torch.relu_':
|
|
|
case 'torch.hardtanh_':
|
|
|
case 'torch.upsample_bilinear2d':
|
|
|
- case 'torch.unsqueeze':
|
|
|
case 'ops.prepacked.conv2d_clamp_run': {
|
|
|
parameter.resize_([ NaN, NaN, NaN, NaN ]);
|
|
|
break;
|
|
|
@@ -3160,6 +3159,41 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
|
|
|
parameter.resize_(evalArgs[1]);
|
|
|
break;
|
|
|
}
|
|
|
+ case 'torch.squeeze': {
|
|
|
+ const input = evalArgs[0];
|
|
|
+ const size = input.size();
|
|
|
+ if (Array.isArray(size)) {
|
|
|
+ switch (evalArgs.length) {
|
|
|
+ case 1: {
|
|
|
+ parameter.resize_(size.filter((value) => value !== 1));
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 2: {
|
|
|
+ const dim = evalArgs[1];
|
|
|
+ parameter.resize_(size.filter((value, index) => value !== 1 || index !== dim));
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ default: {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'torch.unsqueeze': {
|
|
|
+ const input = evalArgs[0];
|
|
|
+ const size = input.size();
|
|
|
+ const dim = evalArgs[1];
|
|
|
+ if (Array.isArray(size) && dim !== undefined) {
|
|
|
+ const shape = size.slice();
|
|
|
+ shape.splice(dim, 0, 1);
|
|
|
+ parameter.resize_(shape);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ parameter.resize_([ NaN, NaN, NaN, NaN ]);
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
case 'torch.transpose': {
|
|
|
const input = evalArgs[0];
|
|
|
let dim0 = evalArgs[1];
|