ctl_callbacks.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # SPDX-License-Identifier: Apache-2.0
  15. import time
  16. import dllogger
  17. from callbacks.callbacks import Callback, CallbackContainer
  18. from distributed_utils import is_main_process
  19. from training.utils import round_dict
  20. from training.checkpoint_utils import save_checkpoint
  21. class CTLCallbackContainer(CallbackContainer):
  22. """
  23. Base class for CTLTrainer callbacks storage.
  24. """
  25. def __init__(self, trainer, callbacks):
  26. self.callbacks = callbacks
  27. self.trainer = trainer
  28. self._init_trainers()
  29. self.logs = {}
  30. super().__init__()
  31. def _init_trainers(self):
  32. for callback in self.callbacks:
  33. callback.trainer = self.trainer
  34. def on_train_begin(self, logs=None):
  35. if logs is None:
  36. logs = {}
  37. for callback in self.callbacks:
  38. callback.on_train_begin(logs)
  39. def on_train_end(self, logs=None):
  40. if logs is None:
  41. logs = {}
  42. for callback in self.callbacks:
  43. callback.on_train_end(logs)
  44. def on_epoch_begin(self, epoch, logs=None):
  45. if logs is None:
  46. logs = {}
  47. for callback in self.callbacks:
  48. callback.on_epoch_begin(epoch, logs)
  49. def on_epoch_end(self, epoch, logs=None):
  50. if logs is None:
  51. logs = {}
  52. for callback in self.callbacks:
  53. callback.on_epoch_end(epoch, logs)
  54. def on_valid_begin(self, epoch, logs=None):
  55. if logs is None:
  56. logs = {}
  57. for callback in self.callbacks:
  58. callback.on_valid_begin(epoch, logs)
  59. def on_valid_end(self, epoch, logs=None):
  60. if logs is None:
  61. logs = {}
  62. for callback in self.callbacks:
  63. callback.on_valid_end(epoch, logs)
  64. def on_batch_begin(self, batch, logs=None):
  65. if logs is None:
  66. logs = {}
  67. for callback in self.callbacks:
  68. callback.on_batch_begin(batch, logs)
  69. def on_batch_end(self, batch, logs=None):
  70. if logs is None:
  71. logs = {}
  72. for callback in self.callbacks:
  73. callback.on_batch_end(batch, logs)
  74. def on_evaluate_end(self, logs=None):
  75. if logs is None:
  76. logs = {}
  77. for callback in self.callbacks:
  78. callback.on_evaluate_end(logs)
  79. def on_evaluate_begin(self, logs=None):
  80. if logs is None:
  81. logs = {}
  82. for callback in self.callbacks:
  83. callback.on_evaluate_begin(logs)
  84. class CTLCallback(Callback):
  85. """
  86. Base class for building new CTLTrainer callbacks.
  87. """
  88. def __init__(self):
  89. self.trainer = None
  90. super().__init__()
  91. @property
  92. def trainer(self):
  93. return self._trainer
  94. @trainer.setter
  95. def trainer(self, trainer):
  96. self._trainer = trainer
  97. def on_train_begin(self, logs=None):
  98. pass
  99. def on_train_end(self, logs=None):
  100. pass
  101. def on_epoch_begin(self, epoch, logs=None):
  102. pass
  103. def on_epoch_end(self, epoch, logs=None):
  104. pass
  105. def on_valid_begin(self, epoch, logs=None):
  106. pass
  107. def on_valid_end(self, epoch, logs=None):
  108. pass
  109. def on_batch_begin(self, batch, logs=None):
  110. pass
  111. def on_batch_end(self, batch, logs=None):
  112. pass
  113. def on_evaluate_begin(self, logs=None):
  114. pass
  115. def on_evaluate_end(self, logs=None):
  116. pass
  117. class LoggingCallback(CTLCallback):
  118. def on_train_begin(self, logs=None):
  119. self.trainer.logger.log(
  120. step='event',
  121. data={"String": "Training with {} epochs".format(self.trainer.config.get("num_epochs", 1))},
  122. verbosity=dllogger.Verbosity.DEFAULT,
  123. )
  124. def on_train_end(self, logs=None):
  125. self.trainer.logger.log(step='event', data={"String": "Training Stopped"}, verbosity=dllogger.Verbosity.DEFAULT)
  126. def on_epoch_begin(self, epoch, logs=None):
  127. self.trainer.logger.log(step='event', data={"String": "Epoch {}".format(epoch)}, verbosity=dllogger.Verbosity.DEFAULT)
  128. def on_batch_end(self, batch, logs=None):
  129. if self.trainer.config.log_interval > 0 and self.trainer.global_step % self.trainer.config.log_interval == 0:
  130. self.trainer.logger.flush()
  131. def on_valid_begin(self, epoch, logs=None):
  132. self.trainer.logger.log(
  133. step='event', data={"String": "Calculating Validation Metrics"}, verbosity=dllogger.Verbosity.DEFAULT
  134. )
  135. def on_valid_end(self, epoch, logs=None):
  136. self.trainer.logger.log(
  137. step='event',
  138. data={"String": "Epoch {} Validation Metrics: {}".format(epoch, round_dict(logs))},
  139. verbosity=dllogger.Verbosity.DEFAULT,
  140. )
  141. def on_epoch_end(self, epoch, logs=None):
  142. self.trainer.logger.flush()
  143. def on_evaluate_begin(self, logs=None):
  144. self.trainer.logger.log(
  145. step='event', data={"String": "Beginning Metric Evaluation"}, verbosity=dllogger.Verbosity.DEFAULT
  146. )
  147. def on_evaluate_end(self, logs=None):
  148. self.trainer.logger.log(
  149. step='event', data={"String": "Evaluation Metrics: {}".format(round_dict(logs))}, verbosity=dllogger.Verbosity.DEFAULT
  150. )
  151. self.trainer.logger.log(step=[], data=logs, verbosity=dllogger.Verbosity.DEFAULT)
  152. class EarlyStopping(CTLCallback):
  153. def __init__(self, metric="val_loss", min_delta=0, patience=5, max_divergence=None, divergence_patience=1):
  154. self.metric = metric
  155. self.min_delta = min_delta
  156. self.patience = patience
  157. self.max_divergence = max_divergence
  158. self.divergence_patience = divergence_patience
  159. self.divergence_stopped_epochs = 0
  160. self.stopped_epochs = 0
  161. self.best_loss = None
  162. super().__init__()
  163. def on_epoch_end(self, epoch, logs=None):
  164. epoch_loss = logs.get(self.metric, None)
  165. if epoch_loss is None:
  166. return
  167. if self.best_loss is None:
  168. self.best_loss = epoch_loss
  169. return
  170. if self.max_divergence and ((epoch_loss - self.best_loss) > self.max_divergence):
  171. self.divergence_stopped_epochs += 1
  172. self.stopped_epochs += 1
  173. if self.divergence_stopped_epochs >= self.divergence_patience:
  174. self.trainer._stop_training = True
  175. self.trainer.logger.log(
  176. step='event', data={"String": f"Applying early stopping as divergence threshold reached"}, verbosity=dllogger.Verbosity.DEFAULT
  177. )
  178. elif (epoch_loss + self.min_delta) < self.best_loss:
  179. self.best_loss = epoch_loss
  180. self.stopped_epochs = 0
  181. self.divergence_stopped_epochs = 0
  182. else:
  183. self.stopped_epochs += 1
  184. self.divergence_stopped_epochs = 0
  185. if self.stopped_epochs >= self.patience:
  186. self.trainer._stop_training = True
  187. self.trainer.logger.log(
  188. step='event', data={"String": f"Applying early stopping"}, verbosity=dllogger.Verbosity.DEFAULT
  189. )
  190. class SaveBestCheckpoint(CTLCallback):
  191. def __init__(self, metric="val_loss"):
  192. self.metric = metric
  193. self.best_loss = None
  194. super().__init__()
  195. def on_epoch_end(self, epoch, logs=None):
  196. epoch_loss = logs.get(self.metric, None)
  197. if epoch_loss is None:
  198. return
  199. if self.best_loss is None or epoch_loss < self.best_loss:
  200. self.best_loss = epoch_loss
  201. if is_main_process():
  202. save_checkpoint(self.trainer, checkpoint_dir=self.trainer.log_path, filename="best_checkpoint.zip")
  203. class SaveCheckpoint(CTLCallback):
  204. def __init__(self):
  205. super().__init__()
  206. def on_epoch_end(self, epoch, logs=None):
  207. if is_main_process():
  208. save_checkpoint(self.trainer, checkpoint_dir=self.trainer.log_path, filename="last_checkpoint.zip")
  209. class MeanAccumulator:
  210. def __init__(self):
  211. self.sum = 0
  212. self.count = 0
  213. def consume(self, value):
  214. self.sum += value
  215. self.count += 1
  216. @property
  217. def value(self):
  218. if self.count == 0:
  219. return 0
  220. return self.sum / self.count
  221. class ThroughputBenchmark(CTLCallback):
  222. def __init__(self, warmup_epochs=0):
  223. self.warmup_epochs = warmup_epochs
  224. self.train_throughput = MeanAccumulator()
  225. self.valid_throughput = MeanAccumulator()
  226. self.epoch_train_start = None
  227. self.epoch_train_end = None
  228. super().__init__()
  229. def on_train_end(self, logs=None):
  230. if self.train_throughput.value > 0:
  231. logs["Train it/s"] = self.train_throughput.value
  232. logs["Valid it/s"] = self.valid_throughput.value
  233. def on_epoch_begin(self, epoch, logs=None):
  234. self.epoch_train_start = time.time()
  235. def on_valid_end(self, epoch, logs=None):
  236. if epoch >= self.warmup_epochs:
  237. train_epoch_time = self.epoch_train_end - self.epoch_train_start
  238. valid_epoch_time = time.time() - self.epoch_train_end
  239. train_iter_per_sec = self.trainer.train_dataset_len / train_epoch_time
  240. valid_iter_per_sec = self.trainer.valid_dataset_len / valid_epoch_time
  241. logs["Train epoch it/s"] = train_iter_per_sec
  242. logs["Valid epoch it/s"] = valid_iter_per_sec
  243. self.train_throughput.consume(train_iter_per_sec)
  244. self.valid_throughput.consume(valid_iter_per_sec)
  245. def on_valid_begin(self, batch, logs=None):
  246. self.epoch_train_end = time.time()