Browse Source

[ConvNets/PyT] Fix interpolation type from Image.* to InterpolationMode.*

Adam Rajfer 3 năm trước cách đây
mục cha
commit
a684cf0527

+ 9 - 6
PyTorch/Classification/ConvNets/image_classification/dataloaders.py

@@ -34,6 +34,7 @@ import torchvision.datasets as datasets
 import torchvision.transforms as transforms
 from PIL import Image
 from functools import partial
+from torchvision.transforms.functional import InterpolationMode
 
 from image_classification.autoaugment import AutoaugmentImageNetPolicy
 
@@ -422,9 +423,10 @@ def get_pytorch_train_loader(
     prefetch_factor=2,
     memory_format=torch.contiguous_format,
 ):
-    interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
-        interpolation
-    ]
+    interpolation = {
+        "bicubic": InterpolationMode.BICUBIC,
+        "bilinear": InterpolationMode.BILINEAR,
+    }[interpolation]
     traindir = os.path.join(data_path, "train")
     transforms_list = [
         transforms.RandomResizedCrop(image_size, interpolation=interpolation),
@@ -474,9 +476,10 @@ def get_pytorch_val_loader(
     memory_format=torch.contiguous_format,
     prefetch_factor=2,
 ):
-    interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
-        interpolation
-    ]
+    interpolation = {
+        "bicubic": InterpolationMode.BICUBIC,
+        "bilinear": InterpolationMode.BILINEAR,
+    }[interpolation]
     valdir = os.path.join(data_path, "val")
     val_dataset = datasets.ImageFolder(
         valdir,