|
@@ -34,6 +34,7 @@ import torchvision.datasets as datasets
|
|
|
import torchvision.transforms as transforms
|
|
import torchvision.transforms as transforms
|
|
|
from PIL import Image
|
|
from PIL import Image
|
|
|
from functools import partial
|
|
from functools import partial
|
|
|
|
|
+from torchvision.transforms.functional import InterpolationMode
|
|
|
|
|
|
|
|
from image_classification.autoaugment import AutoaugmentImageNetPolicy
|
|
from image_classification.autoaugment import AutoaugmentImageNetPolicy
|
|
|
|
|
|
|
@@ -422,9 +423,10 @@ def get_pytorch_train_loader(
|
|
|
prefetch_factor=2,
|
|
prefetch_factor=2,
|
|
|
memory_format=torch.contiguous_format,
|
|
memory_format=torch.contiguous_format,
|
|
|
):
|
|
):
|
|
|
- interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
|
|
|
|
|
- interpolation
|
|
|
|
|
- ]
|
|
|
|
|
|
|
+ interpolation = {
|
|
|
|
|
+ "bicubic": InterpolationMode.BICUBIC,
|
|
|
|
|
+ "bilinear": InterpolationMode.BILINEAR,
|
|
|
|
|
+ }[interpolation]
|
|
|
traindir = os.path.join(data_path, "train")
|
|
traindir = os.path.join(data_path, "train")
|
|
|
transforms_list = [
|
|
transforms_list = [
|
|
|
transforms.RandomResizedCrop(image_size, interpolation=interpolation),
|
|
transforms.RandomResizedCrop(image_size, interpolation=interpolation),
|
|
@@ -474,9 +476,10 @@ def get_pytorch_val_loader(
|
|
|
memory_format=torch.contiguous_format,
|
|
memory_format=torch.contiguous_format,
|
|
|
prefetch_factor=2,
|
|
prefetch_factor=2,
|
|
|
):
|
|
):
|
|
|
- interpolation = {"bicubic": Image.BICUBIC, "bilinear": Image.BILINEAR}[
|
|
|
|
|
- interpolation
|
|
|
|
|
- ]
|
|
|
|
|
|
|
+ interpolation = {
|
|
|
|
|
+ "bicubic": InterpolationMode.BICUBIC,
|
|
|
|
|
+ "bilinear": InterpolationMode.BILINEAR,
|
|
|
|
|
+ }[interpolation]
|
|
|
valdir = os.path.join(data_path, "val")
|
|
valdir = os.path.join(data_path, "val")
|
|
|
val_dataset = datasets.ImageFolder(
|
|
val_dataset = datasets.ImageFolder(
|
|
|
valdir,
|
|
valdir,
|