main.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import ctypes
  15. import os
  16. from data_loading.data_module import DataModule
  17. from models.nn_unet import NNUnet
  18. from runtime.args import get_main_args
  19. from runtime.checkpoint import load_model
  20. from runtime.logging import get_logger
  21. from runtime.run import evaluate, export_model, predict, train
  22. from runtime.utils import hvd_init, set_seed, set_tf_flags
  23. def main(args):
  24. os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
  25. os.environ["TF_GPU_THREAD_COUNT"] = "1"
  26. _libcudart = ctypes.CDLL("libcudart.so")
  27. # Set device limit on the current device
  28. # cudaLimitMaxL2FetchGranularity = 0x05
  29. pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
  30. _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
  31. _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
  32. assert pValue.contents.value == 128
  33. hvd_init()
  34. if args.seed is not None:
  35. set_seed(args.seed)
  36. set_tf_flags(args)
  37. data = DataModule(args)
  38. data.setup()
  39. logger = get_logger(args)
  40. logger.log_hyperparams(vars(args))
  41. logger.log_metadata("dice_score", {"unit": None})
  42. logger.log_metadata("eval_dice_nobg", {"unit": None})
  43. logger.log_metadata("throughput_predict", {"unit": "images/s"})
  44. logger.log_metadata("throughput_train", {"unit": "images/s"})
  45. logger.log_metadata("latency_predict_mean", {"unit": "ms"})
  46. logger.log_metadata("latency_train_mean", {"unit": "ms"})
  47. if args.exec_mode == "train":
  48. model = NNUnet(args)
  49. train(args, model, data, logger)
  50. elif args.exec_mode == "evaluate":
  51. model = load_model(args)
  52. evaluate(args, model, data, logger)
  53. elif args.exec_mode == "predict":
  54. model = NNUnet(args) if args.benchmark else load_model(args)
  55. predict(args, model, data, logger)
  56. elif args.exec_mode == "export":
  57. # Export model
  58. model = load_model(args)
  59. export_model(args, model)
  60. suffix = "amp" if args.amp else "fp32"
  61. sm = f"{args.results}/saved_model_task_{args.task}_dim_{args.dim}_" + suffix
  62. trt = f"{args.results}/trt_saved_model_task_{args.task}_dim_{args.dim}_" + suffix
  63. args.saved_model_dir = sm if args.load_sm else trt
  64. args.exec_mode = "evaluate" if args.validate else "predict"
  65. # Run benchmarking
  66. model = load_model(args)
  67. data = DataModule(args)
  68. data.setup()
  69. if args.validate:
  70. evaluate(args, model, data, logger)
  71. else:
  72. predict(args, model, data, logger)
  73. else:
  74. raise NotImplementedError
  75. if __name__ == "__main__":
  76. args = get_main_args()
  77. main(args)