benchmark.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
  16. from os.path import dirname
  17. from subprocess import call
  18. parser = ArgumentParser(ArgumentDefaultsHelpFormatter)
  19. parser.add_argument("--mode", type=str, required=True, choices=["train", "predict"], help="Benchmarking mode")
  20. parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use")
  21. parser.add_argument("--dim", type=int, required=True, help="Dimension of UNet")
  22. parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
  23. parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
  24. parser.add_argument("--train_batches", type=int, default=80, help="Number of batches for training")
  25. parser.add_argument("--test_batches", type=int, default=80, help="Number of batches for inference")
  26. parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations before collecting statistics")
  27. parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
  28. parser.add_argument("--logname", type=str, default="perf.json", help="Name of dlloger output")
  29. parser.add_argument("--create_idx", action="store_true", help="Create index files for tfrecord")
  30. parser.add_argument("--profile", action="store_true", help="Enable dlprof profiling")
  31. if __name__ == "__main__":
  32. args = parser.parse_args()
  33. path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
  34. cmd = "python main.py --task 01 --benchmark --max_epochs 1 --min_epochs 1 "
  35. cmd += f"--results {args.results} "
  36. cmd += f"--logname {args.logname} "
  37. cmd += f"--exec_mode {args.mode} "
  38. cmd += f"--dim {args.dim} "
  39. cmd += f"--gpus {args.gpus} "
  40. cmd += f"--train_batches {args.train_batches} "
  41. cmd += f"--test_batches {args.test_batches} "
  42. cmd += f"--warmup {args.warmup} "
  43. cmd += "--amp " if args.amp else ""
  44. cmd += "--create_idx " if args.create_idx else ""
  45. cmd += "--profile " if args.profile else ""
  46. if args.mode == "train":
  47. cmd += f"--batch_size {args.batch_size} "
  48. else:
  49. cmd += f"--val_batch_size {args.batch_size} "
  50. call(cmd, shell=True)