dali_loader.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import itertools
  2. import os
  3. import numpy as np
  4. import nvidia.dali.fn as fn
  5. import nvidia.dali.math as math
  6. import nvidia.dali.ops as ops
  7. import nvidia.dali.tfrecord as tfrec
  8. import nvidia.dali.types as types
  9. from nvidia.dali.pipeline import Pipeline
  10. from nvidia.dali.plugin.pytorch import DALIGenericIterator
  11. class TFRecordTrain(Pipeline):
  12. def __init__(self, batch_size, num_threads, device_id, **kwargs):
  13. super(TFRecordTrain, self).__init__(batch_size, num_threads, device_id)
  14. self.dim = kwargs["dim"]
  15. self.seed = kwargs["seed"]
  16. self.oversampling = kwargs["oversampling"]
  17. self.input = ops.TFRecordReader(
  18. path=kwargs["tfrecords"],
  19. index_path=kwargs["tfrecords_idx"],
  20. features={
  21. "X_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
  22. "Y_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
  23. "X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
  24. "Y": tfrec.FixedLenFeature([], tfrec.string, ""),
  25. "fname": tfrec.FixedLenFeature([], tfrec.string, ""),
  26. },
  27. num_shards=kwargs["gpus"],
  28. shard_id=device_id,
  29. random_shuffle=True,
  30. pad_last_batch=True,
  31. read_ahead=True,
  32. seed=self.seed,
  33. )
  34. self.patch_size = kwargs["patch_size"]
  35. self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
  36. self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
  37. self.layout = "CDHW" if self.dim == 3 else "CHW"
  38. self.axis_name = "DHW" if self.dim == 3 else "HW"
  39. def load_data(self, features):
  40. img = fn.reshape(features["X"], shape=features["X_shape"], layout=self.layout)
  41. lbl = fn.reshape(features["Y"], shape=features["Y_shape"], layout=self.layout)
  42. lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
  43. return img, lbl
  44. def random_augmentation(self, probability, augmented, original):
  45. condition = fn.cast(fn.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL)
  46. neg_condition = condition ^ True
  47. return condition * augmented + neg_condition * original
  48. @staticmethod
  49. def slice_fn(img, start_idx, length):
  50. return fn.slice(img, start_idx, length, axes=[0])
  51. def crop_fn(self, img, lbl):
  52. center = fn.segmentation.random_mask_pixel(lbl, foreground=fn.coin_flip(probability=self.oversampling))
  53. crop_anchor = self.slice_fn(center, 1, self.dim) - self.crop_shape // 2
  54. adjusted_anchor = math.max(0, crop_anchor)
  55. max_anchor = self.slice_fn(fn.shapes(lbl), 1, self.dim) - self.crop_shape
  56. crop_anchor = math.min(adjusted_anchor, max_anchor)
  57. img = fn.slice(img.gpu(), crop_anchor, self.crop_shape, axis_names=self.axis_name, out_of_bounds_policy="pad")
  58. lbl = fn.slice(lbl.gpu(), crop_anchor, self.crop_shape, axis_names=self.axis_name, out_of_bounds_policy="pad")
  59. return img, lbl
  60. def zoom_fn(self, img, lbl):
  61. resized_shape = self.crop_shape * self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.0)), 1.0)
  62. img, lbl = fn.crop(img, crop=resized_shape), fn.crop(lbl, crop=resized_shape)
  63. img = fn.resize(img, interp_type=types.DALIInterpType.INTERP_CUBIC, size=self.crop_shape_float)
  64. lbl = fn.resize(lbl, interp_type=types.DALIInterpType.INTERP_NN, size=self.crop_shape_float)
  65. return img, lbl
  66. def noise_fn(self, img):
  67. img_noised = img + fn.normal_distribution(img, stddev=fn.uniform(range=(0.0, 0.33)))
  68. return self.random_augmentation(0.15, img_noised, img)
  69. def blur_fn(self, img):
  70. img_blured = fn.gaussian_blur(img, sigma=fn.uniform(range=(0.5, 1.5)))
  71. return self.random_augmentation(0.15, img_blured, img)
  72. def brightness_fn(self, img):
  73. brightness_scale = self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.3)), 1.0)
  74. return img * brightness_scale
  75. def contrast_fn(self, img):
  76. min_, max_ = fn.reductions.min(img), fn.reductions.max(img)
  77. scale = self.random_augmentation(0.15, fn.uniform(range=(0.65, 1.5)), 1.0)
  78. img = math.clamp(img * scale, min_, max_)
  79. return img
  80. def flips_fn(self, img, lbl):
  81. kwargs = {"horizontal": fn.coin_flip(probability=0.33), "vertical": fn.coin_flip(probability=0.33)}
  82. if self.dim == 3:
  83. kwargs.update({"depthwise": fn.coin_flip(probability=0.33)})
  84. return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
  85. def define_graph(self):
  86. features = self.input(name="Reader")
  87. img, lbl = self.load_data(features)
  88. img, lbl = self.crop_fn(img, lbl)
  89. img, lbl = self.zoom_fn(img, lbl)
  90. img = self.noise_fn(img)
  91. img = self.blur_fn(img)
  92. img = self.brightness_fn(img)
  93. img = self.contrast_fn(img)
  94. img, lbl = self.flips_fn(img, lbl)
  95. return img, lbl
  96. class TFRecordEval(Pipeline):
  97. def __init__(self, batch_size, num_threads, device_id, **kwargs):
  98. super(TFRecordEval, self).__init__(batch_size, num_threads, device_id)
  99. self.input = ops.TFRecordReader(
  100. path=kwargs["tfrecords"],
  101. index_path=kwargs["tfrecords_idx"],
  102. features={
  103. "X_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
  104. "Y_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
  105. "X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
  106. "Y": tfrec.FixedLenFeature([], tfrec.string, ""),
  107. "fname": tfrec.FixedLenFeature([], tfrec.string, ""),
  108. },
  109. shard_id=device_id,
  110. num_shards=kwargs["gpus"],
  111. read_ahead=True,
  112. random_shuffle=False,
  113. pad_last_batch=True,
  114. )
  115. def load_data(self, features):
  116. img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout="CDHW")
  117. lbl = fn.reshape(features["Y"].gpu(), shape=features["Y_shape"], layout="CDHW")
  118. lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
  119. return img, lbl
  120. def define_graph(self):
  121. features = self.input(name="Reader")
  122. img, lbl = self.load_data(features)
  123. return img, lbl, features["fname"]
  124. class TFRecordTest(Pipeline):
  125. def __init__(self, batch_size, num_threads, device_id, **kwargs):
  126. super(TFRecordTest, self).__init__(batch_size, num_threads, device_id)
  127. self.input = ops.TFRecordReader(
  128. path=kwargs["tfrecords"],
  129. index_path=kwargs["tfrecords_idx"],
  130. features={
  131. "X_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
  132. "X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
  133. "fname": tfrec.FixedLenFeature([], tfrec.string, ""),
  134. },
  135. shard_id=device_id,
  136. num_shards=kwargs["gpus"],
  137. read_ahead=True,
  138. random_shuffle=False,
  139. pad_last_batch=True,
  140. )
  141. def define_graph(self):
  142. features = self.input(name="Reader")
  143. img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout="CDHW")
  144. return img, features["fname"]
  145. class TFRecordBenchmark(Pipeline):
  146. def __init__(self, batch_size, num_threads, device_id, **kwargs):
  147. super(TFRecordBenchmark, self).__init__(batch_size, num_threads, device_id)
  148. self.dim = kwargs["dim"]
  149. self.input = ops.TFRecordReader(
  150. path=kwargs["tfrecords"],
  151. index_path=kwargs["tfrecords_idx"],
  152. features={
  153. "X_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
  154. "Y_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
  155. "X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
  156. "Y": tfrec.FixedLenFeature([], tfrec.string, ""),
  157. "fname": tfrec.FixedLenFeature([], tfrec.string, ""),
  158. },
  159. shard_id=device_id,
  160. num_shards=kwargs["gpus"],
  161. read_ahead=True,
  162. )
  163. self.patch_size = kwargs["patch_size"]
  164. self.layout = "CDHW" if self.dim == 3 else "CHW"
  165. def load_data(self, features):
  166. img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout=self.layout)
  167. lbl = fn.reshape(features["Y"].gpu(), shape=features["Y_shape"], layout=self.layout)
  168. lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
  169. return img, lbl
  170. def crop_fn(self, img, lbl):
  171. img = fn.crop(img, crop=self.patch_size)
  172. lbl = fn.crop(lbl, crop=self.patch_size)
  173. return img, lbl
  174. def define_graph(self):
  175. features = self.input(name="Reader")
  176. img, lbl = self.load_data(features)
  177. img, lbl = self.crop_fn(img, lbl)
  178. return img, lbl
  179. class LightningWrapper(DALIGenericIterator):
  180. def __init__(self, pipe, **kwargs):
  181. super().__init__(pipe, **kwargs)
  182. def __next__(self):
  183. out = super().__next__()
  184. out = out[0]
  185. return out
  186. def fetch_dali_loader(tfrecords, idx_files, batch_size, mode, **kwargs):
  187. assert len(tfrecords) > 0, "Got empty tfrecord list"
  188. assert len(idx_files) == len(tfrecords), f"Got {len(idx_files)} index files but {len(tfrecords)} tfrecords"
  189. if kwargs["benchmark"]:
  190. tfrecords = list(itertools.chain(*(20 * [tfrecords])))
  191. idx_files = list(itertools.chain(*(20 * [idx_files])))
  192. pipe_kwargs = {
  193. "tfrecords": tfrecords,
  194. "tfrecords_idx": idx_files,
  195. "gpus": kwargs["gpus"],
  196. "seed": kwargs["seed"],
  197. "patch_size": kwargs["patch_size"],
  198. "dim": kwargs["dim"],
  199. "oversampling": kwargs["oversampling"],
  200. }
  201. if kwargs["benchmark"] and mode == "eval":
  202. pipeline = TFRecordBenchmark
  203. output_map = ["image", "label"]
  204. dynamic_shape = False
  205. elif mode == "training":
  206. pipeline = TFRecordTrain
  207. output_map = ["image", "label"]
  208. dynamic_shape = False
  209. elif mode == "eval":
  210. pipeline = TFRecordEval
  211. output_map = ["image", "label", "fname"]
  212. dynamic_shape = True
  213. else:
  214. pipeline = TFRecordTest
  215. output_map = ["image", "fname"]
  216. dynamic_shape = True
  217. device_id = int(os.getenv("LOCAL_RANK", "0"))
  218. pipe = pipeline(batch_size, kwargs["num_workers"], device_id, **pipe_kwargs)
  219. return LightningWrapper(
  220. pipe,
  221. auto_reset=True,
  222. reader_name="Reader",
  223. output_map=output_map,
  224. dynamic_shape=dynamic_shape,
  225. )