trainer.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. # Redistribution and use in source and binary forms, with or without
  3. # modification, are permitted provided that the following conditions are met:
  4. # * Redistributions of source code must retain the above copyright
  5. # notice, this list of conditions and the following disclaimer.
  6. # * Redistributions in binary form must reproduce the above copyright
  7. # notice, this list of conditions and the following disclaimer in the
  8. # documentation and/or other materials provided with the distribution.
  9. # * Neither the name of the NVIDIA CORPORATION nor the
  10. # names of its contributors may be used to endorse or promote products
  11. # derived from this software without specific prior written permission.
  12. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  14. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  15. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  16. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  17. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  18. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  19. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  20. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  21. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  22. import abc
  23. import glob
  24. import pathlib
  25. import numpy as np
  26. import torch
  27. from tensorboardX import SummaryWriter
  28. import time
  29. import os
  30. import matplotlib.pyplot as plt
  31. from torch import nn
  32. from fastspeech.utils.logging import tprint
  33. from fastspeech.utils.pytorch import to_device_async
  34. from fastspeech.utils.nvtx import Nvtx
  35. from fastspeech.utils.fp16 import cast_model_to_half
  36. import torch.cuda.profiler as profiler
  37. from fastspeech.utils.logging import tprint
  38. from fastspeech.utils.time import TimeElapsed
  39. plt.switch_backend('Agg')
  40. class Trainer(object):
  41. """
  42. set seed
  43. set n_epochs, n_steps
  44. save/load model
  45. validation
  46. logging
  47. distributed
  48. """
  49. def __init__(self, data_loader, model_name, model, optimizer_fn, final_steps, lr_scheduler_fn=None, step=0, ckpt_path=None, log_path=None, n_epochs=None, save_steps=None, log_steps=10, device='cuda', use_amp=False, nvprof_iter_start=None, nvprof_iter_end=None, pyprof_enabled=False, detect_anomaly=False, seed=None):
  50. self.data_loader = data_loader
  51. self.model_name = model_name
  52. self.model = model
  53. self.n_epochs = n_epochs
  54. self.save_steps = save_steps
  55. self.log_steps = log_steps
  56. self.ckpt_path = ckpt_path
  57. self.log_path = log_path
  58. self.final_steps = final_steps
  59. self.step = step
  60. self.device = device
  61. self.use_amp = use_amp
  62. self.nvprof_iter_start = nvprof_iter_start
  63. self.nvprof_iter_end = nvprof_iter_end
  64. self.pyprof_enabled = pyprof_enabled
  65. self.detect_anomaly = detect_anomaly
  66. # model
  67. self.model.train()
  68. to_device_async(self.model, self.device)
  69. num_param = sum(param.numel() for param in model.parameters())
  70. tprint('The number of {} parameters: {}'.format(
  71. self.model_name, num_param))
  72. # optimizer
  73. self.optimizer = optimizer_fn(model)
  74. # lr scheduler
  75. if lr_scheduler_fn:
  76. self.lr_scheduler = lr_scheduler_fn(self.optimizer)
  77. else:
  78. self.lr_scheduler = None
  79. # automatic mixed precision
  80. if self.use_amp:
  81. from apex import amp
  82. self.model, self.optimizer = amp.initialize(self.model,
  83. self.optimizer,
  84. opt_level='O1')
  85. # profile
  86. if nvprof_iter_start and nvprof_iter_end is not None and pyprof_enabled:
  87. from apex import pyprof
  88. pyprof.nvtx.init()
  89. # data parallel
  90. self.model = nn.DataParallel(self.model)
  91. # set seed
  92. if seed is None:
  93. seed = np.random.randint(2**16)
  94. np.random.seed(seed)
  95. torch.manual_seed(seed)
  96. # data loader
  97. self.data_loader_iter = self.repeat(self.data_loader, n_epochs)
  98. # logging
  99. if log_path:
  100. # tensorboard log path : {log_path}/YYYYMMDD-HHMMMSS
  101. log_path = os.path.join(log_path, time.strftime('%Y%m%d-%H%M%S'))
  102. self.tbwriter = SummaryWriter(log_dir=log_path, flush_secs=10)
  103. # checkpoint path
  104. if self.ckpt_path:
  105. self.ckpt_path = os.path.join(self.ckpt_path, self.model_name)
  106. pathlib.Path(self.ckpt_path).mkdir(parents=True, exist_ok=True)
  107. # load checkpoint
  108. self.load()
  109. def train(self):
  110. try:
  111. with torch.autograd.profiler.emit_nvtx(enabled=self.pyprof_enabled):
  112. for i in range(self.step+1, self.final_steps + 1):
  113. self.step = i
  114. tprint("------------- TRAIN step : {} -------------".format(i))
  115. if self.nvprof_iter_start and i == self.nvprof_iter_start:
  116. profiler.start()
  117. timer = TimeElapsed(name="Training time during profiling", format=":.6f")
  118. timer.start()
  119. with Nvtx("step #{}".format(self.step)):
  120. loss, meta = self.do_step()
  121. if self.nvprof_iter_end and i == self.nvprof_iter_end:
  122. profiler.stop()
  123. timer.end()
  124. if self.lr_scheduler:
  125. for param_group in self.optimizer.param_groups:
  126. tprint("lr: {:06f}".format(param_group['lr']))
  127. self.lr_scheduler.step(self.step)
  128. if self.step % self.log_steps == 0:
  129. self.log(loss, meta)
  130. if self.ckpt_path and self.save_steps and i % self.save_steps == 0:
  131. self.save()
  132. tprint("Training has been done.")
  133. except StopIteration: # done by n_epochs
  134. tprint("Training has been done. (by n_epochs)")
  135. except KeyboardInterrupt:
  136. tprint("Training has been canceled.")
  137. @abc.abstractmethod
  138. def loss(self, inputs, model):
  139. raise NotImplemented
  140. def do_step(self):
  141. with Nvtx("data load", enabled=False):
  142. data = next(self.data_loader_iter)
  143. with torch.autograd.set_detect_anomaly(mode=self.detect_anomaly):
  144. with Nvtx("forward"):
  145. loss, meta = self.loss(data, self.model)
  146. self.optimizer.zero_grad()
  147. with Nvtx("backward"):
  148. if self.use_amp:
  149. from apex import amp
  150. with amp.scale_loss(loss, self.optimizer) as scaled_loss:
  151. scaled_loss.backward()
  152. else:
  153. loss.backward()
  154. with Nvtx("weight update"):
  155. self.optimizer.step()
  156. return loss, meta
  157. def log(self, loss, meta):
  158. self.console_log('train', loss, meta)
  159. if self.log_path:
  160. self.tensorboard_log('train', loss)
  161. def save(self):
  162. state_dict = {
  163. 'step': self.step,
  164. 'model': self.model.state_dict(),
  165. 'optim': self.optimizer.state_dict(),
  166. }
  167. torch.save(state_dict, self.ckpt_path +
  168. '/checkpoint_{:06d}.pt'.format(self.step))
  169. tprint('[Save] Model "{}". Step={}.'.format(
  170. self.model_name, self.step))
  171. def load(self, load_optim=True):
  172. files_exist = glob.glob(os.path.join(self.ckpt_path, '*'))
  173. if files_exist:
  174. # load the latest created file.
  175. latest_file = max(files_exist, key=os.path.getctime)
  176. state_dict = torch.load(latest_file)
  177. self.step = state_dict['step']
  178. self.model.load_state_dict(state_dict['model'])
  179. if load_optim:
  180. self.optimizer.load_state_dict(state_dict['optim'])
  181. tprint('[Load] Checkpoint \'{}\'. Step={}'.format(
  182. latest_file, self.step))
  183. else:
  184. tprint('No checkpoints in {}. Load skipped.'.format(self.ckpt_path))
  185. def console_log(self, tag, loss, meta):
  186. # console logging
  187. msg = 'loss: {:.6f}'.format(loss)
  188. for key, value in meta.items():
  189. msg += ',\t{}: {:.4f}'.format(key, value)
  190. tprint(msg)
  191. def tensorboard_log(self, tag, loss):
  192. self.tbwriter.add_scalar(
  193. '{}/loss'.format(tag), loss, global_step=self.step)
  194. @staticmethod
  195. def repeat(iterable, n_repeat=None):
  196. cnt = 0
  197. while n_repeat is None or cnt < n_repeat:
  198. for x in iterable:
  199. yield x
  200. cnt += 1
  201. return StopIteration()