| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577 |
- # Copyright (c) 2018-2019, NVIDIA CORPORATION
- # Copyright (c) 2017- Facebook, Inc
- #
- # All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- #
- # * Redistributions of source code must retain the above copyright notice, this
- # list of conditions and the following disclaimer.
- #
- # * Redistributions in binary form must reproduce the above copyright notice,
- # this list of conditions and the following disclaimer in the documentation
- # and/or other materials provided with the distribution.
- #
- # * Neither the name of the copyright holder nor the names of its
- # contributors may be used to endorse or promote products derived from
- # this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- import os
- import torch
- import numpy as np
- import torchvision.datasets as datasets
- import torchvision.transforms as transforms
- from PIL import Image
- from functools import partial
- from torchvision.transforms.functional import InterpolationMode
- from image_classification.autoaugment import AutoaugmentImageNetPolicy
- DATA_BACKEND_CHOICES = ["pytorch", "synthetic"]
- try:
- from nvidia.dali.plugin.pytorch import DALIClassificationIterator
- from nvidia.dali.pipeline import Pipeline
- import nvidia.dali.ops as ops
- import nvidia.dali.types as types
- DATA_BACKEND_CHOICES.append("dali-gpu")
- DATA_BACKEND_CHOICES.append("dali-cpu")
- except ImportError:
- print(
- "Please install DALI from https://www.github.com/NVIDIA/DALI to run this example."
- )
- def load_jpeg_from_file(path, cuda=True):
- img_transforms = transforms.Compose(
- [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
- )
- img = img_transforms(Image.open(path))
- with torch.no_grad():
- # mean and std are not multiplied by 255 as they are in training script
- # torch dataloader reads data into bytes whereas loading directly
- # through PIL creates a tensor with floats in [0,1] range
- mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
- std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
- if cuda:
- mean = mean.cuda()
- std = std.cuda()
- img = img.cuda()
- img = img.float()
- input = img.unsqueeze(0).sub_(mean).div_(std)
- return input
- class HybridTrainPipe(Pipeline):
- def __init__(
- self,
- batch_size,
- num_threads,
- device_id,
- data_dir,
- interpolation,
- crop,
- dali_cpu=False,
- ):
- super(HybridTrainPipe, self).__init__(
- batch_size, num_threads, device_id, seed=12 + device_id
- )
- interpolation = {
- "bicubic": types.INTERP_CUBIC,
- "bilinear": types.INTERP_LINEAR,
- "triangular": types.INTERP_TRIANGULAR,
- }[interpolation]
- if torch.distributed.is_initialized():
- rank = torch.distributed.get_rank()
- world_size = torch.distributed.get_world_size()
- else:
- rank = 0
- world_size = 1
- self.input = ops.FileReader(
- file_root=data_dir,
- shard_id=rank,
- num_shards=world_size,
- random_shuffle=True,
- pad_last_batch=True,
- )
- if dali_cpu:
- dali_device = "cpu"
- self.decode = ops.ImageDecoder(device=dali_device, output_type=types.RGB)
- else:
- dali_device = "gpu"
- # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
- # without additional reallocations
- self.decode = ops.ImageDecoder(
- device="mixed",
- output_type=types.RGB,
- device_memory_padding=211025920,
- host_memory_padding=140544512,
- )
- self.res = ops.RandomResizedCrop(
- device=dali_device,
- size=[crop, crop],
- interp_type=interpolation,
- random_aspect_ratio=[0.75, 4.0 / 3.0],
- random_area=[0.08, 1.0],
- num_attempts=100,
- antialias=False,
- )
- self.cmnp = ops.CropMirrorNormalize(
- device="gpu",
- dtype=types.FLOAT,
- output_layout=types.NCHW,
- crop=(crop, crop),
- mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
- std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
- )
- self.coin = ops.CoinFlip(probability=0.5)
- def define_graph(self):
- rng = self.coin()
- self.jpegs, self.labels = self.input(name="Reader")
- images = self.decode(self.jpegs)
- images = self.res(images)
- output = self.cmnp(images.gpu(), mirror=rng)
- return [output, self.labels]
- class HybridValPipe(Pipeline):
- def __init__(
- self, batch_size, num_threads, device_id, data_dir, interpolation, crop, size
- ):
- super(HybridValPipe, self).__init__(
- batch_size, num_threads, device_id, seed=12 + device_id
- )
- interpolation = {
- "bicubic": types.INTERP_CUBIC,
- "bilinear": types.INTERP_LINEAR,
- "triangular": types.INTERP_TRIANGULAR,
- }[interpolation]
- if torch.distributed.is_initialized():
- rank = torch.distributed.get_rank()
- world_size = torch.distributed.get_world_size()
- else:
- rank = 0
- world_size = 1
- self.input = ops.FileReader(
- file_root=data_dir,
- shard_id=rank,
- num_shards=world_size,
- random_shuffle=False,
- pad_last_batch=True,
- )
- self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
- self.res = ops.Resize(
- device="gpu",
- resize_shorter=size,
- interp_type=interpolation,
- antialias=False,
- )
- self.cmnp = ops.CropMirrorNormalize(
- device="gpu",
- dtype=types.FLOAT,
- output_layout=types.NCHW,
- crop=(crop, crop),
- mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
- std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
- )
- def define_graph(self):
- self.jpegs, self.labels = self.input(name="Reader")
- images = self.decode(self.jpegs)
- images = self.res(images)
- output = self.cmnp(images)
- return [output, self.labels]
- class DALIWrapper(object):
- def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format):
- for data in dalipipeline:
- input = data[0]["data"].contiguous(memory_format=memory_format)
- target = torch.reshape(data[0]["label"], [-1]).cuda().long()
- if one_hot:
- target = expand(num_classes, torch.float, target)
- yield input, target
- dalipipeline.reset()
- def __init__(self, dalipipeline, num_classes, one_hot, memory_format):
- self.dalipipeline = dalipipeline
- self.num_classes = num_classes
- self.one_hot = one_hot
- self.memory_format = memory_format
- def __iter__(self):
- return DALIWrapper.gen_wrapper(
- self.dalipipeline, self.num_classes, self.one_hot, self.memory_format
- )
- def get_dali_train_loader(dali_cpu=False):
- def gdtl(
- data_path,
- image_size,
- batch_size,
- num_classes,
- one_hot,
- interpolation="bilinear",
- augmentation=None,
- start_epoch=0,
- workers=5,
- _worker_init_fn=None,
- memory_format=torch.contiguous_format,
- **kwargs,
- ):
- if torch.distributed.is_initialized():
- rank = torch.distributed.get_rank()
- world_size = torch.distributed.get_world_size()
- else:
- rank = 0
- world_size = 1
- traindir = os.path.join(data_path, "train")
- if augmentation is not None:
- raise NotImplementedError(
- f"Augmentation {augmentation} for dali loader is not supported"
- )
- pipe = HybridTrainPipe(
- batch_size=batch_size,
- num_threads=workers,
- device_id=rank % torch.cuda.device_count(),
- data_dir=traindir,
- interpolation=interpolation,
- crop=image_size,
- dali_cpu=dali_cpu,
- )
- pipe.build()
- train_loader = DALIClassificationIterator(
- pipe, reader_name="Reader", fill_last_batch=False
- )
- return (
- DALIWrapper(train_loader, num_classes, one_hot, memory_format),
- int(pipe.epoch_size("Reader") / (world_size * batch_size)),
- )
- return gdtl
- def get_dali_val_loader():
- def gdvl(
- data_path,
- image_size,
- batch_size,
- num_classes,
- one_hot,
- interpolation="bilinear",
- crop_padding=32,
- workers=5,
- _worker_init_fn=None,
- memory_format=torch.contiguous_format,
- **kwargs,
- ):
- if torch.distributed.is_initialized():
- rank = torch.distributed.get_rank()
- world_size = torch.distributed.get_world_size()
- else:
- rank = 0
- world_size = 1
- valdir = os.path.join(data_path, "val")
- pipe = HybridValPipe(
- batch_size=batch_size,
- num_threads=workers,
- device_id=rank % torch.cuda.device_count(),
- data_dir=valdir,
- interpolation=interpolation,
- crop=image_size,
- size=image_size + crop_padding,
- )
- pipe.build()
- val_loader = DALIClassificationIterator(
- pipe, reader_name="Reader", fill_last_batch=False
- )
- return (
- DALIWrapper(val_loader, num_classes, one_hot, memory_format),
- int(pipe.epoch_size("Reader") / (world_size * batch_size)),
- )
- return gdvl
- def fast_collate(memory_format, batch):
- imgs = [img[0] for img in batch]
- targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
- w = imgs[0].size[0]
- h = imgs[0].size[1]
- tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8).contiguous(
- memory_format=memory_format
- )
- for i, img in enumerate(imgs):
- nump_array = np.asarray(img, dtype=np.uint8)
- if nump_array.ndim < 3:
- nump_array = np.expand_dims(nump_array, axis=-1)
- nump_array = np.rollaxis(nump_array, 2)
- tensor[i] += torch.from_numpy(nump_array.copy())
- return tensor, targets
- def expand(num_classes, dtype, tensor):
- e = torch.zeros(
- tensor.size(0), num_classes, dtype=dtype, device=torch.device("cuda")
- )
- e = e.scatter(1, tensor.unsqueeze(1), 1.0)
- return e
- class PrefetchedWrapper(object):
- def prefetched_loader(loader, num_classes, one_hot):
- mean = (
- torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
- .cuda()
- .view(1, 3, 1, 1)
- )
- std = (
- torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255])
- .cuda()
- .view(1, 3, 1, 1)
- )
- stream = torch.cuda.Stream()
- first = True
- for next_input, next_target in loader:
- with torch.cuda.stream(stream):
- next_input = next_input.cuda(non_blocking=True)
- next_target = next_target.cuda(non_blocking=True)
- next_input = next_input.float()
- if one_hot:
- next_target = expand(num_classes, torch.float, next_target)
- next_input = next_input.sub_(mean).div_(std)
- if not first:
- yield input, target
- else:
- first = False
- torch.cuda.current_stream().wait_stream(stream)
- input = next_input
- target = next_target
- yield input, target
- def __init__(self, dataloader, start_epoch, num_classes, one_hot):
- self.dataloader = dataloader
- self.epoch = start_epoch
- self.one_hot = one_hot
- self.num_classes = num_classes
- def __iter__(self):
- if self.dataloader.sampler is not None and isinstance(
- self.dataloader.sampler, torch.utils.data.distributed.DistributedSampler
- ):
- self.dataloader.sampler.set_epoch(self.epoch)
- self.epoch += 1
- return PrefetchedWrapper.prefetched_loader(
- self.dataloader, self.num_classes, self.one_hot
- )
- def __len__(self):
- return len(self.dataloader)
- def get_pytorch_train_loader(
- data_path,
- image_size,
- batch_size,
- num_classes,
- one_hot,
- interpolation="bilinear",
- augmentation=None,
- start_epoch=0,
- workers=5,
- _worker_init_fn=None,
- prefetch_factor=2,
- memory_format=torch.contiguous_format,
- ):
- interpolation = {
- "bicubic": InterpolationMode.BICUBIC,
- "bilinear": InterpolationMode.BILINEAR,
- }[interpolation]
- traindir = os.path.join(data_path, "train")
- transforms_list = [
- transforms.RandomResizedCrop(image_size, interpolation=interpolation),
- transforms.RandomHorizontalFlip(),
- ]
- if augmentation == "autoaugment":
- transforms_list.append(AutoaugmentImageNetPolicy())
- train_dataset = datasets.ImageFolder(traindir, transforms.Compose(transforms_list))
- if torch.distributed.is_initialized():
- train_sampler = torch.utils.data.distributed.DistributedSampler(
- train_dataset, shuffle=True
- )
- else:
- train_sampler = None
- train_loader = torch.utils.data.DataLoader(
- train_dataset,
- sampler=train_sampler,
- batch_size=batch_size,
- shuffle=(train_sampler is None),
- num_workers=workers,
- worker_init_fn=_worker_init_fn,
- pin_memory=True,
- collate_fn=partial(fast_collate, memory_format),
- drop_last=True,
- persistent_workers=True,
- prefetch_factor=prefetch_factor,
- )
- return (
- PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot),
- len(train_loader),
- )
- def get_pytorch_val_loader(
- data_path,
- image_size,
- batch_size,
- num_classes,
- one_hot,
- interpolation="bilinear",
- workers=5,
- _worker_init_fn=None,
- crop_padding=32,
- memory_format=torch.contiguous_format,
- prefetch_factor=2,
- ):
- interpolation = {
- "bicubic": InterpolationMode.BICUBIC,
- "bilinear": InterpolationMode.BILINEAR,
- }[interpolation]
- valdir = os.path.join(data_path, "val")
- val_dataset = datasets.ImageFolder(
- valdir,
- transforms.Compose(
- [
- transforms.Resize(
- image_size + crop_padding, interpolation=interpolation
- ),
- transforms.CenterCrop(image_size),
- ]
- ),
- )
- if torch.distributed.is_initialized():
- val_sampler = torch.utils.data.distributed.DistributedSampler(
- val_dataset, shuffle=False
- )
- else:
- val_sampler = None
- val_loader = torch.utils.data.DataLoader(
- val_dataset,
- sampler=val_sampler,
- batch_size=batch_size,
- shuffle=(val_sampler is None),
- num_workers=workers,
- worker_init_fn=_worker_init_fn,
- pin_memory=True,
- collate_fn=partial(fast_collate, memory_format),
- drop_last=False,
- persistent_workers=True,
- prefetch_factor=prefetch_factor,
- )
- return PrefetchedWrapper(val_loader, 0, num_classes, one_hot), len(val_loader)
- class SynteticDataLoader(object):
- def __init__(
- self,
- batch_size,
- num_classes,
- num_channels,
- height,
- width,
- one_hot,
- memory_format=torch.contiguous_format,
- ):
- input_data = (
- torch.randn(batch_size, num_channels, height, width)
- .contiguous(memory_format=memory_format)
- .cuda()
- .normal_(0, 1.0)
- )
- if one_hot:
- input_target = torch.empty(batch_size, num_classes).cuda()
- input_target[:, 0] = 1.0
- else:
- input_target = torch.randint(0, num_classes, (batch_size,))
- input_target = input_target.cuda()
- self.input_data = input_data
- self.input_target = input_target
- def __iter__(self):
- while True:
- yield self.input_data, self.input_target
- def get_synthetic_loader(
- data_path,
- image_size,
- batch_size,
- num_classes,
- one_hot,
- interpolation=None,
- augmentation=None,
- start_epoch=0,
- workers=None,
- _worker_init_fn=None,
- memory_format=torch.contiguous_format,
- **kwargs,
- ):
- return (
- SynteticDataLoader(
- batch_size,
- num_classes,
- 3,
- image_size,
- image_size,
- one_hot,
- memory_format=memory_format,
- ),
- -1,
- )
|