data_module.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import horovod.tensorflow as hvd
  15. from runtime.utils import get_config_file, is_main_process
  16. from sklearn.model_selection import KFold
  17. from data_loading.dali_loader import fetch_dali_loader
  18. from data_loading.utils import get_path, get_split, get_test_fnames, load_data
  19. class DataModule:
  20. def __init__(self, args):
  21. super().__init__()
  22. self.args = args
  23. self.train_imgs = []
  24. self.train_lbls = []
  25. self.val_imgs = []
  26. self.val_lbls = []
  27. self.test_imgs = []
  28. self.kfold = KFold(n_splits=self.args.nfolds, shuffle=True, random_state=12345)
  29. self.data_path = get_path(args)
  30. configs = get_config_file(self.args)
  31. self.patch_size = configs["patch_size"]
  32. self.kwargs = {
  33. "dim": self.args.dim,
  34. "patch_size": self.patch_size,
  35. "seed": self.args.seed,
  36. "gpus": hvd.size(),
  37. "num_workers": self.args.num_workers,
  38. "oversampling": self.args.oversampling,
  39. "benchmark": self.args.benchmark,
  40. "nvol": self.args.nvol,
  41. "bench_steps": self.args.bench_steps,
  42. "meta": load_data(self.data_path, "*_meta.npy"),
  43. }
  44. def setup(self, stage=None):
  45. imgs = load_data(self.data_path, "*_x.npy")
  46. lbls = load_data(self.data_path, "*_y.npy")
  47. self.test_imgs, self.kwargs["meta"] = get_test_fnames(self.args, self.data_path, self.kwargs["meta"])
  48. if self.args.exec_mode != "predict" or self.args.benchmark:
  49. train_idx, val_idx = list(self.kfold.split(imgs))[self.args.fold]
  50. self.train_imgs = get_split(imgs, train_idx)
  51. self.train_lbls = get_split(lbls, train_idx)
  52. self.val_imgs = get_split(imgs, val_idx)
  53. self.val_lbls = get_split(lbls, val_idx)
  54. if is_main_process():
  55. ntrain, nval = len(self.train_imgs), len(self.val_imgs)
  56. print(f"Number of examples: Train {ntrain} - Val {nval}")
  57. # Shard the validation data
  58. self.val_imgs = self.val_imgs[hvd.rank() :: hvd.size()]
  59. self.val_lbls = self.val_lbls[hvd.rank() :: hvd.size()]
  60. self.cached_val_loader = None
  61. elif is_main_process():
  62. print(f"Number of test examples: {len(self.test_imgs)}")
  63. def train_dataset(self):
  64. return fetch_dali_loader(
  65. self.train_imgs,
  66. self.train_lbls,
  67. self.args.batch_size,
  68. "train",
  69. **self.kwargs,
  70. )
  71. def train_size(self):
  72. return len(self.train_imgs)
  73. def val_dataset(self):
  74. if self.cached_val_loader is None:
  75. self.cached_val_loader = fetch_dali_loader(self.val_imgs, self.val_lbls, 1, "eval", **self.kwargs)
  76. return self.cached_val_loader
  77. def val_size(self):
  78. return len(self.val_imgs)
  79. def test_dataset(self):
  80. if self.kwargs["benchmark"]:
  81. return fetch_dali_loader(
  82. self.train_imgs,
  83. self.train_lbls,
  84. self.args.batch_size,
  85. "test",
  86. **self.kwargs,
  87. )
  88. return fetch_dali_loader(self.test_imgs, None, 1, "test", **self.kwargs)
  89. def test_size(self):
  90. return len(self.test_imgs)
  91. def test_fname(self, idx):
  92. return self.test_imgs[idx]