logging.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import dllogger
  15. import pathlib
  16. from abc import ABC, abstractmethod
  17. from dllogger import Verbosity
  18. from typing import Dict, Any, Callable, Optional
  19. from runtime.utils import rank_zero_only
  20. class Logger(ABC):
  21. @rank_zero_only
  22. @abstractmethod
  23. def log_hyperparams(self, params):
  24. pass
  25. @rank_zero_only
  26. @abstractmethod
  27. def log_metadata(self, metric, metadata):
  28. pass
  29. @rank_zero_only
  30. @abstractmethod
  31. def log_metrics(self, metrics, step=None):
  32. pass
  33. @staticmethod
  34. def _sanitize_params(params):
  35. def _sanitize(val):
  36. if isinstance(val, Callable):
  37. try:
  38. _val = val()
  39. if isinstance(_val, Callable):
  40. return val.__name__
  41. return _val
  42. except Exception:
  43. return getattr(val, "__name__", None)
  44. elif isinstance(val, pathlib.Path):
  45. return str(val)
  46. return val
  47. return {key: _sanitize(val) for key, val in params.items()}
  48. @rank_zero_only
  49. def flush(self):
  50. pass
  51. class LoggerCollection(Logger):
  52. def __init__(self, loggers):
  53. super().__init__()
  54. self.loggers = loggers
  55. def __getitem__(self, index):
  56. return [logger for logger in self.loggers][index]
  57. @rank_zero_only
  58. def log_metrics(self, metrics, step=None):
  59. for logger in self.loggers:
  60. logger.log_metrics(metrics, step)
  61. @rank_zero_only
  62. def log_hyperparams(self, params):
  63. for logger in self.loggers:
  64. logger.log_hyperparams(params)
  65. @rank_zero_only
  66. def log_metadata(self, metric, metadata):
  67. for logger in self.loggers:
  68. logger.log_metadata(metric, metadata)
  69. @rank_zero_only
  70. def flush(self):
  71. for logger in self.loggers:
  72. logger.flush()
  73. class DLLogger(Logger):
  74. def __init__(self, save_dir: pathlib.Path, filename: str, append: bool, quiet: bool):
  75. super().__init__()
  76. self._initialize_dllogger(save_dir, filename, append, quiet)
  77. @rank_zero_only
  78. def _initialize_dllogger(self, save_dir, filename, append, quiet):
  79. save_dir.mkdir(parents=True, exist_ok=True)
  80. backends = [
  81. dllogger.JSONStreamBackend(
  82. Verbosity.DEFAULT, str(save_dir / filename), append=append
  83. ),
  84. ]
  85. if not quiet:
  86. backends.append(
  87. dllogger.StdOutBackend(
  88. Verbosity.VERBOSE, step_format=lambda step: f"Step: {step} "
  89. )
  90. )
  91. dllogger.init(backends=backends)
  92. @rank_zero_only
  93. def log_hyperparams(self, params):
  94. params = self._sanitize_params(params)
  95. dllogger.log(step="PARAMETER", data=params)
  96. @rank_zero_only
  97. def log_metadata(self, metric, metadata):
  98. dllogger.metadata(metric, metadata)
  99. @rank_zero_only
  100. def log_metrics(self, metrics, step=None):
  101. if step is None:
  102. step = tuple()
  103. dllogger.log(step=step, data=metrics)
  104. @rank_zero_only
  105. def flush(self):
  106. dllogger.flush()
  107. def get_logger(args):
  108. loggers = []
  109. if args.use_dllogger:
  110. loggers.append(
  111. DLLogger(save_dir=args.results, filename=args.logname, append=args.resume_training, quiet=args.quiet)
  112. )
  113. return LoggerCollection(loggers)