utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import pickle
  16. from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
  17. from subprocess import call
  18. import torch
  19. def is_main_process():
  20. return int(os.getenv("LOCAL_RANK", "0")) == 0
  21. def set_cuda_devices(args):
  22. assert args.gpus <= torch.cuda.device_count(), f"Requested {args.gpus} gpus, available {torch.cuda.device_count()}."
  23. device_list = ",".join([str(i) for i in range(args.gpus)])
  24. os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", device_list)
  25. def verify_ckpt_path(args):
  26. resume_path = os.path.join(args.results, "checkpoints", "last.ckpt")
  27. ckpt_path = resume_path if args.resume_training and os.path.exists(resume_path) else args.ckpt_path
  28. return ckpt_path
  29. def get_task_code(args):
  30. return f"{args.task}_{args.dim}d"
  31. def get_config_file(args):
  32. task_code = get_task_code(args)
  33. config_file = os.path.join(args.data, task_code, "config.pkl")
  34. return pickle.load(open(config_file, "rb"))
  35. def make_empty_dir(path):
  36. call(["rm", "-rf", path])
  37. os.makedirs(path)
  38. def flip(data, axis):
  39. return torch.flip(data, dims=axis)
  40. def positive_int(value):
  41. ivalue = int(value)
  42. assert ivalue > 0, f"Argparse error. Expected positive integer but got {value}"
  43. return ivalue
  44. def non_negative_int(value):
  45. ivalue = int(value)
  46. assert ivalue >= 0, f"Argparse error. Expected positive integer but got {value}"
  47. return ivalue
  48. def float_0_1(value):
  49. ivalue = float(value)
  50. assert 0 <= ivalue <= 1, f"Argparse error. Expected float to be in range (0, 1), but got {value}"
  51. return ivalue
  52. def get_main_args(strings=None):
  53. parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
  54. parser.add_argument(
  55. "--exec_mode",
  56. type=str,
  57. choices=["train", "evaluate", "predict"],
  58. default="train",
  59. help="Execution mode to run the model",
  60. )
  61. parser.add_argument("--data", type=str, default="/data", help="Path to data directory")
  62. parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
  63. parser.add_argument("--logname", type=str, default=None, help="Name of dlloger output")
  64. parser.add_argument("--task", type=str, help="Task number. MSD uses numbers 01-10")
  65. parser.add_argument("--gpus", type=non_negative_int, default=1, help="Number of gpus")
  66. parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate")
  67. parser.add_argument("--gradient_clip_val", type=float, default=0, help="Gradient clipping norm value")
  68. parser.add_argument("--negative_slope", type=float, default=0.01, help="Negative slope for LeakyReLU")
  69. parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
  70. parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
  71. parser.add_argument("--benchmark", action="store_true", help="Run model benchmarking")
  72. parser.add_argument("--deep_supervision", action="store_true", help="Enable deep supervision")
  73. parser.add_argument("--sync_batchnorm", action="store_true", help="Enable synchronized batchnorm")
  74. parser.add_argument("--save_ckpt", action="store_true", help="Enable saving checkpoint")
  75. parser.add_argument("--nfolds", type=positive_int, default=5, help="Number of cross-validation folds")
  76. parser.add_argument("--seed", type=non_negative_int, default=1, help="Random seed")
  77. parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint")
  78. parser.add_argument("--fold", type=non_negative_int, default=0, help="Fold number")
  79. parser.add_argument("--patience", type=positive_int, default=100, help="Early stopping patience")
  80. parser.add_argument("--lr_patience", type=positive_int, default=70, help="Patience for ReduceLROnPlateau scheduler")
  81. parser.add_argument("--batch_size", type=positive_int, default=2, help="Batch size")
  82. parser.add_argument("--val_batch_size", type=positive_int, default=4, help="Validation batch size")
  83. parser.add_argument("--steps", nargs="+", type=positive_int, required=False, help="Steps for multistep scheduler")
  84. parser.add_argument("--create_idx", action="store_true", help="Create index files for tfrecord")
  85. parser.add_argument("--profile", action="store_true", help="Run dlprof profiling")
  86. parser.add_argument("--momentum", type=float, default=0.99, help="Momentum factor")
  87. parser.add_argument("--weight_decay", type=float, default=0.0001, help="Weight decay (L2 penalty)")
  88. parser.add_argument("--save_preds", action="store_true", help="Enable prediction saving")
  89. parser.add_argument("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
  90. parser.add_argument("--resume_training", action="store_true", help="Resume training from the last checkpoint")
  91. parser.add_argument("--factor", type=float, default=0.3, help="Scheduler factor")
  92. parser.add_argument(
  93. "--num_workers", type=non_negative_int, default=8, help="Number of subprocesses to use for data loading"
  94. )
  95. parser.add_argument(
  96. "--min_epochs", type=non_negative_int, default=30, help="Force training for at least these many epochs"
  97. )
  98. parser.add_argument(
  99. "--max_epochs", type=non_negative_int, default=10000, help="Stop training after this number of epochs"
  100. )
  101. parser.add_argument(
  102. "--warmup", type=non_negative_int, default=5, help="Warmup iterations before collecting statistics"
  103. )
  104. parser.add_argument(
  105. "--oversampling",
  106. type=float_0_1,
  107. default=0.33,
  108. help="Probability of crop to have some region with positive label",
  109. )
  110. parser.add_argument(
  111. "--norm", type=str, choices=["instance", "batch", "group"], default="instance", help="Normalization layer"
  112. )
  113. parser.add_argument(
  114. "--overlap",
  115. type=float_0_1,
  116. default=0.25,
  117. help="Amount of overlap between scans during sliding window inference",
  118. )
  119. parser.add_argument(
  120. "--affinity",
  121. type=str,
  122. default="socket_unique_interleaved",
  123. choices=[
  124. "socket",
  125. "single",
  126. "single_unique",
  127. "socket_unique_interleaved",
  128. "socket_unique_continuous",
  129. "disabled",
  130. ],
  131. help="type of CPU affinity",
  132. )
  133. parser.add_argument(
  134. "--data2d_dim",
  135. choices=[2, 3],
  136. type=int,
  137. default=3,
  138. help="Input data dimension for 2d model",
  139. )
  140. parser.add_argument(
  141. "--scheduler",
  142. type=str,
  143. default="none",
  144. choices=["none", "multistep", "cosine", "plateau"],
  145. help="Learning rate scheduler",
  146. )
  147. parser.add_argument(
  148. "--optimizer",
  149. type=str,
  150. default="radam",
  151. choices=["sgd", "adam", "adamw", "radam", "fused_adam"],
  152. help="Optimizer",
  153. )
  154. parser.add_argument(
  155. "--val_mode",
  156. type=str,
  157. choices=["gaussian", "constant"],
  158. default="gaussian",
  159. help="How to blend output of overlapping windows",
  160. )
  161. parser.add_argument(
  162. "--train_batches",
  163. type=non_negative_int,
  164. default=0,
  165. help="Limit number of batches for training (used for benchmarking mode only)",
  166. )
  167. parser.add_argument(
  168. "--test_batches",
  169. type=non_negative_int,
  170. default=0,
  171. help="Limit number of batches for inference (used for benchmarking mode only)",
  172. )
  173. if strings is not None:
  174. parser.add_argument(
  175. "strings",
  176. metavar="STRING",
  177. nargs="*",
  178. help="String for searching",
  179. )
  180. args = parser.parse_args(strings.split())
  181. else:
  182. args = parser.parse_args()
  183. return args