inference.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
  16. from os.path import dirname
  17. from subprocess import call
  18. parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
  19. parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4], help="Fold number")
  20. parser.add_argument("--dim", type=int, required=True, help="Dimension of UNet")
  21. parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
  22. parser.add_argument("--val_batch_size", type=int, default=4, help="Batch size")
  23. parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
  24. parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
  25. parser.add_argument("--save_preds", action="store_true", help="Save predicted masks")
  26. if __name__ == "__main__":
  27. args = parser.parse_args()
  28. path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
  29. cmd = f"python {path_to_main} --exec_mode evaluate --task 01 --gpus 1 "
  30. cmd += f"--dim {args.dim} "
  31. cmd += f"--fold {args.fold} "
  32. cmd += f"--ckpt_path {args.ckpt_path} "
  33. cmd += f"--val_batch_size {args.val_batch_size} "
  34. cmd += "--amp " if args.amp else ""
  35. cmd += "--tta " if args.tta else ""
  36. cmd += "--save_preds " if args.save_preds else ""
  37. call(cmd, shell=True)