dali.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from packaging.version import Version
  16. from nvidia import dali
  17. from nvidia.dali.pipeline import Pipeline
  18. import nvidia.dali.ops as ops
  19. import nvidia.dali.types as types
  20. from nvidia.dali.plugin.mxnet import DALIClassificationIterator
  21. import horovod.mxnet as hvd
  22. def add_dali_args(parser):
  23. group = parser.add_argument_group('DALI data backend', 'entire group applies only to dali data backend')
  24. group.add_argument('--dali-separ-val', action='store_true',
  25. help='each process will perform independent validation on whole val-set')
  26. group.add_argument('--dali-threads', type=int, default=6, help="number of threads" +\
  27. "per GPU for DALI")
  28. group.add_argument('--dali-validation-threads', type=int, default=10, help="number of threads" +\
  29. "per GPU for DALI for validation")
  30. group.add_argument('--dali-prefetch-queue', type=int, default=5, help="DALI prefetch queue depth")
  31. group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=64, help="Memory padding value for nvJPEG (in MB)")
  32. group.add_argument('--dali-fuse-decoder', type=int, default=1, help="0 or 1 whether to fuse decoder or not")
  33. group.add_argument('--dali-nvjpeg-width-hint', type=int, default=5980, help="Width hint value for nvJPEG (in pixels)")
  34. group.add_argument('--dali-nvjpeg-height-hint', type=int, default=6430, help="Height hint value for nvJPEG (in pixels)")
  35. group.add_argument('--dali-dont-use-mmap', default=False, action='store_true', help="Use plain I/O instead of MMAP for datasets")
  36. return parser
  37. class HybridTrainPipe(Pipeline):
  38. def __init__(self, args, batch_size, num_threads, device_id, rec_path, idx_path,
  39. shard_id, num_shards, crop_shape, nvjpeg_padding, prefetch_queue=3,
  40. output_layout=types.NCHW, pad_output=True, dtype='float16', dali_cpu=False,
  41. nvjpeg_width_hint=5980, nvjpeg_height_hint=6430,
  42. ):
  43. super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id, prefetch_queue_depth = prefetch_queue)
  44. self.input = ops.MXNetReader(path=[rec_path], index_path=[idx_path],
  45. random_shuffle=True, shard_id=shard_id, num_shards=num_shards,
  46. dont_use_mmap=args.dali_dont_use_mmap)
  47. if dali_cpu:
  48. dali_device = "cpu"
  49. decoder_device = "cpu"
  50. else:
  51. dali_device = "gpu"
  52. decoder_device = "mixed"
  53. dali_kwargs_fallback = {}
  54. if Version(dali.__version__) >= Version("1.2.0"):
  55. dali_kwargs_fallback = {
  56. "preallocate_width_hint": nvjpeg_width_hint,
  57. "preallocate_height_hint": nvjpeg_height_hint,
  58. }
  59. if args.dali_fuse_decoder:
  60. self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
  61. device_memory_padding=nvjpeg_padding,
  62. host_memory_padding=nvjpeg_padding,
  63. **dali_kwargs_fallback)
  64. else:
  65. self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB,
  66. device_memory_padding=nvjpeg_padding,
  67. host_memory_padding=nvjpeg_padding,
  68. **dali_kwargs_fallback)
  69. if args.dali_fuse_decoder:
  70. self.resize = ops.Resize(device=dali_device, resize_x=crop_shape[1], resize_y=crop_shape[0])
  71. else:
  72. self.resize = ops.RandomResizedCrop(device=dali_device, size=crop_shape)
  73. self.cmnp = ops.CropMirrorNormalize(device="gpu",
  74. output_dtype=types.FLOAT16 if dtype == 'float16' else types.FLOAT,
  75. output_layout=output_layout, crop=crop_shape, pad_output=pad_output,
  76. image_type=types.RGB, mean=args.rgb_mean, std=args.rgb_std)
  77. self.coin = ops.CoinFlip(probability=0.5)
  78. def define_graph(self):
  79. rng = self.coin()
  80. self.jpegs, self.labels = self.input(name="Reader")
  81. images = self.decode(self.jpegs)
  82. images = self.resize(images)
  83. output = self.cmnp(images.gpu(), mirror=rng)
  84. return [output, self.labels]
  85. class HybridValPipe(Pipeline):
  86. def __init__(self, args, batch_size, num_threads, device_id, rec_path, idx_path,
  87. shard_id, num_shards, crop_shape, nvjpeg_padding, prefetch_queue=3, resize_shp=None,
  88. output_layout=types.NCHW, pad_output=True, dtype='float16', dali_cpu=False,
  89. nvjpeg_width_hint=5980, nvjpeg_height_hint=6430):
  90. super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id, prefetch_queue_depth=prefetch_queue)
  91. self.input = ops.MXNetReader(path=[rec_path], index_path=[idx_path],
  92. random_shuffle=False, shard_id=shard_id, num_shards=num_shards,
  93. dont_use_mmap=args.dali_dont_use_mmap)
  94. if dali_cpu:
  95. dali_device = "cpu"
  96. decoder_device = "cpu"
  97. else:
  98. dali_device = "gpu"
  99. decoder_device = "mixed"
  100. dali_kwargs_fallback = {}
  101. if Version(dali.__version__) >= Version("1.2.0"):
  102. dali_kwargs_fallback = {
  103. "preallocate_width_hint": nvjpeg_width_hint,
  104. "preallocate_height_hint": nvjpeg_height_hint
  105. }
  106. self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB,
  107. device_memory_padding=nvjpeg_padding,
  108. host_memory_padding=nvjpeg_padding,
  109. **dali_kwargs_fallback)
  110. self.resize = ops.Resize(device=dali_device, resize_shorter=resize_shp) if resize_shp else None
  111. self.cmnp = ops.CropMirrorNormalize(device="gpu",
  112. output_dtype=types.FLOAT16 if dtype == 'float16' else types.FLOAT,
  113. output_layout=output_layout, crop=crop_shape, pad_output=pad_output,
  114. image_type=types.RGB, mean=args.rgb_mean, std=args.rgb_std)
  115. def define_graph(self):
  116. self.jpegs, self.labels = self.input(name="Reader")
  117. images = self.decode(self.jpegs)
  118. if self.resize:
  119. images = self.resize(images)
  120. output = self.cmnp(images.gpu())
  121. return [output, self.labels]
  122. def get_rec_iter(args, kv=None, dali_cpu=False):
  123. gpus = args.gpus
  124. num_threads = args.dali_threads
  125. num_validation_threads = args.dali_validation_threads
  126. pad_output = (args.image_shape[0] == 4)
  127. # the input_layout w.r.t. the model is the output_layout of the image pipeline
  128. output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW
  129. if 'horovod' in args.kv_store:
  130. rank = hvd.rank()
  131. nWrk = hvd.size()
  132. else:
  133. rank = kv.rank if kv else 0
  134. nWrk = kv.num_workers if kv else 1
  135. batch_size = args.batch_size // nWrk // len(gpus)
  136. trainpipes = [HybridTrainPipe(args = args,
  137. batch_size = batch_size,
  138. num_threads = num_threads,
  139. device_id = gpu_id,
  140. rec_path = args.data_train,
  141. idx_path = args.data_train_idx,
  142. shard_id = gpus.index(gpu_id) + len(gpus)*rank,
  143. num_shards = len(gpus)*nWrk,
  144. crop_shape = args.image_shape[1:],
  145. output_layout = output_layout,
  146. dtype = args.dtype,
  147. pad_output = pad_output,
  148. dali_cpu = dali_cpu,
  149. nvjpeg_padding = args.dali_nvjpeg_memory_padding * 1024 * 1024,
  150. prefetch_queue = args.dali_prefetch_queue,
  151. nvjpeg_width_hint = args.dali_nvjpeg_width_hint,
  152. nvjpeg_height_hint = args.dali_nvjpeg_height_hint) for gpu_id in gpus]
  153. if args.data_val:
  154. valpipes = [HybridValPipe(args = args,
  155. batch_size = batch_size,
  156. num_threads = num_validation_threads,
  157. device_id = gpu_id,
  158. rec_path = args.data_val,
  159. idx_path = args.data_val_idx,
  160. shard_id = 0 if args.dali_separ_val
  161. else gpus.index(gpu_id) + len(gpus)*rank,
  162. num_shards = 1 if args.dali_separ_val else len(gpus)*nWrk,
  163. crop_shape = args.image_shape[1:],
  164. resize_shp = args.data_val_resize,
  165. output_layout = output_layout,
  166. dtype = args.dtype,
  167. pad_output = pad_output,
  168. dali_cpu = dali_cpu,
  169. nvjpeg_padding = args.dali_nvjpeg_memory_padding * 1024 * 1024,
  170. prefetch_queue = args.dali_prefetch_queue,
  171. nvjpeg_width_hint = args.dali_nvjpeg_width_hint,
  172. nvjpeg_height_hint = args.dali_nvjpeg_height_hint) for gpu_id in gpus] if args.data_val else None
  173. trainpipes[0].build()
  174. if args.data_val:
  175. valpipes[0].build()
  176. worker_val_examples = valpipes[0].epoch_size("Reader")
  177. if not args.dali_separ_val:
  178. worker_val_examples = worker_val_examples // nWrk
  179. if rank < valpipes[0].epoch_size("Reader") % nWrk:
  180. worker_val_examples += 1
  181. if args.num_examples < trainpipes[0].epoch_size("Reader"):
  182. warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader")))
  183. dali_train_iter = DALIClassificationIterator(trainpipes, args.num_examples // nWrk)
  184. if args.data_val:
  185. dali_val_iter = DALIClassificationIterator(valpipes, worker_val_examples, fill_last_batch = False) if args.data_val else None
  186. else:
  187. dali_val_iter = None
  188. return dali_train_iter, dali_val_iter