Browse Source

Update pytorch.js (#918)

Lutz Roeder 3 years ago
parent
commit
a9fb89fe9f
1 changed files with 35 additions and 1 deletions
  1. 35 1
      source/pytorch.js

+ 35 - 1
source/pytorch.js

@@ -3076,7 +3076,6 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                         case 'torch.relu_':
                                         case 'torch.hardtanh_':
                                         case 'torch.upsample_bilinear2d':
-                                        case 'torch.unsqueeze':
                                         case 'ops.prepacked.conv2d_clamp_run': {
                                             parameter.resize_([ NaN, NaN, NaN, NaN ]);
                                             break;
@@ -3160,6 +3159,41 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
                                             parameter.resize_(evalArgs[1]);
                                             break;
                                         }
+                                        case 'torch.squeeze': {
+                                            const input = evalArgs[0];
+                                            const size = input.size();
+                                            if (Array.isArray(size)) {
+                                                switch (evalArgs.length) {
+                                                    case 1: {
+                                                        parameter.resize_(size.filter((value) => value !== 1));
+                                                        break;
+                                                    }
+                                                    case 2: {
+                                                        const dim = evalArgs[1];
+                                                        parameter.resize_(size.filter((value, index) => value !== 1 || index !== dim));
+                                                        break;
+                                                    }
+                                                    default: {
+                                                        break;
+                                                    }
+                                                }
+                                            }
+                                            break;
+                                        }
+                                        case 'torch.unsqueeze': {
+                                            const input = evalArgs[0];
+                                            const size = input.size();
+                                            const dim = evalArgs[1];
+                                            if (Array.isArray(size) && dim !== undefined) {
+                                                const shape = size.slice();
+                                                shape.splice(dim, 0, 1);
+                                                parameter.resize_(shape);
+                                            }
+                                            else {
+                                                parameter.resize_([ NaN, NaN, NaN, NaN ]);
+                                            }
+                                            break;
+                                        }
                                         case 'torch.transpose': {
                                             const input = evalArgs[0];
                                             let dim0 = evalArgs[1];