dataloaders.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. # Copyright (c) 2018-2019, NVIDIA CORPORATION
  2. # Copyright (c) 2017- Facebook, Inc
  3. #
  4. # All rights reserved.
  5. #
  6. # Redistribution and use in source and binary forms, with or without
  7. # modification, are permitted provided that the following conditions are met:
  8. #
  9. # * Redistributions of source code must retain the above copyright notice, this
  10. # list of conditions and the following disclaimer.
  11. #
  12. # * Redistributions in binary form must reproduce the above copyright notice,
  13. # this list of conditions and the following disclaimer in the documentation
  14. # and/or other materials provided with the distribution.
  15. #
  16. # * Neither the name of the copyright holder nor the names of its
  17. # contributors may be used to endorse or promote products derived from
  18. # this software without specific prior written permission.
  19. #
  20. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  23. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  24. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  25. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  26. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  27. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  28. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  29. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. import os
  31. import torch
  32. import numpy as np
  33. import torchvision.datasets as datasets
  34. import torchvision.transforms as transforms
  35. from PIL import Image
  36. from functools import partial
  37. from torchvision.transforms.functional import InterpolationMode
  38. from image_classification.autoaugment import AutoaugmentImageNetPolicy
  39. DATA_BACKEND_CHOICES = ["pytorch", "synthetic"]
  40. try:
  41. from nvidia.dali.plugin.pytorch import DALIClassificationIterator
  42. from nvidia.dali.pipeline import Pipeline
  43. import nvidia.dali.ops as ops
  44. import nvidia.dali.types as types
  45. DATA_BACKEND_CHOICES.append("dali-gpu")
  46. DATA_BACKEND_CHOICES.append("dali-cpu")
  47. except ImportError:
  48. print(
  49. "Please install DALI from https://www.github.com/NVIDIA/DALI to run this example."
  50. )
  51. def load_jpeg_from_file(path, cuda=True):
  52. img_transforms = transforms.Compose(
  53. [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
  54. )
  55. img = img_transforms(Image.open(path))
  56. with torch.no_grad():
  57. # mean and std are not multiplied by 255 as they are in training script
  58. # torch dataloader reads data into bytes whereas loading directly
  59. # through PIL creates a tensor with floats in [0,1] range
  60. mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
  61. std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
  62. if cuda:
  63. mean = mean.cuda()
  64. std = std.cuda()
  65. img = img.cuda()
  66. img = img.float()
  67. input = img.unsqueeze(0).sub_(mean).div_(std)
  68. return input
  69. class HybridTrainPipe(Pipeline):
  70. def __init__(
  71. self,
  72. batch_size,
  73. num_threads,
  74. device_id,
  75. data_dir,
  76. interpolation,
  77. crop,
  78. dali_cpu=False,
  79. ):
  80. super(HybridTrainPipe, self).__init__(
  81. batch_size, num_threads, device_id, seed=12 + device_id
  82. )
  83. interpolation = {
  84. "bicubic": types.INTERP_CUBIC,
  85. "bilinear": types.INTERP_LINEAR,
  86. "triangular": types.INTERP_TRIANGULAR,
  87. }[interpolation]
  88. if torch.distributed.is_initialized():
  89. rank = torch.distributed.get_rank()
  90. world_size = torch.distributed.get_world_size()
  91. else:
  92. rank = 0
  93. world_size = 1
  94. self.input = ops.FileReader(
  95. file_root=data_dir,
  96. shard_id=rank,
  97. num_shards=world_size,
  98. random_shuffle=True,
  99. pad_last_batch=True,
  100. )
  101. if dali_cpu:
  102. dali_device = "cpu"
  103. self.decode = ops.ImageDecoder(device=dali_device, output_type=types.RGB)
  104. else:
  105. dali_device = "gpu"
  106. # This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
  107. # without additional reallocations
  108. self.decode = ops.ImageDecoder(
  109. device="mixed",
  110. output_type=types.RGB,
  111. device_memory_padding=211025920,
  112. host_memory_padding=140544512,
  113. )
  114. self.res = ops.RandomResizedCrop(
  115. device=dali_device,
  116. size=[crop, crop],
  117. interp_type=interpolation,
  118. random_aspect_ratio=[0.75, 4.0 / 3.0],
  119. random_area=[0.08, 1.0],
  120. num_attempts=100,
  121. antialias=False,
  122. )
  123. self.cmnp = ops.CropMirrorNormalize(
  124. device="gpu",
  125. dtype=types.FLOAT,
  126. output_layout=types.NCHW,
  127. crop=(crop, crop),
  128. mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
  129. std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
  130. )
  131. self.coin = ops.CoinFlip(probability=0.5)
  132. def define_graph(self):
  133. rng = self.coin()
  134. self.jpegs, self.labels = self.input(name="Reader")
  135. images = self.decode(self.jpegs)
  136. images = self.res(images)
  137. output = self.cmnp(images.gpu(), mirror=rng)
  138. return [output, self.labels]
  139. class HybridValPipe(Pipeline):
  140. def __init__(
  141. self, batch_size, num_threads, device_id, data_dir, interpolation, crop, size
  142. ):
  143. super(HybridValPipe, self).__init__(
  144. batch_size, num_threads, device_id, seed=12 + device_id
  145. )
  146. interpolation = {
  147. "bicubic": types.INTERP_CUBIC,
  148. "bilinear": types.INTERP_LINEAR,
  149. "triangular": types.INTERP_TRIANGULAR,
  150. }[interpolation]
  151. if torch.distributed.is_initialized():
  152. rank = torch.distributed.get_rank()
  153. world_size = torch.distributed.get_world_size()
  154. else:
  155. rank = 0
  156. world_size = 1
  157. self.input = ops.FileReader(
  158. file_root=data_dir,
  159. shard_id=rank,
  160. num_shards=world_size,
  161. random_shuffle=False,
  162. pad_last_batch=True,
  163. )
  164. self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
  165. self.res = ops.Resize(
  166. device="gpu",
  167. resize_shorter=size,
  168. interp_type=interpolation,
  169. antialias=False,
  170. )
  171. self.cmnp = ops.CropMirrorNormalize(
  172. device="gpu",
  173. dtype=types.FLOAT,
  174. output_layout=types.NCHW,
  175. crop=(crop, crop),
  176. mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
  177. std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
  178. )
  179. def define_graph(self):
  180. self.jpegs, self.labels = self.input(name="Reader")
  181. images = self.decode(self.jpegs)
  182. images = self.res(images)
  183. output = self.cmnp(images)
  184. return [output, self.labels]
  185. class DALIWrapper(object):
  186. def gen_wrapper(dalipipeline, num_classes, one_hot, memory_format):
  187. for data in dalipipeline:
  188. input = data[0]["data"].contiguous(memory_format=memory_format)
  189. target = torch.reshape(data[0]["label"], [-1]).cuda().long()
  190. if one_hot:
  191. target = expand(num_classes, torch.float, target)
  192. yield input, target
  193. dalipipeline.reset()
  194. def __init__(self, dalipipeline, num_classes, one_hot, memory_format):
  195. self.dalipipeline = dalipipeline
  196. self.num_classes = num_classes
  197. self.one_hot = one_hot
  198. self.memory_format = memory_format
  199. def __iter__(self):
  200. return DALIWrapper.gen_wrapper(
  201. self.dalipipeline, self.num_classes, self.one_hot, self.memory_format
  202. )
  203. def get_dali_train_loader(dali_cpu=False):
  204. def gdtl(
  205. data_path,
  206. image_size,
  207. batch_size,
  208. num_classes,
  209. one_hot,
  210. interpolation="bilinear",
  211. augmentation=None,
  212. start_epoch=0,
  213. workers=5,
  214. _worker_init_fn=None,
  215. memory_format=torch.contiguous_format,
  216. **kwargs,
  217. ):
  218. if torch.distributed.is_initialized():
  219. rank = torch.distributed.get_rank()
  220. world_size = torch.distributed.get_world_size()
  221. else:
  222. rank = 0
  223. world_size = 1
  224. traindir = os.path.join(data_path, "train")
  225. if augmentation is not None:
  226. raise NotImplementedError(
  227. f"Augmentation {augmentation} for dali loader is not supported"
  228. )
  229. pipe = HybridTrainPipe(
  230. batch_size=batch_size,
  231. num_threads=workers,
  232. device_id=rank % torch.cuda.device_count(),
  233. data_dir=traindir,
  234. interpolation=interpolation,
  235. crop=image_size,
  236. dali_cpu=dali_cpu,
  237. )
  238. pipe.build()
  239. train_loader = DALIClassificationIterator(
  240. pipe, reader_name="Reader", fill_last_batch=False
  241. )
  242. return (
  243. DALIWrapper(train_loader, num_classes, one_hot, memory_format),
  244. int(pipe.epoch_size("Reader") / (world_size * batch_size)),
  245. )
  246. return gdtl
  247. def get_dali_val_loader():
  248. def gdvl(
  249. data_path,
  250. image_size,
  251. batch_size,
  252. num_classes,
  253. one_hot,
  254. interpolation="bilinear",
  255. crop_padding=32,
  256. workers=5,
  257. _worker_init_fn=None,
  258. memory_format=torch.contiguous_format,
  259. **kwargs,
  260. ):
  261. if torch.distributed.is_initialized():
  262. rank = torch.distributed.get_rank()
  263. world_size = torch.distributed.get_world_size()
  264. else:
  265. rank = 0
  266. world_size = 1
  267. valdir = os.path.join(data_path, "val")
  268. pipe = HybridValPipe(
  269. batch_size=batch_size,
  270. num_threads=workers,
  271. device_id=rank % torch.cuda.device_count(),
  272. data_dir=valdir,
  273. interpolation=interpolation,
  274. crop=image_size,
  275. size=image_size + crop_padding,
  276. )
  277. pipe.build()
  278. val_loader = DALIClassificationIterator(
  279. pipe, reader_name="Reader", fill_last_batch=False
  280. )
  281. return (
  282. DALIWrapper(val_loader, num_classes, one_hot, memory_format),
  283. int(pipe.epoch_size("Reader") / (world_size * batch_size)),
  284. )
  285. return gdvl
  286. def fast_collate(memory_format, batch):
  287. imgs = [img[0] for img in batch]
  288. targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
  289. w = imgs[0].size[0]
  290. h = imgs[0].size[1]
  291. tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8).contiguous(
  292. memory_format=memory_format
  293. )
  294. for i, img in enumerate(imgs):
  295. nump_array = np.asarray(img, dtype=np.uint8)
  296. if nump_array.ndim < 3:
  297. nump_array = np.expand_dims(nump_array, axis=-1)
  298. nump_array = np.rollaxis(nump_array, 2)
  299. tensor[i] += torch.from_numpy(nump_array.copy())
  300. return tensor, targets
  301. def expand(num_classes, dtype, tensor):
  302. e = torch.zeros(
  303. tensor.size(0), num_classes, dtype=dtype, device=torch.device("cuda")
  304. )
  305. e = e.scatter(1, tensor.unsqueeze(1), 1.0)
  306. return e
  307. class PrefetchedWrapper(object):
  308. def prefetched_loader(loader, num_classes, one_hot):
  309. mean = (
  310. torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
  311. .cuda()
  312. .view(1, 3, 1, 1)
  313. )
  314. std = (
  315. torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255])
  316. .cuda()
  317. .view(1, 3, 1, 1)
  318. )
  319. stream = torch.cuda.Stream()
  320. first = True
  321. for next_input, next_target in loader:
  322. with torch.cuda.stream(stream):
  323. next_input = next_input.cuda(non_blocking=True)
  324. next_target = next_target.cuda(non_blocking=True)
  325. next_input = next_input.float()
  326. if one_hot:
  327. next_target = expand(num_classes, torch.float, next_target)
  328. next_input = next_input.sub_(mean).div_(std)
  329. if not first:
  330. yield input, target
  331. else:
  332. first = False
  333. torch.cuda.current_stream().wait_stream(stream)
  334. input = next_input
  335. target = next_target
  336. yield input, target
  337. def __init__(self, dataloader, start_epoch, num_classes, one_hot):
  338. self.dataloader = dataloader
  339. self.epoch = start_epoch
  340. self.one_hot = one_hot
  341. self.num_classes = num_classes
  342. def __iter__(self):
  343. if self.dataloader.sampler is not None and isinstance(
  344. self.dataloader.sampler, torch.utils.data.distributed.DistributedSampler
  345. ):
  346. self.dataloader.sampler.set_epoch(self.epoch)
  347. self.epoch += 1
  348. return PrefetchedWrapper.prefetched_loader(
  349. self.dataloader, self.num_classes, self.one_hot
  350. )
  351. def __len__(self):
  352. return len(self.dataloader)
  353. def get_pytorch_train_loader(
  354. data_path,
  355. image_size,
  356. batch_size,
  357. num_classes,
  358. one_hot,
  359. interpolation="bilinear",
  360. augmentation=None,
  361. start_epoch=0,
  362. workers=5,
  363. _worker_init_fn=None,
  364. prefetch_factor=2,
  365. memory_format=torch.contiguous_format,
  366. ):
  367. interpolation = {
  368. "bicubic": InterpolationMode.BICUBIC,
  369. "bilinear": InterpolationMode.BILINEAR,
  370. }[interpolation]
  371. traindir = os.path.join(data_path, "train")
  372. transforms_list = [
  373. transforms.RandomResizedCrop(image_size, interpolation=interpolation),
  374. transforms.RandomHorizontalFlip(),
  375. ]
  376. if augmentation == "autoaugment":
  377. transforms_list.append(AutoaugmentImageNetPolicy())
  378. train_dataset = datasets.ImageFolder(traindir, transforms.Compose(transforms_list))
  379. if torch.distributed.is_initialized():
  380. train_sampler = torch.utils.data.distributed.DistributedSampler(
  381. train_dataset, shuffle=True
  382. )
  383. else:
  384. train_sampler = None
  385. train_loader = torch.utils.data.DataLoader(
  386. train_dataset,
  387. sampler=train_sampler,
  388. batch_size=batch_size,
  389. shuffle=(train_sampler is None),
  390. num_workers=workers,
  391. worker_init_fn=_worker_init_fn,
  392. pin_memory=True,
  393. collate_fn=partial(fast_collate, memory_format),
  394. drop_last=True,
  395. persistent_workers=True,
  396. prefetch_factor=prefetch_factor,
  397. )
  398. return (
  399. PrefetchedWrapper(train_loader, start_epoch, num_classes, one_hot),
  400. len(train_loader),
  401. )
  402. def get_pytorch_val_loader(
  403. data_path,
  404. image_size,
  405. batch_size,
  406. num_classes,
  407. one_hot,
  408. interpolation="bilinear",
  409. workers=5,
  410. _worker_init_fn=None,
  411. crop_padding=32,
  412. memory_format=torch.contiguous_format,
  413. prefetch_factor=2,
  414. ):
  415. interpolation = {
  416. "bicubic": InterpolationMode.BICUBIC,
  417. "bilinear": InterpolationMode.BILINEAR,
  418. }[interpolation]
  419. valdir = os.path.join(data_path, "val")
  420. val_dataset = datasets.ImageFolder(
  421. valdir,
  422. transforms.Compose(
  423. [
  424. transforms.Resize(
  425. image_size + crop_padding, interpolation=interpolation
  426. ),
  427. transforms.CenterCrop(image_size),
  428. ]
  429. ),
  430. )
  431. if torch.distributed.is_initialized():
  432. val_sampler = torch.utils.data.distributed.DistributedSampler(
  433. val_dataset, shuffle=False
  434. )
  435. else:
  436. val_sampler = None
  437. val_loader = torch.utils.data.DataLoader(
  438. val_dataset,
  439. sampler=val_sampler,
  440. batch_size=batch_size,
  441. shuffle=(val_sampler is None),
  442. num_workers=workers,
  443. worker_init_fn=_worker_init_fn,
  444. pin_memory=True,
  445. collate_fn=partial(fast_collate, memory_format),
  446. drop_last=False,
  447. persistent_workers=True,
  448. prefetch_factor=prefetch_factor,
  449. )
  450. return PrefetchedWrapper(val_loader, 0, num_classes, one_hot), len(val_loader)
  451. class SynteticDataLoader(object):
  452. def __init__(
  453. self,
  454. batch_size,
  455. num_classes,
  456. num_channels,
  457. height,
  458. width,
  459. one_hot,
  460. memory_format=torch.contiguous_format,
  461. ):
  462. input_data = (
  463. torch.randn(batch_size, num_channels, height, width)
  464. .contiguous(memory_format=memory_format)
  465. .cuda()
  466. .normal_(0, 1.0)
  467. )
  468. if one_hot:
  469. input_target = torch.empty(batch_size, num_classes).cuda()
  470. input_target[:, 0] = 1.0
  471. else:
  472. input_target = torch.randint(0, num_classes, (batch_size,))
  473. input_target = input_target.cuda()
  474. self.input_data = input_data
  475. self.input_target = input_target
  476. def __iter__(self):
  477. while True:
  478. yield self.input_data, self.input_target
  479. def get_synthetic_loader(
  480. data_path,
  481. image_size,
  482. batch_size,
  483. num_classes,
  484. one_hot,
  485. interpolation=None,
  486. augmentation=None,
  487. start_epoch=0,
  488. workers=None,
  489. _worker_init_fn=None,
  490. memory_format=torch.contiguous_format,
  491. **kwargs,
  492. ):
  493. return (
  494. SynteticDataLoader(
  495. batch_size,
  496. num_classes,
  497. 3,
  498. image_size,
  499. image_size,
  500. one_hot,
  501. memory_format=memory_format,
  502. ),
  503. -1,
  504. )