setup.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ Utils for setting up different parts of the execution """
  15. import os
  16. import multiprocessing
  17. import numpy as np
  18. import dllogger as logger
  19. from dllogger import StdOutBackend, Verbosity, JSONStreamBackend
  20. import tensorflow as tf
  21. import horovod.tensorflow as hvd
  22. def set_flags():
  23. """ Set necessary flags for execution """
  24. tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
  25. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  26. os.environ['CUDA_CACHE_DISABLE'] = '1'
  27. os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'
  28. os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
  29. os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '0'
  30. os.environ['TF_ADJUST_HUE_FUSED'] = '1'
  31. os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
  32. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
  33. os.environ['TF_SYNC_ON_FINISH'] = '0'
  34. os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0'
  35. def prepare_model_dir(params):
  36. """ Prepare the directory where checkpoints are stored
  37. :param params: Dict with additional parameters
  38. :return: Path to model dir
  39. """
  40. model_dir = os.path.join(params.model_dir, "model_checkpoint")
  41. model_dir = model_dir if (hvd.rank() == 0 and not params.benchmark) else None
  42. if model_dir is not None:
  43. os.makedirs(model_dir, exist_ok=True)
  44. if ('train' in params.exec_mode) and (not params.resume_training):
  45. os.system('rm -rf {}/*'.format(model_dir))
  46. return model_dir
  47. def build_estimator(params, model_fn):
  48. """ Build estimator
  49. :param params: Dict with additional parameters
  50. :param model_fn: Model graph
  51. :return: Estimator
  52. """
  53. np.random.seed(params.seed)
  54. tf.compat.v1.random.set_random_seed(params.seed)
  55. model_dir = prepare_model_dir(params)
  56. config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(), allow_soft_placement=True)
  57. if params.use_xla:
  58. config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
  59. config.gpu_options.allow_growth = True
  60. config.gpu_options.visible_device_list = str(hvd.local_rank())
  61. config.intra_op_parallelism_threads = 1
  62. config.inter_op_parallelism_threads = max(2, (multiprocessing.cpu_count() // hvd.size()) - 2)
  63. if params.use_amp:
  64. config.graph_options.rewrite_options.auto_mixed_precision = 1
  65. checkpoint_steps = (params.max_steps // hvd.size()) if hvd.rank() == 0 else None
  66. checkpoint_steps = checkpoint_steps if not params.benchmark else None
  67. run_config = tf.estimator.RunConfig(
  68. save_summary_steps=params.max_steps,
  69. tf_random_seed=params.seed,
  70. session_config=config,
  71. save_checkpoints_steps=checkpoint_steps,
  72. keep_checkpoint_max=1)
  73. return tf.estimator.Estimator(model_fn=model_fn,
  74. model_dir=model_dir,
  75. config=run_config,
  76. params=params)
  77. def get_logger(params):
  78. """ Get logger object
  79. :param params: Dict with additional parameters
  80. :return: logger
  81. """
  82. backends = []
  83. if hvd.rank() == 0:
  84. backends += [StdOutBackend(Verbosity.VERBOSE)]
  85. if params.log_dir:
  86. backends += [JSONStreamBackend(Verbosity.VERBOSE, params.log_dir)]
  87. logger.init(backends=backends)
  88. logger.metadata("whole_tumor", {"unit": None})
  89. logger.metadata("throughput_test", {"unit": "volumes/s"})
  90. logger.metadata("throughput_train", {"unit": "volumes/s"})
  91. return logger