preprocessor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. import json
  16. import math
  17. import os
  18. import pickle
  19. import monai.transforms as transforms
  20. import nibabel
  21. import numpy as np
  22. from joblib import Parallel, delayed
  23. from skimage.transform import resize
  24. from utils.utils import get_task_code, make_empty_dir
  25. from data_preprocessing.configs import (ct_max, ct_mean, ct_min, ct_std,
  26. patch_size, spacings, task)
  27. class Preprocessor:
  28. def __init__(self, args):
  29. self.args = args
  30. self.ct_min = 0
  31. self.ct_max = 0
  32. self.ct_mean = 0
  33. self.ct_std = 0
  34. self.target_spacing = None
  35. self.task = args.task
  36. self.task_code = get_task_code(args)
  37. self.patch_size = patch_size[self.task_code]
  38. self.training = args.exec_mode == "training"
  39. self.data_path = os.path.join(args.data, task[args.task])
  40. self.results = os.path.join(args.results, self.task_code)
  41. if not self.training:
  42. self.results = os.path.join(self.results, "test")
  43. self.metadata = json.load(open(os.path.join(self.data_path, "dataset.json"), "r"))
  44. self.modality = self.metadata["modality"]["0"]
  45. self.crop_foreg = transforms.CropForegroundd(keys=["image", "label"], source_key="image")
  46. self.normalize_intensity = transforms.NormalizeIntensity(nonzero=True, channel_wise=True)
  47. def run(self):
  48. make_empty_dir(self.results)
  49. print(f"Preprocessing {self.data_path}")
  50. try:
  51. self.target_spacing = spacings[self.task_code]
  52. except:
  53. self.collect_spacings()
  54. print(f"Target spacing {self.target_spacing}")
  55. if self.modality == "CT":
  56. try:
  57. self.ct_min = ct_min[self.task]
  58. self.ct_max = ct_max[self.task]
  59. self.ct_mean = ct_mean[self.task]
  60. self.ct_std = ct_std[self.task]
  61. except:
  62. self.collect_intensities()
  63. _mean = round(self.ct_mean, 2)
  64. _std = round(self.ct_std, 2)
  65. print(f"[CT] min: {self.ct_min}, max: {self.ct_max}, mean: {_mean}, std: {_std}")
  66. self.run_parallel(self.preprocess_pair, self.args.exec_mode)
  67. pickle.dump(
  68. {
  69. "patch_size": self.patch_size,
  70. "spacings": self.target_spacing,
  71. "n_class": len(self.metadata["labels"]),
  72. "in_channels": len(self.metadata["modality"]),
  73. },
  74. open(os.path.join(self.results, "config.pkl"), "wb"),
  75. )
  76. def preprocess_pair(self, pair):
  77. fname = os.path.basename(pair["image"] if self.training else pair)
  78. image, label, image_spacings = self.load_pair(pair)
  79. if self.training:
  80. data = self.crop_foreg({"image": image, "label": label})
  81. image, label = data["image"], data["label"]
  82. if self.args.dim == 3:
  83. image, label = self.resample(image, label, image_spacings)
  84. if self.modality == "CT":
  85. image = np.clip(image, self.ct_min, self.ct_max)
  86. if self.training:
  87. image, label = self.standardize(image, label)
  88. image = self.normalize(image)
  89. self.save(image, label, fname)
  90. def resample(self, image, label, image_spacings):
  91. if self.target_spacing != image_spacings:
  92. image, label = self.resample_pair(image, label, image_spacings)
  93. return image, label
  94. def standardize(self, image, label):
  95. pad_shape = self.calculate_pad_shape(image)
  96. img_shape = image.shape[1:]
  97. if pad_shape != img_shape:
  98. paddings = [(pad_sh - img_sh) / 2 for (pad_sh, img_sh) in zip(pad_shape, img_shape)]
  99. image = self.pad(image, paddings)
  100. label = self.pad(label, paddings)
  101. if self.args.dim == 2: # Center cropping 2D images.
  102. _, _, height, weight = image.shape
  103. start_h = (height - self.patch_size[0]) // 2
  104. start_w = (weight - self.patch_size[1]) // 2
  105. image = image[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]]
  106. label = label[:, :, start_h : start_h + self.patch_size[0], start_w : start_w + self.patch_size[1]]
  107. return image, label
  108. def normalize(self, image):
  109. if self.modality == "CT":
  110. return (image - self.ct_mean) / self.ct_std
  111. return self.normalize_intensity(image)
  112. def save(self, image, label, fname):
  113. mean, std = np.round(np.mean(image, (1, 2, 3)), 2), np.round(np.std(image, (1, 2, 3)), 2)
  114. print(f"Saving {fname} shape {image.shape} mean {mean} std {std}")
  115. self.save_3d(image, label, fname)
  116. def load_pair(self, pair):
  117. image = self.load_nifty(pair["image"] if self.training else pair)
  118. image_spacing = self.load_spacing(image)
  119. image = image.get_fdata().astype(np.float32)
  120. image = self.standardize_layout(image)
  121. label = None
  122. if self.training:
  123. label = self.load_nifty(pair["label"]).get_fdata().astype(np.uint8)
  124. label = self.standardize_layout(label)
  125. return image, label, image_spacing
  126. def resample_pair(self, image, label, spacing):
  127. shape = self.calculate_new_shape(spacing, image.shape[1:])
  128. if self.check_anisotrophy(spacing):
  129. image = self.resample_anisotrophic_image(image, shape)
  130. if self.training:
  131. label = self.resample_anisotrophic_label(label, shape)
  132. else:
  133. image = self.resample_regular_image(image, shape)
  134. if self.training:
  135. label = self.resample_regular_label(label, shape)
  136. image = image.astype(np.float32)
  137. if self.training:
  138. label = label.astype(np.uint8)
  139. return image, label
  140. def calculate_pad_shape(self, image):
  141. min_shape = self.patch_size[:]
  142. img_shape = image.shape[1:]
  143. if len(min_shape) == 2: # In 2D case we don't want to pad depth axis.
  144. min_shape.insert(0, img_shape[0])
  145. pad_shape = [max(mshape, ishape) for mshape, ishape in zip(min_shape, img_shape)]
  146. return pad_shape
  147. def get_intensities(self, pair):
  148. image = self.load_nifty(pair["image"]).get_fdata().astype(np.float32)
  149. label = self.load_nifty(pair["label"]).get_fdata().astype(np.uint8)
  150. foreground_idx = np.where(label > 0)
  151. intensities = image[foreground_idx].tolist()
  152. return intensities
  153. def collect_intensities(self):
  154. intensities = self.run_parallel(self.get_intensities, "training")
  155. intensities = list(itertools.chain(*intensities))
  156. self.ct_min, self.ct_max = np.percentile(intensities, [0.5, 99.5])
  157. self.ct_mean, self.ct_std = np.mean(intensities), np.std(intensities)
  158. def get_spacing(self, pair):
  159. image = nibabel.load(os.path.join(self.data_path, pair["image"]))
  160. spacing = self.load_spacing(image)
  161. return spacing
  162. def collect_spacings(self):
  163. spacing = self.run_parallel(self.get_spacing, "training")
  164. spacing = np.array(spacing)
  165. target_spacing = np.median(spacing, axis=0)
  166. if max(target_spacing) / min(target_spacing) >= 3:
  167. lowres_axis = np.argmin(target_spacing)
  168. target_spacing[lowres_axis] = np.percentile(spacing[:, lowres_axis], 10)
  169. self.target_spacing = list(target_spacing)
  170. def check_anisotrophy(self, spacing):
  171. def check(spacing):
  172. return np.max(spacing) / np.min(spacing) >= 3
  173. return check(spacing) or check(self.target_spacing)
  174. def calculate_new_shape(self, spacing, shape):
  175. spacing_ratio = np.array(spacing) / np.array(self.target_spacing)
  176. new_shape = (spacing_ratio * np.array(shape)).astype(int).tolist()
  177. return new_shape
  178. def save_3d(self, image, label, fname):
  179. self.save_npy(image, fname, "_x.npy")
  180. if self.training:
  181. self.save_npy(label, fname, "_y.npy")
  182. def save_npy(self, img, fname, suffix):
  183. np.save(os.path.join(self.results, fname.replace(".nii.gz", suffix)), img, allow_pickle=False)
  184. def run_parallel(self, func, exec_mode):
  185. return Parallel(n_jobs=self.args.n_jobs)(delayed(func)(pair) for pair in self.metadata[exec_mode])
  186. def load_nifty(self, fname):
  187. return nibabel.load(os.path.join(self.data_path, fname))
  188. @staticmethod
  189. def load_spacing(image):
  190. return image.header["pixdim"][1:4].tolist()[::-1]
  191. @staticmethod
  192. def pad(image, padding):
  193. pad_d, pad_w, pad_h = padding
  194. return np.pad(
  195. image,
  196. (
  197. (0, 0),
  198. (math.floor(pad_d), math.ceil(pad_d)),
  199. (math.floor(pad_w), math.ceil(pad_w)),
  200. (math.floor(pad_h), math.ceil(pad_h)),
  201. ),
  202. )
  203. @staticmethod
  204. def standardize_layout(data):
  205. if len(data.shape) == 3:
  206. data = np.expand_dims(data, 3)
  207. return np.transpose(data, (3, 2, 1, 0))
  208. @staticmethod
  209. def resize_fn(image, shape, order, mode):
  210. return resize(image, shape, order=order, mode=mode, cval=0, clip=True, anti_aliasing=False)
  211. def resample_anisotrophic_image(self, image, shape):
  212. resized_channels = []
  213. for image_c in image:
  214. resized = [self.resize_fn(i, shape[1:], 3, "edge") for i in image_c]
  215. resized = np.stack(resized, axis=0)
  216. resized = self.resize_fn(resized, shape, 0, "constant")
  217. resized_channels.append(resized)
  218. resized = np.stack(resized_channels, axis=0)
  219. return resized
  220. def resample_regular_image(self, image, shape):
  221. resized_channels = []
  222. for image_c in image:
  223. resized_channels.append(self.resize_fn(image_c, shape, 3, "edge"))
  224. resized = np.stack(resized_channels, axis=0)
  225. return resized
  226. def resample_anisotrophic_label(self, label, shape):
  227. depth = label.shape[1]
  228. reshaped = np.zeros(shape, dtype=np.uint8)
  229. shape_2d = shape[1:]
  230. reshaped_2d = np.zeros((depth, *shape_2d), dtype=np.uint8)
  231. n_class = np.max(label)
  232. for class_ in range(1, n_class + 1):
  233. for depth_ in range(depth):
  234. mask = label[0, depth_] == class_
  235. resized_2d = self.resize_fn(mask.astype(float), shape_2d, 1, "edge")
  236. reshaped_2d[depth_][resized_2d >= 0.5] = class_
  237. for class_ in range(1, n_class + 1):
  238. mask = reshaped_2d == class_
  239. resized = self.resize_fn(mask.astype(float), shape, 0, "constant")
  240. reshaped[resized >= 0.5] = class_
  241. reshaped = np.expand_dims(reshaped, 0)
  242. return reshaped
  243. def resample_regular_label(self, label, shape):
  244. reshaped = np.zeros(shape, dtype=np.uint8)
  245. n_class = np.max(label)
  246. for class_ in range(1, n_class + 1):
  247. mask = label[0] == class_
  248. resized = self.resize_fn(mask.astype(float), shape, 1, "edge")
  249. reshaped[resized >= 0.5] = class_
  250. reshaped = np.expand_dims(reshaped, 0)
  251. return reshaped