dali.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from nvidia import dali
  16. from nvidia.dali.pipeline import Pipeline
  17. import nvidia.dali.ops as ops
  18. import nvidia.dali.types as types
  19. from nvidia.dali.plugin.mxnet import DALIClassificationIterator
  20. def add_dali_args(parser):
  21. group = parser.add_argument_group('DALI', 'pipeline and augumentation')
  22. group.add_argument('--use-dali', action='store_true',
  23. help='use dalli pipeline and augunetation')
  24. group.add_argument('--separ-val', action='store_true',
  25. help='each process will perform independent validation on whole val-set')
  26. group.add_argument('--dali-threads', type=int, default=3, help="number of threads" +\
  27. "per GPU for DALI")
  28. group.add_argument('--validation-dali-threads', type=int, default=10, help="number of threads" +\
  29. "per GPU for DALI for validation")
  30. group.add_argument('--dali-prefetch-queue', type=int, default=3, help="DALI prefetch queue depth")
  31. group.add_argument('--dali-nvjpeg-memory-padding', type=int, default=16, help="Memory padding value for nvJPEG (in MB)")
  32. return parser
  33. _mean_pixel = [255 * x for x in (0.485, 0.456, 0.406)]
  34. _std_pixel = [255 * x for x in (0.229, 0.224, 0.225)]
  35. class HybridTrainPipe(Pipeline):
  36. def __init__(self, batch_size, num_threads, device_id, rec_path, idx_path,
  37. shard_id, num_shards, crop_shape,
  38. nvjpeg_padding, prefetch_queue=3,
  39. output_layout=types.NCHW, pad_output=True, dtype='float16'):
  40. super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed = 12 + device_id, prefetch_queue_depth = prefetch_queue)
  41. self.input = ops.MXNetReader(path = [rec_path], index_path=[idx_path],
  42. random_shuffle=True, shard_id=shard_id, num_shards=num_shards)
  43. self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB,
  44. device_memory_padding = nvjpeg_padding,
  45. host_memory_padding = nvjpeg_padding)
  46. self.rrc = ops.RandomResizedCrop(device = "gpu", size = crop_shape)
  47. self.cmnp = ops.CropMirrorNormalize(device = "gpu",
  48. output_dtype = types.FLOAT16 if dtype == 'float16' else types.FLOAT,
  49. output_layout = output_layout,
  50. crop = crop_shape,
  51. pad_output = pad_output,
  52. image_type = types.RGB,
  53. mean = _mean_pixel,
  54. std = _std_pixel)
  55. self.coin = ops.CoinFlip(probability = 0.5)
  56. def define_graph(self):
  57. rng = self.coin()
  58. self.jpegs, self.labels = self.input(name = "Reader")
  59. images = self.decode(self.jpegs)
  60. images = self.rrc(images)
  61. output = self.cmnp(images, mirror = rng)
  62. return [output, self.labels]
  63. class HybridValPipe(Pipeline):
  64. def __init__(self, batch_size, num_threads, device_id, rec_path, idx_path,
  65. shard_id, num_shards, crop_shape,
  66. nvjpeg_padding, prefetch_queue=3,
  67. resize_shp=None,
  68. output_layout=types.NCHW, pad_output=True, dtype='float16'):
  69. super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed = 12 + device_id, prefetch_queue_depth = prefetch_queue)
  70. self.input = ops.MXNetReader(path = [rec_path], index_path=[idx_path],
  71. random_shuffle=False, shard_id=shard_id, num_shards=num_shards)
  72. self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB,
  73. device_memory_padding = nvjpeg_padding,
  74. host_memory_padding = nvjpeg_padding)
  75. self.resize = ops.Resize(device = "gpu", resize_shorter=resize_shp) if resize_shp else None
  76. self.cmnp = ops.CropMirrorNormalize(device = "gpu",
  77. output_dtype = types.FLOAT16 if dtype == 'float16' else types.FLOAT,
  78. output_layout = output_layout,
  79. crop = crop_shape,
  80. pad_output = pad_output,
  81. image_type = types.RGB,
  82. mean = _mean_pixel,
  83. std = _std_pixel)
  84. def define_graph(self):
  85. self.jpegs, self.labels = self.input(name = "Reader")
  86. images = self.decode(self.jpegs)
  87. if self.resize:
  88. images = self.resize(images)
  89. output = self.cmnp(images)
  90. return [output, self.labels]
  91. def get_rec_iter(args, kv=None):
  92. # resize is default base length of shorter edge for dataset;
  93. # all images will be reshaped to this size
  94. resize = int(args.resize)
  95. # target shape is final shape of images pipelined to network;
  96. # all images will be cropped to this size
  97. target_shape = tuple([int(l) for l in args.image_shape.split(',')])
  98. pad_output = target_shape[0] == 4
  99. gpus = list(map(int, filter(None, args.gpus.split(',')))) # filter to not encount eventually empty strings
  100. batch_size = args.batch_size//len(gpus)
  101. num_threads = args.dali_threads
  102. num_validation_threads = args.validation_dali_threads
  103. #db_folder = "/data/imagenet/train-480-val-256-recordio/"
  104. # the input_layout w.r.t. the model is the output_layout of the image pipeline
  105. output_layout = types.NHWC if args.input_layout == 'NHWC' else types.NCHW
  106. rank = kv.rank if kv else 0
  107. nWrk = kv.num_workers if kv else 1
  108. trainpipes = [HybridTrainPipe(batch_size = batch_size,
  109. num_threads = num_threads,
  110. device_id = gpu_id,
  111. rec_path = args.data_train,
  112. idx_path = args.data_train_idx,
  113. shard_id = gpus.index(gpu_id) + len(gpus)*rank,
  114. num_shards = len(gpus)*nWrk,
  115. crop_shape = target_shape[1:],
  116. output_layout = output_layout,
  117. pad_output = pad_output,
  118. dtype = args.dtype,
  119. nvjpeg_padding = args.dali_nvjpeg_memory_padding * 1024 * 1024,
  120. prefetch_queue = args.dali_prefetch_queue) for gpu_id in gpus]
  121. valpipes = [HybridValPipe(batch_size = batch_size,
  122. num_threads = num_validation_threads,
  123. device_id = gpu_id,
  124. rec_path = args.data_val,
  125. idx_path = args.data_val_idx,
  126. shard_id = 0 if args.separ_val
  127. else gpus.index(gpu_id) + len(gpus)*rank,
  128. num_shards = 1 if args.separ_val else len(gpus)*nWrk,
  129. crop_shape = target_shape[1:],
  130. resize_shp = resize,
  131. output_layout = output_layout,
  132. pad_output = pad_output,
  133. dtype = args.dtype,
  134. nvjpeg_padding = args.dali_nvjpeg_memory_padding * 1024 * 1024,
  135. prefetch_queue = args.dali_prefetch_queue) for gpu_id in gpus] if args.data_val else None
  136. trainpipes[0].build()
  137. if args.data_val:
  138. valpipes[0].build()
  139. if args.num_examples < trainpipes[0].epoch_size("Reader"):
  140. warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader")))
  141. dali_train_iter = DALIClassificationIterator(trainpipes, args.num_examples // nWrk)
  142. dali_val_iter = DALIClassificationIterator(valpipes, valpipes[0].epoch_size("Reader") // (1 if args.separ_val else nWrk), fill_last_batch = False) if args.data_val else None
  143. return dali_train_iter, dali_val_iter