main.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pyprof
  16. import torch
  17. from dllogger import JSONStreamBackend, Logger, StdOutBackend, Verbosity
  18. from pytorch_lightning import Trainer, seed_everything
  19. from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
  20. from data_loading.data_module import DataModule
  21. from models.nn_unet import NNUnet
  22. from utils.gpu_affinity import set_affinity
  23. from utils.logger import LoggingCallback
  24. from utils.utils import get_main_args, is_main_process, make_empty_dir, set_cuda_devices, verify_ckpt_path
  25. def log(logname, dice, epoch=None, dice_tta=None):
  26. dllogger = Logger(
  27. backends=[
  28. JSONStreamBackend(Verbosity.VERBOSE, os.path.join(args.results, logname)),
  29. StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: ""),
  30. ]
  31. )
  32. metrics = {}
  33. if epoch is not None:
  34. metrics.update({"Epoch": epoch})
  35. metrics.update({"Mean dice": round(dice.mean().item(), 2)})
  36. if dice_tta is not None:
  37. metrics.update({"Mean TTA dice": round(dice_tta.mean().item(), 2)})
  38. metrics.update({f"L{j+1}": round(m.item(), 2) for j, m in enumerate(dice)})
  39. if dice_tta is not None:
  40. metrics.update({f"TTA_L{j+1}": round(m.item(), 2) for j, m in enumerate(dice_tta)})
  41. dllogger.log(step=(), data=metrics)
  42. dllogger.flush()
  43. if __name__ == "__main__":
  44. args = get_main_args()
  45. if args.profile:
  46. pyprof.init(enable_function_stack=True)
  47. print("Profiling enabled")
  48. if args.affinity != "disabled":
  49. affinity = set_affinity(os.getenv("LOCAL_RANK", "0"), args.affinity)
  50. set_cuda_devices(args)
  51. if is_main_process():
  52. print(f"{args.exec_mode.upper()} TASK {args.task} FOLD {args.fold} SEED {args.seed}")
  53. seed_everything(args.seed)
  54. data_module = DataModule(args)
  55. data_module.prepare_data()
  56. data_module.setup()
  57. ckpt_path = verify_ckpt_path(args)
  58. callbacks = None
  59. model_ckpt = None
  60. if args.benchmark:
  61. model = NNUnet(args)
  62. batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
  63. log_dir = os.path.join(args.results, args.logname if args.logname is not None else "perf.json")
  64. callbacks = [
  65. LoggingCallback(
  66. log_dir=log_dir,
  67. global_batch_size=batch_size * args.gpus,
  68. mode=args.exec_mode,
  69. warmup=args.warmup,
  70. dim=args.dim,
  71. profile=args.profile,
  72. )
  73. ]
  74. elif args.exec_mode == "train":
  75. model = NNUnet(args)
  76. if args.save_ckpt:
  77. model_ckpt = ModelCheckpoint(monitor="dice_sum", mode="max", save_last=True)
  78. callbacks = [EarlyStopping(monitor="dice_sum", patience=args.patience, verbose=True, mode="max")]
  79. else: # Evaluation or inference
  80. if ckpt_path is not None:
  81. model = NNUnet.load_from_checkpoint(ckpt_path)
  82. else:
  83. model = NNUnet(args)
  84. trainer = Trainer(
  85. logger=False,
  86. gpus=args.gpus,
  87. precision=16 if args.amp else 32,
  88. benchmark=True,
  89. deterministic=False,
  90. min_epochs=args.min_epochs,
  91. max_epochs=args.max_epochs,
  92. sync_batchnorm=args.sync_batchnorm,
  93. gradient_clip_val=args.gradient_clip_val,
  94. callbacks=callbacks,
  95. num_sanity_val_steps=0,
  96. default_root_dir=args.results,
  97. resume_from_checkpoint=ckpt_path,
  98. accelerator="ddp" if args.gpus > 1 else None,
  99. checkpoint_callback=model_ckpt,
  100. limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
  101. limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
  102. limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
  103. )
  104. if args.benchmark:
  105. if args.exec_mode == "train":
  106. if args.profile:
  107. with torch.autograd.profiler.emit_nvtx():
  108. trainer.fit(model, train_dataloader=data_module.train_dataloader())
  109. else:
  110. trainer.fit(model, train_dataloader=data_module.train_dataloader())
  111. else:
  112. trainer.test(model, test_dataloaders=data_module.test_dataloader())
  113. elif args.exec_mode == "train":
  114. trainer.fit(model, data_module)
  115. if model_ckpt is not None:
  116. model.args.exec_mode = "evaluate"
  117. model.args.tta = True
  118. trainer.interrupted = False
  119. trainer.test(test_dataloaders=data_module.val_dataloader())
  120. if is_main_process():
  121. log_name = args.logname if args.logname is not None else "train_log.json"
  122. log(log_name, model.best_sum_dice, model.best_sum_epoch, model.eval_dice)
  123. elif args.exec_mode == "evaluate":
  124. model.args = args
  125. trainer.test(model, test_dataloaders=data_module.val_dataloader())
  126. if is_main_process():
  127. log(args.logname if args.logname is not None else "eval_log.json", model.eval_dice)
  128. elif args.exec_mode == "predict":
  129. model.args = args
  130. if args.save_preds:
  131. dir_name = f"preds_task_{args.task}_dim_{args.dim}_fold_{args.fold}"
  132. if args.tta:
  133. dir_name += "_tta"
  134. save_dir = os.path.join(args.results, dir_name)
  135. model.save_dir = save_dir
  136. make_empty_dir(save_dir)
  137. trainer.test(model, test_dataloaders=data_module.test_dataloader())