args.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
  16. def positive_int(value):
  17. ivalue = int(value)
  18. assert ivalue > 0, f"Argparse error. Expected positive integer but got {value}"
  19. return ivalue
  20. def non_negative_int(value):
  21. ivalue = int(value)
  22. assert ivalue >= 0, f"Argparse error. Expected non-negative integer but got {value}"
  23. return ivalue
  24. def float_0_1(value):
  25. fvalue = float(value)
  26. assert 0 <= fvalue <= 1, f"Argparse error. Expected float value to be in range (0, 1), but got {value}"
  27. return fvalue
  28. def get_main_args(strings=None):
  29. parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
  30. arg = parser.add_argument
  31. arg(
  32. "--exec_mode",
  33. type=str,
  34. choices=["train", "evaluate", "predict"],
  35. default="train",
  36. help="Execution mode to run the model",
  37. )
  38. arg("--data", type=str, default="/data", help="Path to data directory")
  39. arg("--results", type=str, default="/results", help="Path to results directory")
  40. arg("--config", type=str, default=None, help="Config file with arguments")
  41. arg("--logname", type=str, default="logs.json", help="Name of dlloger output")
  42. arg("--task", type=str, default="01", help="Task number. MSD uses numbers 01-10")
  43. arg("--gpus", type=non_negative_int, default=1, help="Number of gpus")
  44. arg("--nodes", type=non_negative_int, default=1, help="Number of nodes")
  45. arg("--learning_rate", type=float, default=0.0008, help="Learning rate")
  46. arg("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value")
  47. arg("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU")
  48. arg("--tta", action="store_true", help="Enable test time augmentation")
  49. arg("--brats", action="store_true", help="Enable BraTS specific training and inference")
  50. arg("--deep_supervision", action="store_true", help="Enable deep supervision")
  51. arg("--invert_resampled_y", action="store_true", help="Resize predictions to match label size before resampling")
  52. arg("--amp", action="store_true", help="Enable automatic mixed precision")
  53. arg("--benchmark", action="store_true", help="Run model benchmarking")
  54. arg("--focal", action="store_true", help="Use focal loss instead of cross entropy")
  55. arg("--save_ckpt", action="store_true", help="Enable saving checkpoint")
  56. arg("--nfolds", type=positive_int, default=5, help="Number of cross-validation folds")
  57. arg("--seed", type=non_negative_int, default=None, help="Random seed")
  58. arg("--skip_first_n_eval", type=non_negative_int, default=0, help="Skip the evaluation for the first n epochs.")
  59. arg("--ckpt_path", type=str, default=None, help="Path for loading checkpoint")
  60. arg("--ckpt_store_dir", type=str, default="/results", help="Path for saving checkpoint")
  61. arg("--fold", type=non_negative_int, default=0, help="Fold number")
  62. arg("--patience", type=positive_int, default=100, help="Early stopping patience")
  63. arg("--batch_size", type=positive_int, default=2, help="Batch size")
  64. arg("--val_batch_size", type=positive_int, default=4, help="Validation batch size")
  65. arg("--momentum", type=float, default=0.99, help="Momentum factor")
  66. arg("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)")
  67. arg("--save_preds", action="store_true", help="Enable prediction saving")
  68. arg("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
  69. arg("--resume_training", action="store_true", help="Resume training from the last checkpoint")
  70. arg("--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading")
  71. arg("--epochs", type=non_negative_int, default=1000, help="Number of training epochs.")
  72. arg("--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics")
  73. arg("--nvol", type=positive_int, default=4, help="Number of volumes which come into single batch size for 2D model")
  74. arg("--depth", type=non_negative_int, default=5, help="The depth of the encoder")
  75. arg("--min_fmap", type=non_negative_int, default=4, help="Minimal dimension of feature map in the bottleneck")
  76. arg("--deep_supr_num", type=non_negative_int, default=2, help="Number of deep supervision heads")
  77. arg("--res_block", action="store_true", help="Enable residual blocks")
  78. arg("--filters", nargs="+", help="[Optional] Set U-Net filters", default=None, type=int)
  79. arg("--layout", type=str, default="NCDHW")
  80. arg("--brats22_model", action="store_true", help="Use BraTS22 model")
  81. arg(
  82. "--norm",
  83. type=str,
  84. choices=["instance", "instance_nvfuser", "batch", "group"],
  85. default="instance",
  86. help="Normalization layer",
  87. )
  88. arg(
  89. "--data2d_dim",
  90. choices=[2, 3],
  91. type=int,
  92. default=3,
  93. help="Input data dimension for 2d model",
  94. )
  95. arg(
  96. "--oversampling",
  97. type=float_0_1,
  98. default=0.4,
  99. help="Probability of crop to have some region with positive label",
  100. )
  101. arg(
  102. "--overlap",
  103. type=float_0_1,
  104. default=0.25,
  105. help="Amount of overlap between scans during sliding window inference",
  106. )
  107. arg(
  108. "--scheduler",
  109. action="store_true",
  110. help="Enable cosine rate scheduler with warmup",
  111. )
  112. arg(
  113. "--optimizer",
  114. type=str,
  115. default="adam",
  116. choices=["sgd", "adam"],
  117. help="Optimizer",
  118. )
  119. arg(
  120. "--blend",
  121. type=str,
  122. choices=["gaussian", "constant"],
  123. default="constant",
  124. help="How to blend output of overlapping windows",
  125. )
  126. arg(
  127. "--train_batches",
  128. type=non_negative_int,
  129. default=0,
  130. help="Limit number of batches for training (used for benchmarking mode only)",
  131. )
  132. arg(
  133. "--test_batches",
  134. type=non_negative_int,
  135. default=0,
  136. help="Limit number of batches for inference (used for benchmarking mode only)",
  137. )
  138. if strings is not None:
  139. arg(
  140. "strings",
  141. metavar="STRING",
  142. nargs="*",
  143. help="String for searching",
  144. )
  145. args = parser.parse_args(strings.split())
  146. else:
  147. args = parser.parse_args()
  148. if args.config is not None:
  149. config = json.load(open(args.config, "r"))
  150. args = vars(args)
  151. args.update(config)
  152. args = Namespace(**args)
  153. with open(f"{args.results}/params.json", "w") as f:
  154. json.dump(vars(args), f)
  155. return args