data_module.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import glob
  2. import os
  3. from subprocess import call
  4. import numpy as np
  5. from joblib import Parallel, delayed
  6. from pytorch_lightning import LightningDataModule
  7. from sklearn.model_selection import KFold
  8. from tqdm import tqdm
  9. from utils.utils import get_config_file, get_task_code, is_main_process, make_empty_dir
  10. from data_loading.dali_loader import fetch_dali_loader
  11. class DataModule(LightningDataModule):
  12. def __init__(self, args):
  13. super().__init__()
  14. self.args = args
  15. self.tfrecords_train = []
  16. self.tfrecords_val = []
  17. self.tfrecords_test = []
  18. self.train_idx = []
  19. self.val_idx = []
  20. self.test_idx = []
  21. self.kfold = KFold(n_splits=self.args.nfolds, shuffle=True, random_state=12345)
  22. self.data_path = os.path.join(self.args.data, get_task_code(self.args))
  23. if self.args.exec_mode == "predict" and not args.benchmark:
  24. self.data_path = os.path.join(self.data_path, "test")
  25. configs = get_config_file(self.args)
  26. self.kwargs = {
  27. "dim": self.args.dim,
  28. "patch_size": configs["patch_size"],
  29. "seed": self.args.seed,
  30. "gpus": self.args.gpus,
  31. "num_workers": self.args.num_workers,
  32. "oversampling": self.args.oversampling,
  33. "create_idx": self.args.create_idx,
  34. "benchmark": self.args.benchmark,
  35. }
  36. def prepare_data(self):
  37. if self.args.create_idx:
  38. tfrecords_train, tfrecords_val, tfrecords_test = self.load_tfrecords()
  39. make_empty_dir("train_idx")
  40. make_empty_dir("val_idx")
  41. make_empty_dir("test_idx")
  42. self.create_idx("train_idx", tfrecords_train)
  43. self.create_idx("val_idx", tfrecords_val)
  44. self.create_idx("test_idx", tfrecords_test)
  45. def setup(self, stage=None):
  46. self.tfrecords_train, self.tfrecords_val, self.tfrecords_test = self.load_tfrecords()
  47. self.train_idx, self.val_idx, self.test_idx = self.load_idx_files()
  48. if is_main_process():
  49. ntrain, nval, ntest = len(self.tfrecords_train), len(self.tfrecords_val), len(self.tfrecords_test)
  50. print(f"Number of examples: Train {ntrain} - Val {nval} - Test {ntest}")
  51. def train_dataloader(self):
  52. return fetch_dali_loader(self.tfrecords_train, self.train_idx, self.args.batch_size, "training", **self.kwargs)
  53. def val_dataloader(self):
  54. return fetch_dali_loader(self.tfrecords_val, self.val_idx, 1, "eval", **self.kwargs)
  55. def test_dataloader(self):
  56. if self.kwargs["benchmark"]:
  57. return fetch_dali_loader(
  58. self.tfrecords_train, self.train_idx, self.args.val_batch_size, "eval", **self.kwargs
  59. )
  60. return fetch_dali_loader(self.tfrecords_test, self.test_idx, 1, "test", **self.kwargs)
  61. def load_tfrecords(self):
  62. if self.args.dim == 2:
  63. train_tfrecords = self.load_data(self.data_path, "*.train_tfrecord")
  64. val_tfrecords = self.load_data(self.data_path, "*.val_tfrecord")
  65. else:
  66. train_tfrecords = self.load_data(self.data_path, "*.tfrecord")
  67. val_tfrecords = self.load_data(self.data_path, "*.tfrecord")
  68. train_idx, val_idx = list(self.kfold.split(train_tfrecords))[self.args.fold]
  69. train_tfrecords = self.get_split(train_tfrecords, train_idx)
  70. val_tfrecords = self.get_split(val_tfrecords, val_idx)
  71. return train_tfrecords, val_tfrecords, self.load_data(os.path.join(self.data_path, "test"), "*.tfrecord")
  72. def load_idx_files(self):
  73. if self.args.create_idx:
  74. test_idx = sorted(glob.glob(os.path.join("test_idx", "*.idx")))
  75. else:
  76. test_idx = self.get_idx_list("test/idx_files", self.tfrecords_test)
  77. if self.args.create_idx:
  78. train_idx = sorted(glob.glob(os.path.join("train_idx", "*.idx")))
  79. val_idx = sorted(glob.glob(os.path.join("val_idx", "*.idx")))
  80. elif self.args.dim == 3:
  81. train_idx = self.get_idx_list("idx_files", self.tfrecords_train)
  82. val_idx = self.get_idx_list("idx_files", self.tfrecords_val)
  83. else:
  84. train_idx = self.get_idx_list("train_idx_files", self.tfrecords_train)
  85. val_idx = self.get_idx_list("val_idx_files", self.tfrecords_val)
  86. return train_idx, val_idx, test_idx
  87. def create_idx(self, idx_dir, tfrecords):
  88. idx_files = [os.path.join(idx_dir, os.path.basename(tfrec).split(".")[0] + ".idx") for tfrec in tfrecords]
  89. Parallel(n_jobs=-1)(
  90. delayed(self.tfrecord2idx)(tfrec, tfidx)
  91. for tfrec, tfidx in tqdm(zip(tfrecords, idx_files), total=len(tfrecords))
  92. )
  93. def get_idx_list(self, dir_name, tfrecords):
  94. root_dir = os.path.join(self.data_path, dir_name)
  95. return sorted([os.path.join(root_dir, os.path.basename(tfr).split(".")[0] + ".idx") for tfr in tfrecords])
  96. @staticmethod
  97. def get_split(data, idx):
  98. return list(np.array(data)[idx])
  99. @staticmethod
  100. def load_data(path, files_pattern):
  101. return sorted(glob.glob(os.path.join(path, files_pattern)))
  102. @staticmethod
  103. def tfrecord2idx(tfrecord, tfidx):
  104. call(["tfrecord2idx", tfrecord, tfidx])