preprocess_data.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ Preprocess dataset and prepare it for training
  15. Example usage:
  16. $ python preprocess_data.py --input_dir ./src --output_dir ./dst
  17. --vol_per_file 2
  18. All arguments are listed under `python preprocess_data.py -h`.
  19. """
  20. import os
  21. import argparse
  22. from random import shuffle
  23. import numpy as np
  24. import nibabel as nib
  25. import tensorflow as tf
  26. PARSER = argparse.ArgumentParser()
  27. PARSER.add_argument('--input_dir', '-i',
  28. type=str, help='path to the input directory with data')
  29. PARSER.add_argument('--output_dir', '-o',
  30. type=str, help='path to the output directory where tfrecord files will be stored')
  31. PARSER.add_argument('--verbose', '-v', dest='verbose', action='store_true', default=False)
  32. PARSER.add_argument('--vol_per_file', default=4, dest='vol_per_file',
  33. type=int, help='how many volumes to pack into a single tfrecord file')
  34. PARSER.add_argument('--single_data_dir', dest='single_data_dir', action='store_true', default=False)
  35. def load_features(path):
  36. """ Load features from Nifti
  37. :param path: Path to dataset
  38. :return: Loaded data
  39. """
  40. data = np.zeros((240, 240, 155, 4), dtype=np.uint8)
  41. name = os.path.basename(path)
  42. for i, modality in enumerate(["_t1.nii.gz", "_t1ce.nii.gz", "_t2.nii.gz", "_flair.nii.gz"]):
  43. vol = load_single_nifti(os.path.join(path, name + modality)).astype(np.float32)
  44. vol[vol > 0.85 * vol.max()] = 0.85 * vol.max()
  45. vol = 255 * vol / vol.max()
  46. data[..., i] = vol.astype(np.uint8)
  47. return data
  48. def load_segmentation(path):
  49. """ Load segmentations from Nifti
  50. :param path: Path to dataset
  51. :return: Loaded data
  52. """
  53. path = os.path.join(path, os.path.basename(path)) + "_seg.nii.gz"
  54. return load_single_nifti(path).astype(np.uint8)
  55. def load_single_nifti(path):
  56. """ Load Nifti file as numpy
  57. :param path: Path to file
  58. :return: Loaded data
  59. """
  60. data = nib.load(path).get_fdata().astype(np.int16)
  61. return np.transpose(data, (1, 0, 2))
  62. def write_to_file(features_list, labels_list, foreground_mean_list, foreground_std_list, output_dir, # pylint: disable=R0913
  63. count):
  64. """ Dump numpy array to tfrecord
  65. :param features_list: List of features
  66. :param labels_list: List of labels
  67. :param foreground_mean_list: List of means for each volume
  68. :param foreground_std_list: List of std for each volume
  69. :param output_dir: Directory where to write
  70. :param count: Index of the record
  71. :return:
  72. """
  73. output_filename = os.path.join(output_dir, "volume-{}.tfrecord".format(count))
  74. filelist = list(zip(np.array(features_list),
  75. np.array(labels_list),
  76. np.array(foreground_mean_list),
  77. np.array(foreground_std_list)))
  78. np_to_tfrecords(filelist, output_filename)
  79. def np_to_tfrecords(filelist, output_filename):
  80. """ Convert numpy array to tfrecord
  81. :param filelist: List of files
  82. :param output_filename: Destination directory
  83. """
  84. writer = tf.io.TFRecordWriter(output_filename)
  85. for file_item in filelist:
  86. sample = file_item[0].flatten().tostring()
  87. label = file_item[1].flatten().tostring()
  88. mean = file_item[2].astype(np.float32).flatten()
  89. stdev = file_item[3].astype(np.float32).flatten()
  90. d_feature = {}
  91. d_feature['X'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[sample]))
  92. d_feature['Y'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))
  93. d_feature['mean'] = tf.train.Feature(float_list=tf.train.FloatList(value=mean))
  94. d_feature['stdev'] = tf.train.Feature(float_list=tf.train.FloatList(value=stdev))
  95. features = tf.train.Features(feature=d_feature)
  96. example = tf.train.Example(features=features)
  97. serialized = example.SerializeToString()
  98. writer.write(serialized)
  99. writer.close()
  100. def main(): # pylint: disable=R0914
  101. """ Starting point of the application"""
  102. params = PARSER.parse_args()
  103. input_dir = params.input_dir
  104. output_dir = params.output_dir
  105. os.makedirs(params.output_dir, exist_ok=True)
  106. patient_list = []
  107. if params.single_data_dir:
  108. patient_list.extend([os.path.join(input_dir, folder) for folder in os.listdir(input_dir)])
  109. else:
  110. assert "HGG" in os.listdir(input_dir) and "LGG" in os.listdir(input_dir), \
  111. "Data directory has to contain folders named HGG and LGG. " \
  112. "If you have a single folder with patient's data please set --single_data_dir flag"
  113. path_hgg = os.path.join(input_dir, "HGG")
  114. path_lgg = os.path.join(input_dir, "LGG")
  115. patient_list.extend([os.path.join(path_hgg, folder) for folder in os.listdir(path_hgg)])
  116. patient_list.extend([os.path.join(path_lgg, folder) for folder in os.listdir(path_lgg)])
  117. shuffle(patient_list)
  118. features_list = []
  119. labels_list = []
  120. foreground_mean_list = []
  121. foreground_std_list = []
  122. count = 0
  123. total_tfrecord_files = len(patient_list) // params.vol_per_file + (1 if len(patient_list) % params.vol_per_file
  124. else 0)
  125. for i, folder in enumerate(patient_list):
  126. # Calculate mean and stdev only for foreground voxels
  127. features = load_features(folder)
  128. foreground = features > 0
  129. fg_mean = np.array([(features[..., i][foreground[..., i]]).mean() for i in range(features.shape[-1])])
  130. fg_std = np.array([(features[..., i][foreground[..., i]]).std() for i in range(features.shape[-1])])
  131. # BraTS labels are 0,1,2,4 -> switching to 0,1,2,3
  132. labels = load_segmentation(folder)
  133. labels[labels == 4] = 3
  134. features_list.append(features)
  135. labels_list.append(labels)
  136. foreground_mean_list.append(fg_mean)
  137. foreground_std_list.append(fg_std)
  138. if (i + 1) % params.vol_per_file == 0:
  139. write_to_file(features_list, labels_list, foreground_mean_list, foreground_std_list, output_dir, count)
  140. # Clear lists
  141. features_list = []
  142. labels_list = []
  143. foreground_mean_list = []
  144. foreground_std_list = []
  145. count += 1
  146. if params.verbose:
  147. print("{}/{} tfrecord files created".format(count, total_tfrecord_files))
  148. # create one more file if there are any remaining unpacked volumes
  149. if features_list:
  150. write_to_file(features_list, labels_list, foreground_mean_list, foreground_std_list, output_dir, count)
  151. count += 1
  152. if params.verbose:
  153. print("{}/{} tfrecord files created".format(count, total_tfrecord_files))
  154. if __name__ == '__main__':
  155. main()