data_module.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pytorch_lightning import LightningDataModule
  15. from sklearn.model_selection import KFold
  16. from utils.utils import get_config_file, get_path, get_split, get_test_fnames, is_main_process, load_data
  17. from data_loading.dali_loader import fetch_dali_loader
  18. class DataModule(LightningDataModule):
  19. def __init__(self, args):
  20. super().__init__()
  21. self.args = args
  22. self.train_imgs = []
  23. self.train_lbls = []
  24. self.val_imgs = []
  25. self.val_lbls = []
  26. self.test_imgs = []
  27. self.kfold = KFold(n_splits=self.args.nfolds, shuffle=True, random_state=12345)
  28. self.data_path = get_path(args)
  29. configs = get_config_file(self.args)
  30. self.kwargs = {
  31. "dim": self.args.dim,
  32. "patch_size": configs["patch_size"],
  33. "seed": self.args.seed,
  34. "gpus": self.args.gpus,
  35. "num_workers": self.args.num_workers,
  36. "oversampling": self.args.oversampling,
  37. "benchmark": self.args.benchmark,
  38. "nvol": self.args.nvol,
  39. "train_batches": self.args.train_batches,
  40. "test_batches": self.args.test_batches,
  41. "meta": load_data(self.data_path, "*_meta.npy"),
  42. }
  43. def setup(self, stage=None):
  44. imgs = load_data(self.data_path, "*_x.npy")
  45. lbls = load_data(self.data_path, "*_y.npy")
  46. self.test_imgs, self.kwargs["meta"] = get_test_fnames(self.args, self.data_path, self.kwargs["meta"])
  47. if self.args.exec_mode != "predict" or self.args.benchmark:
  48. train_idx, val_idx = list(self.kfold.split(imgs))[self.args.fold]
  49. self.train_imgs = get_split(imgs, train_idx)
  50. self.train_lbls = get_split(lbls, train_idx)
  51. self.val_imgs = get_split(imgs, val_idx)
  52. self.val_lbls = get_split(lbls, val_idx)
  53. if is_main_process():
  54. ntrain, nval = len(self.train_imgs), len(self.val_imgs)
  55. print(f"Number of examples: Train {ntrain} - Val {nval}")
  56. elif is_main_process():
  57. print(f"Number of test examples: {len(self.test_imgs)}")
  58. def train_dataloader(self):
  59. return fetch_dali_loader(self.train_imgs, self.train_lbls, self.args.batch_size, "train", **self.kwargs)
  60. def val_dataloader(self):
  61. return fetch_dali_loader(self.val_imgs, self.val_lbls, 1, "eval", **self.kwargs)
  62. def test_dataloader(self):
  63. if self.kwargs["benchmark"]:
  64. return fetch_dali_loader(self.train_imgs, self.train_lbls, self.args.val_batch_size, "test", **self.kwargs)
  65. return fetch_dali_loader(self.test_imgs, None, 1, "test", **self.kwargs)