ctl_callbacks.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. import dllogger
  16. from callbacks.callbacks import Callback, CallbackContainer
  17. from distributed_utils import is_main_process
  18. from training.utils import round_dict
  19. from training.checkpoint_utils import save_checkpoint
  20. class CTLCallbackContainer(CallbackContainer):
  21. """
  22. Base class for CTLTrainer callbacks storage.
  23. """
  24. def __init__(self, trainer, callbacks):
  25. self.callbacks = callbacks
  26. self.trainer = trainer
  27. self._init_trainers()
  28. self.logs = {}
  29. super().__init__()
  30. def _init_trainers(self):
  31. for callback in self.callbacks:
  32. callback.trainer = self.trainer
  33. def on_train_begin(self, logs=None):
  34. if logs is None:
  35. logs = {}
  36. for callback in self.callbacks:
  37. callback.on_train_begin(logs)
  38. def on_train_end(self, logs=None):
  39. if logs is None:
  40. logs = {}
  41. for callback in self.callbacks:
  42. callback.on_train_end(logs)
  43. def on_epoch_begin(self, epoch, logs=None):
  44. if logs is None:
  45. logs = {}
  46. for callback in self.callbacks:
  47. callback.on_epoch_begin(epoch, logs)
  48. def on_epoch_end(self, epoch, logs=None):
  49. if logs is None:
  50. logs = {}
  51. for callback in self.callbacks:
  52. callback.on_epoch_end(epoch, logs)
  53. def on_valid_begin(self, epoch, logs=None):
  54. if logs is None:
  55. logs = {}
  56. for callback in self.callbacks:
  57. callback.on_valid_begin(epoch, logs)
  58. def on_valid_end(self, epoch, logs=None):
  59. if logs is None:
  60. logs = {}
  61. for callback in self.callbacks:
  62. callback.on_valid_end(epoch, logs)
  63. def on_batch_begin(self, batch, logs=None):
  64. if logs is None:
  65. logs = {}
  66. for callback in self.callbacks:
  67. callback.on_batch_begin(batch, logs)
  68. def on_batch_end(self, batch, logs=None):
  69. if logs is None:
  70. logs = {}
  71. for callback in self.callbacks:
  72. callback.on_batch_end(batch, logs)
  73. def on_evaluate_end(self, logs=None):
  74. if logs is None:
  75. logs = {}
  76. for callback in self.callbacks:
  77. callback.on_evaluate_end(logs)
  78. def on_evaluate_begin(self, logs=None):
  79. if logs is None:
  80. logs = {}
  81. for callback in self.callbacks:
  82. callback.on_evaluate_begin(logs)
  83. class CTLCallback(Callback):
  84. """
  85. Base class for building new CTLTrainer callbacks.
  86. """
  87. def __init__(self):
  88. self.trainer = None
  89. super().__init__()
  90. @property
  91. def trainer(self):
  92. return self._trainer
  93. @trainer.setter
  94. def trainer(self, trainer):
  95. self._trainer = trainer
  96. def on_train_begin(self, logs=None):
  97. pass
  98. def on_train_end(self, logs=None):
  99. pass
  100. def on_epoch_begin(self, epoch, logs=None):
  101. pass
  102. def on_epoch_end(self, epoch, logs=None):
  103. pass
  104. def on_valid_begin(self, epoch, logs=None):
  105. pass
  106. def on_valid_end(self, epoch, logs=None):
  107. pass
  108. def on_batch_begin(self, batch, logs=None):
  109. pass
  110. def on_batch_end(self, batch, logs=None):
  111. pass
  112. def on_evaluate_begin(self, logs=None):
  113. pass
  114. def on_evaluate_end(self, logs=None):
  115. pass
  116. class LoggingCallback(CTLCallback):
  117. def on_train_begin(self, logs=None):
  118. self.trainer.logger.log(
  119. step='event',
  120. data={"String": "Training with {} epochs".format(self.trainer.config.get("num_epochs", 1))},
  121. verbosity=dllogger.Verbosity.DEFAULT,
  122. )
  123. def on_train_end(self, logs=None):
  124. self.trainer.logger.log(step='event', data={"String": "Training Stopped"}, verbosity=dllogger.Verbosity.DEFAULT)
  125. def on_epoch_begin(self, epoch, logs=None):
  126. self.trainer.logger.log(step='event', data={"String": "Epoch {}".format(epoch)}, verbosity=dllogger.Verbosity.DEFAULT)
  127. def on_batch_end(self, batch, logs=None):
  128. if self.trainer.config.log_interval > 0 and self.trainer.global_step % self.trainer.config.log_interval == 0:
  129. self.trainer.logger.flush()
  130. def on_valid_begin(self, epoch, logs=None):
  131. self.trainer.logger.log(
  132. step='event', data={"String": "Calculating Validation Metrics"}, verbosity=dllogger.Verbosity.DEFAULT
  133. )
  134. def on_valid_end(self, epoch, logs=None):
  135. self.trainer.logger.log(
  136. step='event',
  137. data={"String": "Epoch {} Validation Metrics: {}".format(epoch, round_dict(logs))},
  138. verbosity=dllogger.Verbosity.DEFAULT,
  139. )
  140. def on_epoch_end(self, epoch, logs=None):
  141. self.trainer.logger.flush()
  142. def on_evaluate_begin(self, logs=None):
  143. self.trainer.logger.log(
  144. step='event', data={"String": "Beginning Metric Evaluation"}, verbosity=dllogger.Verbosity.DEFAULT
  145. )
  146. def on_evaluate_end(self, logs=None):
  147. self.trainer.logger.log(
  148. step='event', data={"String": "Evaluation Metrics: {}".format(round_dict(logs))}, verbosity=dllogger.Verbosity.DEFAULT
  149. )
  150. self.trainer.logger.log(step=[], data=logs, verbosity=dllogger.Verbosity.DEFAULT)
  151. class EarlyStopping(CTLCallback):
  152. def __init__(self, metric="val_loss", min_delta=0, patience=5, max_divergence=None, divergence_patience=1):
  153. self.metric = metric
  154. self.min_delta = min_delta
  155. self.patience = patience
  156. self.max_divergence = max_divergence
  157. self.divergence_patience = divergence_patience
  158. self.divergence_stopped_epochs = 0
  159. self.stopped_epochs = 0
  160. self.best_loss = None
  161. super().__init__()
  162. def on_epoch_end(self, epoch, logs=None):
  163. epoch_loss = logs.get(self.metric, None)
  164. if epoch_loss is None:
  165. return
  166. if self.best_loss is None:
  167. self.best_loss = epoch_loss
  168. return
  169. if self.max_divergence and ((epoch_loss - self.best_loss) > self.max_divergence):
  170. self.divergence_stopped_epochs += 1
  171. self.stopped_epochs += 1
  172. if self.divergence_stopped_epochs >= self.divergence_patience:
  173. self.trainer._stop_training = True
  174. self.trainer.logger.log(
  175. step='event', data={"String": f"Applying early stopping as divergence threshold reached"}, verbosity=dllogger.Verbosity.DEFAULT
  176. )
  177. elif (epoch_loss + self.min_delta) < self.best_loss:
  178. self.best_loss = epoch_loss
  179. self.stopped_epochs = 0
  180. self.divergence_stopped_epochs = 0
  181. else:
  182. self.stopped_epochs += 1
  183. self.divergence_stopped_epochs = 0
  184. if self.stopped_epochs >= self.patience:
  185. self.trainer._stop_training = True
  186. self.trainer.logger.log(
  187. step='event', data={"String": f"Applying early stopping"}, verbosity=dllogger.Verbosity.DEFAULT
  188. )
  189. class SaveBestCheckpoint(CTLCallback):
  190. def __init__(self, metric="val_loss"):
  191. self.metric = metric
  192. self.best_loss = None
  193. super().__init__()
  194. def on_epoch_end(self, epoch, logs=None):
  195. epoch_loss = logs.get(self.metric, None)
  196. if epoch_loss is None:
  197. return
  198. if self.best_loss is None or epoch_loss < self.best_loss:
  199. self.best_loss = epoch_loss
  200. if is_main_process():
  201. save_checkpoint(self.trainer, checkpoint_dir=self.trainer.log_path, filename="best_checkpoint.zip")
  202. class SaveCheckpoint(CTLCallback):
  203. def __init__(self):
  204. super().__init__()
  205. def on_epoch_end(self, epoch, logs=None):
  206. if is_main_process():
  207. save_checkpoint(self.trainer, checkpoint_dir=self.trainer.log_path, filename="last_checkpoint.zip")
  208. class MeanAccumulator:
  209. def __init__(self):
  210. self.sum = 0
  211. self.count = 0
  212. def consume(self, value):
  213. self.sum += value
  214. self.count += 1
  215. @property
  216. def value(self):
  217. if self.count == 0:
  218. return 0
  219. return self.sum / self.count
  220. class ThroughputBenchmark(CTLCallback):
  221. def __init__(self, warmup_epochs=0):
  222. self.warmup_epochs = warmup_epochs
  223. self.train_throughput = MeanAccumulator()
  224. self.valid_throughput = MeanAccumulator()
  225. self.epoch_train_start = None
  226. self.epoch_train_end = None
  227. super().__init__()
  228. def on_train_end(self, logs=None):
  229. if self.train_throughput.value > 0:
  230. logs["Train it/s"] = self.train_throughput.value
  231. logs["Valid it/s"] = self.valid_throughput.value
  232. def on_epoch_begin(self, epoch, logs=None):
  233. self.epoch_train_start = time.time()
  234. def on_valid_end(self, epoch, logs=None):
  235. if epoch >= self.warmup_epochs:
  236. train_epoch_time = self.epoch_train_end - self.epoch_train_start
  237. valid_epoch_time = time.time() - self.epoch_train_end
  238. train_iter_per_sec = self.trainer.train_dataset_len / train_epoch_time
  239. valid_iter_per_sec = self.trainer.valid_dataset_len / valid_epoch_time
  240. logs["Train epoch it/s"] = train_iter_per_sec
  241. logs["Valid epoch it/s"] = valid_iter_per_sec
  242. self.train_throughput.consume(train_iter_per_sec)
  243. self.valid_throughput.consume(valid_iter_per_sec)
  244. def on_valid_begin(self, batch, logs=None):
  245. self.epoch_train_end = time.time()