preprocess_data.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. """
  2. Convert LiTS 2017 (Liver Tumor Segmentation) data into UNet3+ data format
  3. LiTS: https://competitions.codalab.org/competitions/17094
  4. """
  5. import os
  6. import sys
  7. from glob import glob
  8. from pathlib import Path
  9. from tqdm import tqdm
  10. import numpy as np
  11. import multiprocessing as mp
  12. import cv2
  13. import nibabel as nib
  14. import hydra
  15. from omegaconf import DictConfig
  16. sys.path.append(os.path.abspath("./"))
  17. from utils.general_utils import create_directory, join_paths
  18. from utils.images_utils import resize_image
  19. def read_nii(filepath):
  20. """
  21. Reads .nii file and returns pixel array
  22. """
  23. ct_scan = nib.load(filepath).get_fdata()
  24. # TODO: Verify images orientation
  25. # in both train and test set, especially on train scan 130
  26. ct_scan = np.rot90(np.array(ct_scan))
  27. return ct_scan
  28. def crop_center(img, croph, cropw):
  29. """
  30. Center crop on given height and width
  31. """
  32. height, width = img.shape[:2]
  33. starth = height // 2 - (croph // 2)
  34. startw = width // 2 - (cropw // 2)
  35. return img[starth:starth + croph, startw:startw + cropw, :]
  36. def linear_scale(img):
  37. """
  38. First convert image to range of 0-1 and them scale to 255
  39. """
  40. img = (img - img.min(axis=(0, 1))) / (img.max(axis=(0, 1)) - img.min(axis=(0, 1)))
  41. return img * 255
  42. def clip_scan(img, min_value, max_value):
  43. """
  44. Clip scan to given range
  45. """
  46. return np.clip(img, min_value, max_value)
  47. def resize_scan(scan, new_height, new_width, scan_type):
  48. """
  49. Resize CT scan to given size
  50. """
  51. scan_shape = scan.shape
  52. resized_scan = np.zeros((new_height, new_width, scan_shape[2]), dtype=scan.dtype)
  53. resize_method = cv2.INTER_CUBIC if scan_type == "image" else cv2.INTER_NEAREST
  54. for start in range(0, scan_shape[2], scan_shape[1]):
  55. end = start + scan_shape[1]
  56. if end >= scan_shape[2]: end = scan_shape[2]
  57. resized_scan[:, :, start:end] = resize_image(
  58. scan[:, :, start:end],
  59. new_height, new_width,
  60. resize_method
  61. )
  62. return resized_scan
  63. def save_images(scan, save_path, img_index):
  64. """
  65. Based on UNet3+ requirement "input image had three channels, including
  66. the slice to be segmented and the upper and lower slices, which was
  67. cropped to 320×320" save each scan as separate image with previous and
  68. next scan concatenated.
  69. """
  70. scan_shape = scan.shape
  71. for index in range(scan_shape[-1]):
  72. before_index = index - 1 if (index - 1) > 0 else 0
  73. after_index = index + 1 if (index + 1) < scan_shape[-1] else scan_shape[-1] - 1
  74. new_img_path = join_paths(save_path, f"image_{img_index}_{index}.png")
  75. new_image = np.stack(
  76. (
  77. scan[:, :, before_index],
  78. scan[:, :, index],
  79. scan[:, :, after_index]
  80. )
  81. , axis=-1)
  82. new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR) # RGB to BGR
  83. cv2.imwrite(new_img_path, new_image) # save the images as .png
  84. def save_mask(scan, save_path, mask_index):
  85. """
  86. Save each scan as separate mask
  87. """
  88. for index in range(scan.shape[-1]):
  89. new_mask_path = join_paths(save_path, f"mask_{mask_index}_{index}.png")
  90. cv2.imwrite(new_mask_path, scan[:, :, index]) # save grey scale image
  91. def extract_image(cfg, image_path, save_path, scan_type="image", ):
  92. """
  93. Extract image from given scan path
  94. """
  95. _, index = str(Path(image_path).stem).split("-")
  96. scan = read_nii(image_path)
  97. scan = resize_scan(
  98. scan,
  99. cfg.DATA_PREPARATION.RESIZED_HEIGHT,
  100. cfg.DATA_PREPARATION.RESIZED_WIDTH,
  101. scan_type
  102. )
  103. if scan_type == "image":
  104. scan = clip_scan(
  105. scan,
  106. cfg.DATA_PREPARATION.SCAN_MIN_VALUE,
  107. cfg.DATA_PREPARATION.SCAN_MAX_VALUE
  108. )
  109. scan = linear_scale(scan)
  110. scan = np.uint8(scan)
  111. save_images(scan, save_path, index)
  112. else:
  113. # 0 for background/non-lesion, 1 for liver, 2 for lesion/tumor
  114. # merging label 2 into label 1, because lesion/tumor is part of liver
  115. scan = np.where(scan != 0, 1, scan)
  116. # scan = np.where(scan==2, 1, scan)
  117. scan = np.uint8(scan)
  118. save_mask(scan, save_path, index)
  119. def extract_images(cfg, images_path, save_path, scan_type="image", ):
  120. """
  121. Extract images paths using multiprocessing and pass to
  122. extract_image function for further processing .
  123. """
  124. # create pool
  125. process_count = np.clip(mp.cpu_count() - 2, 1, 20) # less than 20 workers
  126. pool = mp.Pool(process_count)
  127. for image_path in tqdm(images_path):
  128. pool.apply_async(extract_image,
  129. args=(cfg, image_path, save_path, scan_type),
  130. )
  131. # close pool
  132. pool.close()
  133. pool.join()
  134. @hydra.main(version_base=None, config_path="../configs", config_name="config")
  135. def preprocess_lits_data(cfg: DictConfig):
  136. """
  137. Preprocess LiTS 2017 (Liver Tumor Segmentation) data by extractions
  138. images and mask into UNet3+ data format
  139. """
  140. train_images_names = glob(
  141. join_paths(
  142. cfg.WORK_DIR,
  143. cfg.DATA_PREPARATION.SCANS_TRAIN_DATA_PATH,
  144. "volume-*.nii"
  145. )
  146. )
  147. train_mask_names = glob(
  148. join_paths(
  149. cfg.WORK_DIR,
  150. cfg.DATA_PREPARATION.SCANS_TRAIN_DATA_PATH,
  151. "segmentation-*.nii"
  152. )
  153. )
  154. assert len(train_images_names) == len(train_mask_names), \
  155. "Train volumes and segmentations are not same in length"
  156. val_images_names = glob(
  157. join_paths(
  158. cfg.WORK_DIR,
  159. cfg.DATA_PREPARATION.SCANS_VAL_DATA_PATH,
  160. "volume-*.nii"
  161. )
  162. )
  163. val_mask_names = glob(
  164. join_paths(
  165. cfg.WORK_DIR,
  166. cfg.DATA_PREPARATION.SCANS_VAL_DATA_PATH,
  167. "segmentation-*.nii"
  168. )
  169. )
  170. assert len(val_images_names) == len(val_mask_names), \
  171. "Validation volumes and segmentations are not same in length"
  172. train_images_names = sorted(train_images_names)
  173. train_mask_names = sorted(train_mask_names)
  174. val_images_names = sorted(val_images_names)
  175. val_mask_names = sorted(val_mask_names)
  176. train_images_path = join_paths(
  177. cfg.WORK_DIR, cfg.DATASET.TRAIN.IMAGES_PATH
  178. )
  179. train_mask_path = join_paths(
  180. cfg.WORK_DIR, cfg.DATASET.TRAIN.MASK_PATH
  181. )
  182. val_images_path = join_paths(
  183. cfg.WORK_DIR, cfg.DATASET.VAL.IMAGES_PATH
  184. )
  185. val_mask_path = join_paths(
  186. cfg.WORK_DIR, cfg.DATASET.VAL.MASK_PATH
  187. )
  188. create_directory(train_images_path)
  189. create_directory(train_mask_path)
  190. create_directory(val_images_path)
  191. create_directory(val_mask_path)
  192. print("\nExtracting train images")
  193. extract_images(
  194. cfg, train_images_names, train_images_path, scan_type="image"
  195. )
  196. print("\nExtracting train mask")
  197. extract_images(
  198. cfg, train_mask_names, train_mask_path, scan_type="mask"
  199. )
  200. print("\nExtracting val images")
  201. extract_images(
  202. cfg, val_images_names, val_images_path, scan_type="image"
  203. )
  204. print("\nExtracting val mask")
  205. extract_images(
  206. cfg, val_mask_names, val_mask_path, scan_type="mask"
  207. )
  208. if __name__ == '__main__':
  209. preprocess_lits_data()