dali_loader.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. import horovod.tensorflow as hvd
  16. import numpy as np
  17. import nvidia.dali.fn as fn
  18. import nvidia.dali.ops as ops
  19. import nvidia.dali.plugin.tf as dali_tf
  20. import nvidia.dali.types as types
  21. import tensorflow as tf
  22. from nvidia.dali.pipeline import Pipeline
  23. def get_numpy_reader(files, shard_id, num_shards, seed, shuffle):
  24. return ops.readers.Numpy(
  25. seed=seed,
  26. files=files,
  27. device="cpu",
  28. read_ahead=True,
  29. shard_id=shard_id,
  30. pad_last_batch=True,
  31. num_shards=num_shards,
  32. dont_use_mmap=True,
  33. shuffle_after_epoch=shuffle,
  34. )
  35. def random_augmentation(probability, augmented, original):
  36. condition = fn.cast(fn.random.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL)
  37. neg_condition = condition ^ True
  38. return condition * augmented + neg_condition * original
  39. class GenericPipeline(Pipeline):
  40. def __init__(
  41. self,
  42. batch_size,
  43. num_threads,
  44. shard_id,
  45. seed,
  46. num_gpus,
  47. dim,
  48. shuffle_input=True,
  49. input_x_files=None,
  50. input_y_files=None,
  51. ):
  52. super().__init__(
  53. batch_size=batch_size,
  54. num_threads=num_threads,
  55. device_id=hvd.rank(),
  56. seed=seed,
  57. )
  58. if input_x_files is not None:
  59. self.input_x = get_numpy_reader(
  60. files=input_x_files,
  61. shard_id=shard_id,
  62. seed=seed,
  63. num_shards=num_gpus,
  64. shuffle=shuffle_input,
  65. )
  66. if input_y_files is not None:
  67. self.input_y = get_numpy_reader(
  68. files=input_y_files,
  69. shard_id=shard_id,
  70. seed=seed,
  71. num_shards=num_gpus,
  72. shuffle=shuffle_input,
  73. )
  74. self.dim = dim
  75. self.internal_seed = seed
  76. class TrainPipeline(GenericPipeline):
  77. def __init__(self, imgs, lbls, oversampling, patch_size, batch_size_2d=None, **kwargs):
  78. super().__init__(input_x_files=imgs, input_y_files=lbls, shuffle_input=True, **kwargs)
  79. self.oversampling = oversampling
  80. self.patch_size = patch_size
  81. if self.dim == 2 and batch_size_2d is not None:
  82. self.patch_size = [batch_size_2d] + self.patch_size
  83. self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
  84. self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
  85. def load_data(self):
  86. img, lbl = self.input_x(name="ReaderX"), self.input_y(name="ReaderY")
  87. img, lbl = fn.reshape(img, layout="DHWC"), fn.reshape(lbl, layout="DHWC")
  88. return img, lbl
  89. @staticmethod
  90. def slice_fn(img):
  91. return fn.slice(img, 1, 3, axes=[0])
  92. def biased_crop_fn(self, img, lbl):
  93. roi_start, roi_end = fn.segmentation.random_object_bbox(
  94. lbl,
  95. format="start_end",
  96. foreground_prob=self.oversampling,
  97. k_largest=2,
  98. device="cpu",
  99. cache_objects=True,
  100. )
  101. anchor = fn.roi_random_crop(
  102. lbl,
  103. roi_start=roi_start,
  104. roi_end=roi_end,
  105. crop_shape=[*self.patch_size, 1],
  106. )
  107. anchor = fn.slice(anchor, 0, 3, axes=[0])
  108. img, lbl = fn.slice(
  109. [img, lbl],
  110. anchor,
  111. self.crop_shape,
  112. axis_names="DHW",
  113. out_of_bounds_policy="pad",
  114. device="cpu",
  115. )
  116. img, lbl = img.gpu(), lbl.gpu()
  117. return img, lbl
  118. def zoom_fn(self, img, lbl):
  119. scale = random_augmentation(0.15, fn.random.uniform(range=(0.7, 1.0)), 1.0)
  120. d, h, w = [scale * x for x in self.patch_size]
  121. if self.dim == 2:
  122. d = self.patch_size[0]
  123. img, lbl = fn.crop(img, crop_h=h, crop_w=w, crop_d=d), fn.crop(lbl, crop_h=h, crop_w=w, crop_d=d)
  124. img = fn.resize(
  125. img,
  126. interp_type=types.DALIInterpType.INTERP_CUBIC,
  127. size=self.crop_shape_float,
  128. )
  129. lbl = fn.resize(lbl, interp_type=types.DALIInterpType.INTERP_NN, size=self.crop_shape_float)
  130. return img, lbl
  131. def noise_fn(self, img):
  132. img_noised = fn.noise.gaussian(img, stddev=fn.random.uniform(range=(0.0, 0.3)))
  133. return random_augmentation(0.15, img_noised, img)
  134. def blur_fn(self, img):
  135. img_blurred = fn.gaussian_blur(img, sigma=fn.random.uniform(range=(0.5, 1.5)))
  136. return random_augmentation(0.15, img_blurred, img)
  137. def brightness_contrast_fn(self, img):
  138. img_transformed = fn.brightness_contrast(
  139. img, brightness=fn.random.uniform(range=(0.7, 1.3)), contrast=fn.random.uniform(range=(0.65, 1.5))
  140. )
  141. return random_augmentation(0.15, img_transformed, img)
  142. def flips_fn(self, img, lbl):
  143. kwargs = {
  144. "horizontal": fn.random.coin_flip(probability=0.5),
  145. "vertical": fn.random.coin_flip(probability=0.5),
  146. }
  147. if self.dim == 3:
  148. kwargs.update({"depthwise": fn.random.coin_flip(probability=0.5)})
  149. return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
  150. def define_graph(self):
  151. img, lbl = self.load_data()
  152. img, lbl = self.biased_crop_fn(img, lbl)
  153. img, lbl = self.zoom_fn(img, lbl)
  154. img, lbl = self.flips_fn(img, lbl)
  155. img = self.noise_fn(img)
  156. img = self.blur_fn(img)
  157. img = self.brightness_contrast_fn(img)
  158. return img, lbl
  159. class EvalPipeline(GenericPipeline):
  160. def __init__(self, imgs, lbls, patch_size, **kwargs):
  161. super().__init__(input_x_files=imgs, input_y_files=lbls, shuffle_input=False, **kwargs)
  162. self.patch_size = patch_size
  163. def define_graph(self):
  164. img, lbl = self.input_x(name="ReaderX").gpu(), self.input_y(name="ReaderY").gpu()
  165. img, lbl = fn.reshape(img, layout="DHWC"), fn.reshape(lbl, layout="DHWC")
  166. return img, lbl
  167. class TestPipeline(GenericPipeline):
  168. def __init__(self, imgs, meta, **kwargs):
  169. super().__init__(input_x_files=imgs, input_y_files=meta, shuffle_input=False, **kwargs)
  170. def define_graph(self):
  171. img, meta = self.input_x(name="ReaderX").gpu(), self.input_y(name="ReaderY").gpu()
  172. img = fn.reshape(img, layout="DHWC")
  173. return img, meta
  174. class BenchmarkPipeline(GenericPipeline):
  175. def __init__(self, imgs, lbls, patch_size, batch_size_2d=None, **kwargs):
  176. super().__init__(input_x_files=imgs, input_y_files=lbls, shuffle_input=False, **kwargs)
  177. self.patch_size = patch_size
  178. if self.dim == 2 and batch_size_2d is not None:
  179. self.patch_size = [batch_size_2d] + self.patch_size
  180. def crop_fn(self, img, lbl):
  181. img = fn.crop(img, crop=self.patch_size, out_of_bounds_policy="pad")
  182. lbl = fn.crop(lbl, crop=self.patch_size, out_of_bounds_policy="pad")
  183. return img, lbl
  184. def define_graph(self):
  185. img, lbl = self.input_x(name="ReaderX").gpu(), self.input_y(name="ReaderY").gpu()
  186. img, lbl = self.crop_fn(img, lbl)
  187. img, lbl = fn.reshape(img, layout="DHWC"), fn.reshape(lbl, layout="DHWC")
  188. return img, lbl
  189. def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
  190. assert len(imgs) > 0, "No images found"
  191. if lbls is not None:
  192. assert len(imgs) == len(lbls), f"Got {len(imgs)} images but {len(lbls)} lables"
  193. gpus = hvd.size()
  194. device_id = hvd.rank()
  195. if kwargs["benchmark"]:
  196. # Just to make sure the number of examples is large enough for benchmark run.
  197. nbs = kwargs["bench_steps"]
  198. if kwargs["dim"] == 3:
  199. nbs *= batch_size
  200. imgs = list(itertools.chain(*(100 * [imgs])))[: nbs * gpus]
  201. lbls = list(itertools.chain(*(100 * [lbls])))[: nbs * gpus]
  202. pipe_kwargs = {
  203. "dim": kwargs["dim"],
  204. "num_gpus": gpus,
  205. "seed": kwargs["seed"],
  206. "batch_size": batch_size,
  207. "num_threads": kwargs["num_workers"],
  208. "shard_id": device_id,
  209. }
  210. if kwargs["dim"] == 2:
  211. if kwargs["benchmark"]:
  212. pipe_kwargs.update({"batch_size_2d": batch_size})
  213. batch_size = 1
  214. elif mode == "train":
  215. pipe_kwargs.update({"batch_size_2d": batch_size // kwargs["nvol"]})
  216. batch_size = kwargs["nvol"]
  217. if mode == "eval": # Validation data is manually sharded beforehand.
  218. pipe_kwargs["shard_id"] = 0
  219. pipe_kwargs["num_gpus"] = 1
  220. output_dtypes = (tf.float32, tf.uint8)
  221. if kwargs["benchmark"]:
  222. pipeline = BenchmarkPipeline(imgs, lbls, kwargs["patch_size"], **pipe_kwargs)
  223. elif mode == "train":
  224. pipeline = TrainPipeline(imgs, lbls, kwargs["oversampling"], kwargs["patch_size"], **pipe_kwargs)
  225. elif mode == "eval":
  226. pipeline = EvalPipeline(imgs, lbls, kwargs["patch_size"], **pipe_kwargs)
  227. else:
  228. pipeline = TestPipeline(imgs, kwargs["meta"], **pipe_kwargs)
  229. output_dtypes = (tf.float32, tf.int64)
  230. tf_pipe = dali_tf.DALIDataset(pipeline, batch_size=batch_size, device_id=device_id, output_dtypes=output_dtypes)
  231. return tf_pipe