main.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import torch
  16. from pytorch_lightning import Trainer, seed_everything
  17. from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
  18. from data_loading.data_module import DataModule
  19. from models.nn_unet import NNUnet
  20. from utils.gpu_affinity import set_affinity
  21. from utils.logger import LoggingCallback
  22. from utils.utils import get_main_args, is_main_process, log, make_empty_dir, set_cuda_devices, verify_ckpt_path
  23. if __name__ == "__main__":
  24. args = get_main_args()
  25. if args.profile:
  26. import pyprof
  27. pyprof.init(enable_function_stack=True)
  28. print("Profiling enabled")
  29. if args.affinity != "disabled":
  30. affinity = set_affinity(os.getenv("LOCAL_RANK", "0"), args.affinity)
  31. set_cuda_devices(args)
  32. seed_everything(args.seed)
  33. data_module = DataModule(args)
  34. data_module.prepare_data()
  35. data_module.setup()
  36. ckpt_path = verify_ckpt_path(args)
  37. callbacks = None
  38. model_ckpt = None
  39. if args.benchmark:
  40. model = NNUnet(args)
  41. batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
  42. log_dir = os.path.join(args.results, args.logname if args.logname is not None else "perf.json")
  43. callbacks = [
  44. LoggingCallback(
  45. log_dir=log_dir,
  46. global_batch_size=batch_size * args.gpus,
  47. mode=args.exec_mode,
  48. warmup=args.warmup,
  49. dim=args.dim,
  50. profile=args.profile,
  51. )
  52. ]
  53. elif args.exec_mode == "train":
  54. model = NNUnet(args)
  55. if args.save_ckpt:
  56. model_ckpt = ModelCheckpoint(monitor="dice_sum", mode="max", save_last=True)
  57. callbacks = [EarlyStopping(monitor="dice_sum", patience=args.patience, verbose=True, mode="max")]
  58. else: # Evaluation or inference
  59. if ckpt_path is not None:
  60. model = NNUnet.load_from_checkpoint(ckpt_path)
  61. else:
  62. model = NNUnet(args)
  63. trainer = Trainer(
  64. logger=False,
  65. gpus=args.gpus,
  66. precision=16 if args.amp else 32,
  67. benchmark=True,
  68. deterministic=False,
  69. min_epochs=args.min_epochs,
  70. max_epochs=args.max_epochs,
  71. sync_batchnorm=args.sync_batchnorm,
  72. gradient_clip_val=args.gradient_clip_val,
  73. callbacks=callbacks,
  74. num_sanity_val_steps=0,
  75. default_root_dir=args.results,
  76. resume_from_checkpoint=ckpt_path,
  77. accelerator="ddp" if args.gpus > 1 else None,
  78. checkpoint_callback=model_ckpt,
  79. limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
  80. limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
  81. limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
  82. )
  83. if args.benchmark:
  84. if args.exec_mode == "train":
  85. if args.profile:
  86. with torch.autograd.profiler.emit_nvtx():
  87. trainer.fit(model, train_dataloader=data_module.train_dataloader())
  88. else:
  89. trainer.fit(model, train_dataloader=data_module.train_dataloader())
  90. else:
  91. # warmup
  92. trainer.test(model, test_dataloaders=data_module.test_dataloader())
  93. # benchmark run
  94. trainer.current_epoch = 1
  95. trainer.test(model, test_dataloaders=data_module.test_dataloader())
  96. elif args.exec_mode == "train":
  97. trainer.fit(model, data_module)
  98. elif args.exec_mode == "evaluate":
  99. model.args = args
  100. trainer.test(model, test_dataloaders=data_module.val_dataloader())
  101. if is_main_process():
  102. logname = args.logname if args.logname is not None else "eval_log.json"
  103. log(logname, model.eval_dice, results=args.results)
  104. elif args.exec_mode == "predict":
  105. model.args = args
  106. if args.save_preds:
  107. prec = "amp" if args.amp else "fp32"
  108. dir_name = f"preds_task_{args.task}_dim_{args.dim}_fold_{args.fold}_{prec}"
  109. if args.tta:
  110. dir_name += "_tta"
  111. save_dir = os.path.join(args.results, dir_name)
  112. model.save_dir = save_dir
  113. make_empty_dir(save_dir)
  114. trainer.test(model, test_dataloaders=data_module.test_dataloader())