profile.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import logging
  16. from contextlib import contextmanager
  17. from utils.cuda_bind import cuda_profile_start, cuda_profile_stop
  18. from utils.cuda_bind import cuda_nvtx_range_push, cuda_nvtx_range_pop
  19. class Profiler:
  20. def __init__(self):
  21. super().__init__()
  22. self._enable_profile = int(os.environ.get('ENABLE_PROFILE', 0))
  23. self._start_step = int(os.environ.get('PROFILE_START_STEP', 0))
  24. self._stop_step = int(os.environ.get('PROFILE_STOP_STEP', 0))
  25. if self._enable_profile:
  26. log_msg = f"Profiling start at {self._start_step}-th and stop at {self._stop_step}-th iteration"
  27. logging.info(log_msg)
  28. def profile_setup(self, step):
  29. """
  30. Setup profiling related status.
  31. Args:
  32. step (int): the index of iteration.
  33. Return:
  34. stop (bool): a signal to indicate whether profiling should stop or not.
  35. """
  36. if self._enable_profile and step == self._start_step:
  37. cuda_profile_start()
  38. logging.info("Profiling start at %d-th iteration",
  39. self._start_step)
  40. if self._enable_profile and step == self._stop_step:
  41. cuda_profile_stop()
  42. logging.info("Profiling stop at %d-th iteration", self._stop_step)
  43. return True
  44. return False
  45. def profile_tag_push(self, step, msg):
  46. if self._enable_profile and \
  47. step >= self._start_step and \
  48. step < self._stop_step:
  49. tag_msg = f"Iter-{step}-{msg}"
  50. cuda_nvtx_range_push(tag_msg)
  51. def profile_tag_pop(self):
  52. if self._enable_profile:
  53. cuda_nvtx_range_pop()
  54. @contextmanager
  55. def profile_tag(self, step, msg):
  56. self.profile_tag_push(step, msg)
  57. yield
  58. self.profile_tag_pop()