train.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
  15. from pathlib import Path
  16. from subprocess import run
  17. parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
  18. parser.add_argument("--task", type=str, default="01", help="Path to data")
  19. parser.add_argument("--gpus", type=int, required=True, help="Number of GPUs")
  20. parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4], help="Fold number")
  21. parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet")
  22. parser.add_argument("--seed", type=int, default=1, help="Random seed")
  23. parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
  24. parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
  25. parser.add_argument("--bind", action="store_true", help="Enable test time augmentation")
  26. parser.add_argument("--resume_training", action="store_true", help="Resume training from checkpoint")
  27. parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
  28. parser.add_argument("--logname", type=str, default="train_logs.json", help="Name of dlloger output")
  29. parser.add_argument("--learning_rate", type=float, default=8e-4, help="Learning rate")
  30. if __name__ == "__main__":
  31. args = parser.parse_args()
  32. skip = 100 if args.gpus == 1 else 150
  33. path_to_main = Path(__file__).resolve().parent.parent / "main.py"
  34. cmd = ""
  35. if args.bind:
  36. cmd += "bindpcie --cpu=exclusive,nosmt "
  37. cmd = f"python {path_to_main} --exec_mode train --save_ckpt --deep_supervision --skip_first_n_eval {skip} "
  38. cmd += f"--task {args.task} "
  39. cmd += f"--results {args.results} "
  40. cmd += f"--logname {args.logname} "
  41. cmd += f"--dim {args.dim} "
  42. cmd += f"--batch_size {2 if args.dim == 3 else 64} "
  43. cmd += f"--val_batch_size {1 if args.dim == 3 else 64} "
  44. cmd += f"--norm {'instance_nvfuser' if args.dim == 3 else 'instance'} "
  45. cmd += f"--layout {'NDHWC' if args.dim == 3 else 'NCDHW'} "
  46. cmd += f"--fold {args.fold} "
  47. cmd += f"--gpus {args.gpus} "
  48. cmd += f"--epochs {300 if args.gpus == 1 else 600} "
  49. cmd += f"--learning_rate {args.learning_rate} "
  50. cmd += "--amp " if args.amp else ""
  51. cmd += "--tta " if args.tta else ""
  52. cmd += "--resume_training " if args.resume_training else ""
  53. cmd += f"--seed {args.seed} "
  54. run(cmd, shell=True)