train.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # *****************************************************************************
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import argparse
  28. import copy
  29. import os
  30. import time
  31. from collections import defaultdict, OrderedDict
  32. from itertools import cycle
  33. import numpy as np
  34. import torch
  35. import torch.distributed as dist
  36. import amp_C
  37. from apex.optimizers import FusedAdam, FusedLAMB
  38. from torch.nn.parallel import DistributedDataParallel
  39. from torch.utils.data import DataLoader
  40. from torch.utils.data.distributed import DistributedSampler
  41. import common.tb_dllogger as logger
  42. import models
  43. from common.tb_dllogger import log
  44. from common.repeated_dataloader import (RepeatedDataLoader,
  45. RepeatedDistributedSampler)
  46. from common.text import cmudict
  47. from common.utils import (BenchmarkStats, Checkpointer,
  48. load_pretrained_weights, prepare_tmp)
  49. from fastpitch.attn_loss_function import AttentionBinarizationLoss
  50. from fastpitch.data_function import batch_to_gpu, ensure_disjoint, TTSCollate, TTSDataset
  51. from fastpitch.loss_function import FastPitchLoss
  52. def parse_args(parser):
  53. parser.add_argument('-o', '--output', type=str, required=True,
  54. help='Directory to save checkpoints')
  55. parser.add_argument('-d', '--dataset-path', type=str, default='./',
  56. help='Path to dataset')
  57. parser.add_argument('--log-file', type=str, default=None,
  58. help='Path to a DLLogger log file')
  59. train = parser.add_argument_group('training setup')
  60. train.add_argument('--epochs', type=int, required=True,
  61. help='Number of total epochs to run')
  62. train.add_argument('--epochs-per-checkpoint', type=int, default=50,
  63. help='Number of epochs per checkpoint')
  64. train.add_argument('--checkpoint-path', type=str, default=None,
  65. help='Checkpoint path to resume training')
  66. train.add_argument('--keep-milestones', default=list(range(100, 1000, 100)),
  67. type=int, nargs='+',
  68. help='Milestone checkpoints to keep from removing')
  69. train.add_argument('--resume', action='store_true',
  70. help='Resume training from the last checkpoint')
  71. train.add_argument('--seed', type=int, default=1234,
  72. help='Seed for PyTorch random number generators')
  73. train.add_argument('--amp', action='store_true',
  74. help='Enable AMP')
  75. train.add_argument('--cuda', action='store_true',
  76. help='Run on GPU using CUDA')
  77. train.add_argument('--cudnn-benchmark', action='store_true',
  78. help='Enable cudnn benchmark mode')
  79. train.add_argument('--ema-decay', type=float, default=0,
  80. help='Discounting factor for training weights EMA')
  81. train.add_argument('--grad-accumulation', type=int, default=1,
  82. help='Training steps to accumulate gradients for')
  83. train.add_argument('--kl-loss-start-epoch', type=int, default=250,
  84. help='Start adding the hard attention loss term')
  85. train.add_argument('--kl-loss-warmup-epochs', type=int, default=100,
  86. help='Gradually increase the hard attention loss term')
  87. train.add_argument('--kl-loss-weight', type=float, default=1.0,
  88. help='Gradually increase the hard attention loss term')
  89. train.add_argument('--benchmark-epochs-num', type=int, default=20,
  90. help='Number of epochs for calculating final stats')
  91. train.add_argument('--validation-freq', type=int, default=1,
  92. help='Validate every N epochs to use less compute')
  93. train.add_argument('--init-from-checkpoint', type=str, default=None,
  94. help='Initialize model weights with a pre-trained ckpt')
  95. opt = parser.add_argument_group('optimization setup')
  96. opt.add_argument('--optimizer', type=str, default='lamb',
  97. help='Optimization algorithm')
  98. opt.add_argument('-lr', '--learning-rate', type=float, required=True,
  99. help='Learing rate')
  100. opt.add_argument('--weight-decay', default=1e-6, type=float,
  101. help='Weight decay')
  102. opt.add_argument('--grad-clip-thresh', default=1000.0, type=float,
  103. help='Clip threshold for gradients')
  104. opt.add_argument('-bs', '--batch-size', type=int, required=True,
  105. help='Batch size per GPU')
  106. opt.add_argument('--warmup-steps', type=int, default=1000,
  107. help='Number of steps for lr warmup')
  108. opt.add_argument('--dur-predictor-loss-scale', type=float,
  109. default=1.0, help='Rescale duration predictor loss')
  110. opt.add_argument('--pitch-predictor-loss-scale', type=float,
  111. default=1.0, help='Rescale pitch predictor loss')
  112. opt.add_argument('--attn-loss-scale', type=float,
  113. default=1.0, help='Rescale alignment loss')
  114. data = parser.add_argument_group('dataset parameters')
  115. data.add_argument('--training-files', type=str, nargs='*', required=True,
  116. help='Paths to training filelists.')
  117. data.add_argument('--validation-files', type=str, nargs='*',
  118. required=True, help='Paths to validation filelists')
  119. data.add_argument('--text-cleaners', nargs='*',
  120. default=['english_cleaners'], type=str,
  121. help='Type of text cleaners for input text')
  122. data.add_argument('--symbol-set', type=str, default='english_basic',
  123. help='Define symbol set for input text')
  124. data.add_argument('--p-arpabet', type=float, default=0.0,
  125. help='Probability of using arpabets instead of graphemes '
  126. 'for each word; set 0 for pure grapheme training')
  127. data.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
  128. help='Path to the list of heteronyms')
  129. data.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
  130. help='Path to the pronouncing dictionary')
  131. data.add_argument('--prepend-space-to-text', action='store_true',
  132. help='Capture leading silence with a space token')
  133. data.add_argument('--append-space-to-text', action='store_true',
  134. help='Capture trailing silence with a space token')
  135. data.add_argument('--num-workers', type=int, default=6,
  136. help='Subprocesses for train and val DataLoaders')
  137. data.add_argument('--trainloader-repeats', type=int, default=100,
  138. help='Repeats the dataset to prolong epochs')
  139. cond = parser.add_argument_group('data for conditioning')
  140. cond.add_argument('--n-speakers', type=int, default=1,
  141. help='Number of speakers in the dataset. '
  142. 'n_speakers > 1 enables speaker embeddings')
  143. cond.add_argument('--load-pitch-from-disk', action='store_true',
  144. help='Use pitch cached on disk with prepare_dataset.py')
  145. cond.add_argument('--pitch-online-method', default='pyin',
  146. choices=['pyin'],
  147. help='Calculate pitch on the fly during trainig')
  148. cond.add_argument('--pitch-online-dir', type=str, default=None,
  149. help='A directory for storing pitch calculated on-line')
  150. cond.add_argument('--pitch-mean', type=float, default=214.72203,
  151. help='Normalization value for pitch')
  152. cond.add_argument('--pitch-std', type=float, default=65.72038,
  153. help='Normalization value for pitch')
  154. cond.add_argument('--load-mel-from-disk', action='store_true',
  155. help='Use mel-spectrograms cache on the disk') # XXX
  156. audio = parser.add_argument_group('audio parameters')
  157. audio.add_argument('--max-wav-value', default=32768.0, type=float,
  158. help='Maximum audiowave value')
  159. audio.add_argument('--sampling-rate', default=22050, type=int,
  160. help='Sampling rate')
  161. audio.add_argument('--filter-length', default=1024, type=int,
  162. help='Filter length')
  163. audio.add_argument('--hop-length', default=256, type=int,
  164. help='Hop (stride) length')
  165. audio.add_argument('--win-length', default=1024, type=int,
  166. help='Window length')
  167. audio.add_argument('--mel-fmin', default=0.0, type=float,
  168. help='Minimum mel frequency')
  169. audio.add_argument('--mel-fmax', default=8000.0, type=float,
  170. help='Maximum mel frequency')
  171. dist = parser.add_argument_group('distributed setup')
  172. dist.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
  173. help='Rank of the process for multiproc; do not set manually')
  174. dist.add_argument('--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
  175. help='Number of processes for multiproc; do not set manually')
  176. return parser
  177. def reduce_tensor(tensor, num_gpus):
  178. rt = tensor.clone()
  179. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  180. return rt.true_divide(num_gpus)
  181. def init_distributed(args, world_size, rank):
  182. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  183. print("Initializing distributed training")
  184. # Set cuda device so everything is done on the right GPU.
  185. torch.cuda.set_device(rank % torch.cuda.device_count())
  186. # Initialize distributed communication
  187. dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
  188. init_method='env://')
  189. print("Done initializing distributed training")
  190. def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
  191. batch_to_gpu, ema=False):
  192. was_training = model.training
  193. model.eval()
  194. tik = time.perf_counter()
  195. with torch.no_grad():
  196. val_meta = defaultdict(float)
  197. val_num_frames = 0
  198. for i, batch in enumerate(val_loader):
  199. x, y, num_frames = batch_to_gpu(batch)
  200. y_pred = model(x)
  201. loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
  202. if distributed_run:
  203. for k, v in meta.items():
  204. val_meta[k] += reduce_tensor(v, 1)
  205. val_num_frames += reduce_tensor(num_frames.data, 1).item()
  206. else:
  207. for k, v in meta.items():
  208. val_meta[k] += v
  209. val_num_frames += num_frames.item()
  210. val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
  211. val_meta['took'] = time.perf_counter() - tik
  212. log((epoch,) if epoch is not None else (), tb_total_steps=total_iter,
  213. subset='val_ema' if ema else 'val',
  214. data=OrderedDict([
  215. ('loss', val_meta['loss'].item()),
  216. ('mel_loss', val_meta['mel_loss'].item()),
  217. ('frames/s', val_num_frames / val_meta['took']),
  218. ('took', val_meta['took'])]),
  219. )
  220. if was_training:
  221. model.train()
  222. return val_meta
  223. def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
  224. if warmup_iters == 0:
  225. scale = 1.0
  226. elif total_iter > warmup_iters:
  227. scale = 1. / (total_iter ** 0.5)
  228. else:
  229. scale = total_iter / (warmup_iters ** 1.5)
  230. for param_group in opt.param_groups:
  231. param_group['lr'] = learning_rate * scale
  232. def apply_ema_decay(model, ema_model, decay):
  233. if not decay:
  234. return
  235. st = model.state_dict()
  236. add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
  237. for k, v in ema_model.state_dict().items():
  238. if add_module and not k.startswith('module.'):
  239. k = 'module.' + k
  240. v.copy_(decay * v + (1 - decay) * st[k])
  241. def init_multi_tensor_ema(model, ema_model):
  242. model_weights = list(model.state_dict().values())
  243. ema_model_weights = list(ema_model.state_dict().values())
  244. ema_overflow_buf = torch.cuda.IntTensor([0])
  245. return model_weights, ema_model_weights, ema_overflow_buf
  246. def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
  247. amp_C.multi_tensor_axpby(
  248. 65536, overflow_buf, [ema_weights, model_weights, ema_weights],
  249. decay, 1-decay, -1)
  250. def main():
  251. parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
  252. allow_abbrev=False)
  253. parser = parse_args(parser)
  254. args, _ = parser.parse_known_args()
  255. if args.p_arpabet > 0.0:
  256. cmudict.initialize(args.cmudict_path, args.heteronyms_path)
  257. distributed_run = args.world_size > 1
  258. torch.manual_seed(args.seed + args.local_rank)
  259. np.random.seed(args.seed + args.local_rank)
  260. if args.local_rank == 0:
  261. if not os.path.exists(args.output):
  262. os.makedirs(args.output)
  263. log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
  264. tb_subsets = ['train', 'val']
  265. if args.ema_decay > 0.0:
  266. tb_subsets.append('val_ema')
  267. logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
  268. tb_subsets=tb_subsets)
  269. logger.parameters(vars(args), tb_subset='train')
  270. parser = models.parse_model_args('FastPitch', parser)
  271. args, unk_args = parser.parse_known_args()
  272. if len(unk_args) > 0:
  273. raise ValueError(f'Invalid options {unk_args}')
  274. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  275. if distributed_run:
  276. init_distributed(args, args.world_size, args.local_rank)
  277. else:
  278. if args.trainloader_repeats > 1:
  279. print('WARNING: Disabled --trainloader-repeats, supported only for'
  280. ' multi-GPU data loading.')
  281. args.trainloader_repeats = 1
  282. device = torch.device('cuda' if args.cuda else 'cpu')
  283. model_config = models.get_model_config('FastPitch', args)
  284. model = models.get_model('FastPitch', model_config, device)
  285. if args.init_from_checkpoint is not None:
  286. load_pretrained_weights(model, args.init_from_checkpoint)
  287. attention_kl_loss = AttentionBinarizationLoss()
  288. # Store pitch mean/std as params to translate from Hz during inference
  289. model.pitch_mean[0] = args.pitch_mean
  290. model.pitch_std[0] = args.pitch_std
  291. kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
  292. weight_decay=args.weight_decay)
  293. if args.optimizer == 'adam':
  294. optimizer = FusedAdam(model.parameters(), **kw)
  295. elif args.optimizer == 'lamb':
  296. optimizer = FusedLAMB(model.parameters(), **kw)
  297. else:
  298. raise ValueError
  299. scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
  300. if args.ema_decay > 0:
  301. ema_model = copy.deepcopy(model)
  302. else:
  303. ema_model = None
  304. if distributed_run:
  305. model = DistributedDataParallel(
  306. model, device_ids=[args.local_rank], output_device=args.local_rank,
  307. find_unused_parameters=True)
  308. train_state = {'epoch': 1, 'total_iter': 1}
  309. checkpointer = Checkpointer(args.output, args.keep_milestones)
  310. checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
  311. ema_model)
  312. start_epoch = train_state['epoch']
  313. total_iter = train_state['total_iter']
  314. criterion = FastPitchLoss(
  315. dur_predictor_loss_scale=args.dur_predictor_loss_scale,
  316. pitch_predictor_loss_scale=args.pitch_predictor_loss_scale,
  317. attn_loss_scale=args.attn_loss_scale)
  318. collate_fn = TTSCollate()
  319. if args.local_rank == 0:
  320. prepare_tmp(args.pitch_online_dir)
  321. trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
  322. valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
  323. ensure_disjoint(trainset, valset)
  324. if distributed_run:
  325. train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
  326. trainset, drop_last=True)
  327. val_sampler = DistributedSampler(valset)
  328. shuffle = False
  329. else:
  330. train_sampler, val_sampler, shuffle = None, None, True
  331. # 4 workers are optimal on DGX-1 (from epoch 2 onwards)
  332. kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
  333. 'collate_fn': collate_fn}
  334. train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
  335. shuffle=shuffle, drop_last=True,
  336. sampler=train_sampler, pin_memory=True,
  337. persistent_workers=True, **kw)
  338. val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
  339. pin_memory=False, **kw)
  340. if args.ema_decay:
  341. mt_ema_params = init_multi_tensor_ema(model, ema_model)
  342. model.train()
  343. bmark_stats = BenchmarkStats()
  344. torch.cuda.synchronize()
  345. for epoch in range(start_epoch, args.epochs + 1):
  346. epoch_start_time = time.perf_counter()
  347. epoch_loss = 0.0
  348. epoch_mel_loss = 0.0
  349. epoch_num_frames = 0
  350. epoch_frames_per_sec = 0.0
  351. if distributed_run:
  352. train_loader.sampler.set_epoch(epoch)
  353. iter_loss = 0
  354. iter_num_frames = 0
  355. iter_meta = {}
  356. iter_start_time = time.perf_counter()
  357. epoch_iter = 1
  358. for batch, accum_step in zip(train_loader,
  359. cycle(range(1, args.grad_accumulation + 1))):
  360. if accum_step == 1:
  361. adjust_learning_rate(total_iter, optimizer, args.learning_rate,
  362. args.warmup_steps)
  363. model.zero_grad(set_to_none=True)
  364. x, y, num_frames = batch_to_gpu(batch)
  365. with torch.cuda.amp.autocast(enabled=args.amp):
  366. y_pred = model(x)
  367. loss, meta = criterion(y_pred, y)
  368. if (args.kl_loss_start_epoch is not None
  369. and epoch >= args.kl_loss_start_epoch):
  370. if args.kl_loss_start_epoch == epoch and epoch_iter == 1:
  371. print('Begin hard_attn loss')
  372. _, _, _, _, _, _, _, _, attn_soft, attn_hard, _, _ = y_pred
  373. binarization_loss = attention_kl_loss(attn_hard, attn_soft)
  374. kl_weight = min((epoch - args.kl_loss_start_epoch) / args.kl_loss_warmup_epochs, 1.0) * args.kl_loss_weight
  375. meta['kl_loss'] = binarization_loss.clone().detach() * kl_weight
  376. loss += kl_weight * binarization_loss
  377. else:
  378. meta['kl_loss'] = torch.zeros_like(loss)
  379. kl_weight = 0
  380. binarization_loss = 0
  381. loss /= args.grad_accumulation
  382. meta = {k: v / args.grad_accumulation
  383. for k, v in meta.items()}
  384. if args.amp:
  385. scaler.scale(loss).backward()
  386. else:
  387. loss.backward()
  388. if distributed_run:
  389. reduced_loss = reduce_tensor(loss.data, args.world_size).item()
  390. reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
  391. meta = {k: reduce_tensor(v, args.world_size) for k, v in meta.items()}
  392. else:
  393. reduced_loss = loss.item()
  394. reduced_num_frames = num_frames.item()
  395. if np.isnan(reduced_loss):
  396. raise Exception("loss is NaN")
  397. iter_loss += reduced_loss
  398. iter_num_frames += reduced_num_frames
  399. iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
  400. if accum_step % args.grad_accumulation == 0:
  401. logger.log_grads_tb(total_iter, model)
  402. if args.amp:
  403. scaler.unscale_(optimizer)
  404. torch.nn.utils.clip_grad_norm_(
  405. model.parameters(), args.grad_clip_thresh)
  406. scaler.step(optimizer)
  407. scaler.update()
  408. else:
  409. torch.nn.utils.clip_grad_norm_(
  410. model.parameters(), args.grad_clip_thresh)
  411. optimizer.step()
  412. if args.ema_decay > 0.0:
  413. apply_multi_tensor_ema(args.ema_decay, *mt_ema_params)
  414. iter_mel_loss = iter_meta['mel_loss'].item()
  415. iter_kl_loss = iter_meta['kl_loss'].item()
  416. iter_time = time.perf_counter() - iter_start_time
  417. epoch_frames_per_sec += iter_num_frames / iter_time
  418. epoch_loss += iter_loss
  419. epoch_num_frames += iter_num_frames
  420. epoch_mel_loss += iter_mel_loss
  421. num_iters = len(train_loader) // args.grad_accumulation
  422. log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
  423. subset='train', data=OrderedDict([
  424. ('loss', iter_loss),
  425. ('mel_loss', iter_mel_loss),
  426. ('kl_loss', iter_kl_loss),
  427. ('kl_weight', kl_weight),
  428. ('frames/s', iter_num_frames / iter_time),
  429. ('took', iter_time),
  430. ('lrate', optimizer.param_groups[0]['lr'])]),
  431. )
  432. iter_loss = 0
  433. iter_num_frames = 0
  434. iter_meta = {}
  435. iter_start_time = time.perf_counter()
  436. if epoch_iter == num_iters:
  437. break
  438. epoch_iter += 1
  439. total_iter += 1
  440. # Finished epoch
  441. epoch_loss /= epoch_iter
  442. epoch_mel_loss /= epoch_iter
  443. epoch_time = time.perf_counter() - epoch_start_time
  444. log((epoch,), tb_total_steps=None, subset='train_avg',
  445. data=OrderedDict([
  446. ('loss', epoch_loss),
  447. ('mel_loss', epoch_mel_loss),
  448. ('frames/s', epoch_num_frames / epoch_time),
  449. ('took', epoch_time)]),
  450. )
  451. bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
  452. epoch_time)
  453. if epoch % args.validation_freq == 0:
  454. validate(model, epoch, total_iter, criterion, val_loader,
  455. distributed_run, batch_to_gpu)
  456. if args.ema_decay > 0:
  457. validate(ema_model, epoch, total_iter, criterion, val_loader,
  458. distributed_run, batch_to_gpu, ema=True)
  459. # save before making sched.step() for proper loading of LR
  460. checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
  461. epoch, total_iter, model_config)
  462. logger.flush()
  463. # Finished training
  464. if len(bmark_stats) > 0:
  465. log((), tb_total_steps=None, subset='train_avg',
  466. data=bmark_stats.get(args.benchmark_epochs_num))
  467. validate(model, None, total_iter, criterion, val_loader, distributed_run,
  468. batch_to_gpu)
  469. if __name__ == '__main__':
  470. main()