preprocess_data.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import argparse
  16. from random import shuffle
  17. import numpy as np
  18. import nibabel as nib
  19. import tensorflow as tf
  20. PARSER = argparse.ArgumentParser()
  21. PARSER.add_argument('--input_dir', '-i',
  22. type=str, help='path to the input directory with data')
  23. PARSER.add_argument('--output_dir', '-o',
  24. type=str, help='path to the output directory where tfrecord files will be stored')
  25. PARSER.add_argument('--verbose', '-v', dest='verbose', action='store_true', default=False)
  26. PARSER.add_argument('--vol_per_file', default=4, dest='vol_per_file',
  27. type=int, help='how many volumes to pack into a single tfrecord file')
  28. PARSER.add_argument('--single_data_dir', dest='single_data_dir', action='store_true', default=False)
  29. def load_features(path):
  30. data = np.zeros((240, 240, 155, 4), dtype=np.uint8)
  31. name = os.path.basename(path)
  32. for i, modality in enumerate(["_t1.nii.gz", "_t1ce.nii.gz", "_t2.nii.gz", "_flair.nii.gz"]):
  33. vol = load_single_nifti(os.path.join(path, name+modality)).astype(np.float32)
  34. vol[vol > 0.85 * vol.max()] = 0.85 * vol.max()
  35. vol = 255 * vol / vol.max()
  36. data[..., i] = vol.astype(np.uint8)
  37. return data
  38. def load_segmentation(path):
  39. path = os.path.join(path, os.path.basename(path)) + "_seg.nii.gz"
  40. return load_single_nifti(path).astype(np.uint8)
  41. def load_single_nifti(path):
  42. data = nib.load(path).get_fdata().astype(np.int16)
  43. return np.transpose(data, (1, 0, 2))
  44. def write_to_file(features_list, labels_list, foreground_mean_list, foreground_std_list, output_dir, count):
  45. output_filename = os.path.join(output_dir, "volume-{}.tfrecord".format(count))
  46. filelist = list(zip(np.array(features_list),
  47. np.array(labels_list),
  48. np.array(foreground_mean_list),
  49. np.array(foreground_std_list)))
  50. np_to_tfrecords(filelist, output_filename)
  51. def np_to_tfrecords(filelist, output_filename):
  52. writer = tf.io.TFRecordWriter(output_filename)
  53. for idx in range(len(filelist)):
  54. X = filelist[idx][0].flatten().tostring()
  55. Y = filelist[idx][1].flatten().tostring()
  56. mean = filelist[idx][2].astype(np.float32).flatten()
  57. stdev = filelist[idx][3].astype(np.float32).flatten()
  58. d_feature = {}
  59. d_feature['X'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[X]))
  60. d_feature['Y'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[Y]))
  61. d_feature['mean'] = tf.train.Feature(float_list=tf.train.FloatList(value=mean))
  62. d_feature['stdev'] = tf.train.Feature(float_list=tf.train.FloatList(value=stdev))
  63. features = tf.train.Features(feature=d_feature)
  64. example = tf.train.Example(features=features)
  65. serialized = example.SerializeToString()
  66. writer.write(serialized)
  67. writer.close()
  68. def main():
  69. # parse arguments
  70. params = PARSER.parse_args()
  71. input_dir = params.input_dir
  72. output_dir = params.output_dir
  73. os.makedirs(params.output_dir, exist_ok=True)
  74. patient_list = []
  75. if params.single_data_dir:
  76. patient_list.extend([os.path.join(input_dir, folder) for folder in os.listdir(input_dir)])
  77. else:
  78. assert "HGG" in os.listdir(input_dir) and "LGG" in os.listdir(input_dir),\
  79. "Data directory has to contain folders named HGG and LGG. " \
  80. "If you have a single folder with patient's data please set --single_data_dir flag"
  81. path_hgg = os.path.join(input_dir, "HGG")
  82. path_lgg = os.path.join(input_dir, "LGG")
  83. patient_list.extend([os.path.join(path_hgg, folder) for folder in os.listdir(path_hgg)])
  84. patient_list.extend([os.path.join(path_lgg, folder) for folder in os.listdir(path_lgg)])
  85. shuffle(patient_list)
  86. features_list = []
  87. labels_list = []
  88. foreground_mean_list = []
  89. foreground_std_list = []
  90. count = 0
  91. total_tfrecord_files = len(patient_list) // params.vol_per_file + (1 if len(patient_list) % params.vol_per_file
  92. else 0)
  93. for i, folder in enumerate(patient_list):
  94. # Calculate mean and stdev only for foreground voxels
  95. features = load_features(folder)
  96. foreground = features > 0
  97. fg_mean = np.array([(features[..., i][foreground[..., i]]).mean() for i in range(features.shape[-1])])
  98. fg_std = np.array([(features[..., i][foreground[..., i]]).std() for i in range(features.shape[-1])])
  99. # BraTS labels are 0,1,2,4 -> switching to 0,1,2,3
  100. labels = load_segmentation(folder)
  101. labels[labels == 4] = 3
  102. features_list.append(features)
  103. labels_list.append(labels)
  104. foreground_mean_list.append(fg_mean)
  105. foreground_std_list.append(fg_std)
  106. if (i+1) % params.vol_per_file == 0:
  107. write_to_file(features_list, labels_list, foreground_mean_list, foreground_std_list, output_dir, count)
  108. # Clear lists
  109. features_list = []
  110. labels_list = []
  111. foreground_mean_list = []
  112. foreground_std_list = []
  113. count += 1
  114. if params.verbose:
  115. print("{}/{} tfrecord files created".format(count, total_tfrecord_files))
  116. # create one more file if there are any remaining unpacked volumes
  117. if features_list:
  118. write_to_file(features_list, labels_list, foreground_mean_list, foreground_std_list, output_dir, count)
  119. count += 1
  120. if params.verbose:
  121. print("{}/{} tfrecord files created".format(count, total_tfrecord_files))
  122. if __name__ == '__main__':
  123. main()