train.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. # *****************************************************************************
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import argparse
  28. import copy
  29. import json
  30. import glob
  31. import os
  32. import re
  33. import time
  34. from collections import defaultdict, OrderedDict
  35. from contextlib import contextmanager
  36. import torch
  37. import numpy as np
  38. import torch.distributed as dist
  39. from scipy.io.wavfile import write as write_wav
  40. from torch.autograd import Variable
  41. from torch.nn.parameter import Parameter
  42. from torch.utils.data import DataLoader
  43. from torch.utils.data.distributed import DistributedSampler
  44. import dllogger as DLLogger
  45. from apex import amp
  46. from apex.optimizers import FusedAdam, FusedLAMB
  47. from apex.parallel import DistributedDataParallel as DDP
  48. import common
  49. import data_functions
  50. import loss_functions
  51. import models
  52. from common.log_helper import init_dllogger, TBLogger
  53. def parse_args(parser):
  54. """
  55. Parse commandline arguments.
  56. """
  57. parser.add_argument('-o', '--output', type=str, required=True,
  58. help='Directory to save checkpoints')
  59. parser.add_argument('-d', '--dataset-path', type=str, default='./',
  60. help='Path to dataset')
  61. parser.add_argument('--log-file', type=str, default='nvlog.json',
  62. help='Filename for logging')
  63. training = parser.add_argument_group('training setup')
  64. training.add_argument('--epochs', type=int, required=True,
  65. help='Number of total epochs to run')
  66. training.add_argument('--epochs-per-checkpoint', type=int, default=50,
  67. help='Number of epochs per checkpoint')
  68. training.add_argument('--checkpoint-path', type=str, default=None,
  69. help='Checkpoint path to resume training')
  70. training.add_argument('--checkpoint-resume', action='store_true',
  71. help='Resume training from the last available checkpoint')
  72. training.add_argument('--seed', type=int, default=1234,
  73. help='Seed for PyTorch random number generators')
  74. training.add_argument('--amp-run', action='store_true',
  75. help='Enable AMP')
  76. training.add_argument('--cuda', action='store_true',
  77. help='Run on GPU using CUDA')
  78. training.add_argument('--cudnn-enabled', action='store_true',
  79. help='Enable cudnn')
  80. training.add_argument('--cudnn-benchmark', action='store_true',
  81. help='Run cudnn benchmark')
  82. training.add_argument('--ema-decay', type=float, default=0,
  83. help='Discounting factor for training weights EMA')
  84. training.add_argument('--gradient-accumulation-steps', type=int, default=1,
  85. help='Training steps to accumulate gradients for')
  86. optimization = parser.add_argument_group('optimization setup')
  87. optimization.add_argument('--optimizer', type=str, default='lamb',
  88. help='Optimization algorithm')
  89. optimization.add_argument('-lr', '--learning-rate', type=float, required=True,
  90. help='Learing rate')
  91. optimization.add_argument('--weight-decay', default=1e-6, type=float,
  92. help='Weight decay')
  93. optimization.add_argument('--grad-clip-thresh', default=1000.0, type=float,
  94. help='Clip threshold for gradients')
  95. optimization.add_argument('-bs', '--batch-size', type=int, required=True,
  96. help='Batch size per GPU')
  97. optimization.add_argument('--warmup-steps', type=int, default=1000,
  98. help='Number of steps for lr warmup')
  99. optimization.add_argument('--dur-predictor-loss-scale', type=float,
  100. default=1.0, help='Rescale duration predictor loss')
  101. optimization.add_argument('--pitch-predictor-loss-scale', type=float,
  102. default=1.0, help='Rescale pitch predictor loss')
  103. dataset = parser.add_argument_group('dataset parameters')
  104. dataset.add_argument('--training-files', type=str, required=True,
  105. help='Path to training filelist')
  106. dataset.add_argument('--validation-files', type=str, required=True,
  107. help='Path to validation filelist')
  108. dataset.add_argument('--pitch-mean-std-file', type=str, default=None,
  109. help='Path to pitch stats to be stored in the model')
  110. dataset.add_argument('--text-cleaners', nargs='*',
  111. default=['english_cleaners'], type=str,
  112. help='Type of text cleaners for input text')
  113. distributed = parser.add_argument_group('distributed setup')
  114. distributed.add_argument('--rank', default=0, type=int,
  115. help='Rank of the process for multiproc. Do not set manually.')
  116. distributed.add_argument('--world-size', default=1, type=int,
  117. help='Number of processes for multiproc. Do not set manually.')
  118. distributed.add_argument('--dist-url', type=str, default='tcp://localhost:23456',
  119. help='Url used to set up distributed training')
  120. distributed.add_argument('--group-name', type=str, default='group_name',
  121. required=False, help='Distributed group name')
  122. distributed.add_argument('--dist-backend', default='nccl', type=str, choices={'nccl'},
  123. help='Distributed run backend')
  124. return parser
  125. def reduce_tensor(tensor, num_gpus):
  126. rt = tensor.clone()
  127. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  128. rt /= num_gpus
  129. return rt
  130. def init_distributed(args, world_size, rank, group_name):
  131. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  132. print("Initializing distributed training")
  133. # Set cuda device so everything is done on the right GPU.
  134. torch.cuda.set_device(rank % torch.cuda.device_count())
  135. # Initialize distributed communication
  136. dist.init_process_group(
  137. backend=args.dist_backend, init_method=args.dist_url,
  138. world_size=world_size, rank=rank, group_name=group_name)
  139. print("Done initializing distributed training")
  140. def last_checkpoint(output):
  141. def corrupted(fpath):
  142. try:
  143. torch.load(fpath, map_location='cpu')
  144. return False
  145. except:
  146. print(f'WARNING: Cannot load {fpath}')
  147. return True
  148. saved = sorted(
  149. glob.glob(f'{output}/FastPitch_checkpoint_*.pt'),
  150. key=lambda f: int(re.search('_(\d+).pt', f).group(1)))
  151. if len(saved) >= 1 and not corrupted(saved[-1]):
  152. return saved[-1]
  153. elif len(saved) >= 2:
  154. return saved[-2]
  155. else:
  156. return None
  157. def save_checkpoint(local_rank, model, ema_model, optimizer, epoch, config,
  158. amp_run, filepath):
  159. if local_rank != 0:
  160. return
  161. print(f"Saving model and optimizer state at epoch {epoch} to {filepath}")
  162. ema_dict = None if ema_model is None else ema_model.state_dict()
  163. checkpoint = {'epoch': epoch,
  164. 'config': config,
  165. 'state_dict': model.state_dict(),
  166. 'ema_state_dict': ema_dict,
  167. 'optimizer': optimizer.state_dict()}
  168. if amp_run:
  169. checkpoint['amp'] = amp.state_dict()
  170. torch.save(checkpoint, filepath)
  171. def load_checkpoint(local_rank, model, ema_model, optimizer, epoch, config,
  172. amp_run, filepath, world_size):
  173. if local_rank == 0:
  174. print(f'Loading model and optimizer state from {filepath}')
  175. checkpoint = torch.load(filepath, map_location='cpu')
  176. epoch[0] = checkpoint['epoch'] + 1
  177. config = checkpoint['config']
  178. sd = {k.replace('module.', ''): v
  179. for k, v in checkpoint['state_dict'].items()}
  180. getattr(model, 'module', model).load_state_dict(sd)
  181. optimizer.load_state_dict(checkpoint['optimizer'])
  182. if amp_run:
  183. amp.load_state_dict(checkpoint['amp'])
  184. if ema_model is not None:
  185. ema_model.load_state_dict(checkpoint['ema_state_dict'])
  186. def validate(model, criterion, valset, batch_size, world_size, collate_fn,
  187. distributed_run, rank, batch_to_gpu, use_gt_durations=False):
  188. """Handles all the validation scoring and printing"""
  189. was_training = model.training
  190. model.eval()
  191. with torch.no_grad():
  192. val_sampler = DistributedSampler(valset) if distributed_run else None
  193. val_loader = DataLoader(valset, num_workers=8, shuffle=False,
  194. sampler=val_sampler,
  195. batch_size=batch_size, pin_memory=False,
  196. collate_fn=collate_fn)
  197. val_meta = defaultdict(float)
  198. val_num_frames = 0
  199. for i, batch in enumerate(val_loader):
  200. x, y, num_frames = batch_to_gpu(batch)
  201. y_pred = model(x, use_gt_durations=use_gt_durations)
  202. loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
  203. if distributed_run:
  204. for k,v in meta.items():
  205. val_meta[k] += reduce_tensor(v, 1)
  206. val_num_frames += reduce_tensor(num_frames.data, 1).item()
  207. else:
  208. for k,v in meta.items():
  209. val_meta[k] += v
  210. val_num_frames = num_frames.item()
  211. val_meta = {k: v / len(valset) for k,v in val_meta.items()}
  212. val_loss = val_meta['loss']
  213. if was_training:
  214. model.train()
  215. return val_loss.item(), val_meta, val_num_frames
  216. def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
  217. if warmup_iters == 0:
  218. scale = 1.0
  219. elif total_iter > warmup_iters:
  220. scale = 1. / (total_iter ** 0.5)
  221. else:
  222. scale = total_iter / (warmup_iters ** 1.5)
  223. for param_group in opt.param_groups:
  224. param_group['lr'] = learning_rate * scale
  225. def apply_ema_decay(model, ema_model, decay):
  226. if not decay:
  227. return
  228. st = model.state_dict()
  229. add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
  230. for k,v in ema_model.state_dict().items():
  231. if add_module and not k.startswith('module.'):
  232. k = 'module.' + k
  233. v.copy_(decay * v + (1 - decay) * st[k])
  234. def main():
  235. parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
  236. allow_abbrev=False)
  237. parser = parse_args(parser)
  238. args, _ = parser.parse_known_args()
  239. if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  240. local_rank = int(os.environ['LOCAL_RANK'])
  241. world_size = int(os.environ['WORLD_SIZE'])
  242. else:
  243. local_rank = args.rank
  244. world_size = args.world_size
  245. distributed_run = world_size > 1
  246. torch.manual_seed(args.seed + local_rank)
  247. np.random.seed(args.seed + local_rank)
  248. if local_rank == 0:
  249. if not os.path.exists(args.output):
  250. os.makedirs(args.output)
  251. init_dllogger(args.log_file)
  252. else:
  253. init_dllogger(dummy=True)
  254. for k,v in vars(args).items():
  255. DLLogger.log(step="PARAMETER", data={k:v})
  256. parser = models.parse_model_args('FastPitch', parser)
  257. args, unk_args = parser.parse_known_args()
  258. if len(unk_args) > 0:
  259. raise ValueError(f'Invalid options {unk_args}')
  260. torch.backends.cudnn.enabled = args.cudnn_enabled
  261. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  262. if distributed_run:
  263. init_distributed(args, world_size, local_rank, args.group_name)
  264. device = torch.device('cuda' if args.cuda else 'cpu')
  265. model_config = models.get_model_config('FastPitch', args)
  266. model = models.get_model('FastPitch', model_config, device)
  267. # Store pitch mean/std as params to translate from Hz during inference
  268. fpath = common.utils.stats_filename(
  269. args.dataset_path, args.training_files, 'pitch_char')
  270. with open(args.pitch_mean_std_file, 'r') as f:
  271. stats = json.load(f)
  272. model.pitch_mean[0] = stats['mean']
  273. model.pitch_std[0] = stats['std']
  274. kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
  275. weight_decay=args.weight_decay)
  276. if args.optimizer == 'adam':
  277. optimizer = FusedAdam(model.parameters(), **kw)
  278. elif args.optimizer == 'lamb':
  279. optimizer = FusedLAMB(model.parameters(), **kw)
  280. else:
  281. raise ValueError
  282. if args.amp_run:
  283. model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
  284. if args.ema_decay > 0:
  285. ema_model = copy.deepcopy(model)
  286. else:
  287. ema_model = None
  288. if distributed_run:
  289. model = DDP(model)
  290. start_epoch = [1]
  291. assert args.checkpoint_path is None or args.checkpoint_resume is False, (
  292. "Specify a single checkpoint source")
  293. if args.checkpoint_path is not None:
  294. ch_fpath = args.checkpoint_path
  295. elif args.checkpoint_resume:
  296. ch_fpath = last_checkpoint(args.output)
  297. else:
  298. ch_fpath = None
  299. if ch_fpath is not None:
  300. load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch,
  301. model_config, args.amp_run, ch_fpath, world_size)
  302. start_epoch = start_epoch[0]
  303. criterion = loss_functions.get_loss_function('FastPitch',
  304. dur_predictor_loss_scale=args.dur_predictor_loss_scale,
  305. pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)
  306. collate_fn = data_functions.get_collate_function('FastPitch')
  307. trainset = data_functions.get_data_loader('FastPitch', args.dataset_path,
  308. args.training_files, args)
  309. valset = data_functions.get_data_loader('FastPitch', args.dataset_path,
  310. args.validation_files, args)
  311. if distributed_run:
  312. train_sampler, shuffle = DistributedSampler(trainset), False
  313. else:
  314. train_sampler, shuffle = None, True
  315. train_loader = DataLoader(trainset, num_workers=16, shuffle=shuffle,
  316. sampler=train_sampler, batch_size=args.batch_size,
  317. pin_memory=False, drop_last=True,
  318. collate_fn=collate_fn)
  319. batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')
  320. model.train()
  321. train_tblogger = TBLogger(local_rank, args.output, 'train')
  322. val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True)
  323. if args.ema_decay > 0:
  324. val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema')
  325. val_loss = 0.0
  326. total_iter = 0
  327. torch.cuda.synchronize()
  328. for epoch in range(start_epoch, args.epochs + 1):
  329. epoch_start_time = time.time()
  330. epoch_loss = 0.0
  331. epoch_mel_loss = 0.0
  332. epoch_num_frames = 0
  333. epoch_frames_per_sec = 0.0
  334. if distributed_run:
  335. train_loader.sampler.set_epoch(epoch)
  336. accumulated_steps = 0
  337. iter_loss = 0
  338. iter_num_frames = 0
  339. iter_meta = {}
  340. epoch_iter = 0
  341. num_iters = len(train_loader) // args.gradient_accumulation_steps
  342. for batch in train_loader:
  343. if accumulated_steps == 0:
  344. if epoch_iter == num_iters:
  345. break
  346. total_iter += 1
  347. epoch_iter += 1
  348. iter_start_time = time.time()
  349. start = time.perf_counter()
  350. old_lr = optimizer.param_groups[0]['lr']
  351. adjust_learning_rate(total_iter, optimizer, args.learning_rate,
  352. args.warmup_steps)
  353. new_lr = optimizer.param_groups[0]['lr']
  354. if new_lr != old_lr:
  355. dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}'
  356. train_tblogger.log_value(total_iter, 'lrate', new_lr)
  357. else:
  358. dllog_lrate_change = None
  359. model.zero_grad()
  360. x, y, num_frames = batch_to_gpu(batch)
  361. y_pred = model(x, use_gt_durations=True)
  362. loss, meta = criterion(y_pred, y)
  363. loss /= args.gradient_accumulation_steps
  364. meta = {k: v / args.gradient_accumulation_steps
  365. for k, v in meta.items()}
  366. if args.amp_run:
  367. with amp.scale_loss(loss, optimizer) as scaled_loss:
  368. scaled_loss.backward()
  369. else:
  370. loss.backward()
  371. if distributed_run:
  372. reduced_loss = reduce_tensor(loss.data, world_size).item()
  373. reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
  374. meta = {k: reduce_tensor(v, world_size) for k,v in meta.items()}
  375. else:
  376. reduced_loss = loss.item()
  377. reduced_num_frames = num_frames.item()
  378. if np.isnan(reduced_loss):
  379. raise Exception("loss is NaN")
  380. accumulated_steps += 1
  381. iter_loss += reduced_loss
  382. iter_num_frames += reduced_num_frames
  383. iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
  384. if accumulated_steps % args.gradient_accumulation_steps == 0:
  385. train_tblogger.log_grads(total_iter, model)
  386. if args.amp_run:
  387. torch.nn.utils.clip_grad_norm_(
  388. amp.master_params(optimizer), args.grad_clip_thresh)
  389. else:
  390. torch.nn.utils.clip_grad_norm_(
  391. model.parameters(), args.grad_clip_thresh)
  392. optimizer.step()
  393. apply_ema_decay(model, ema_model, args.ema_decay)
  394. iter_stop_time = time.time()
  395. iter_time = iter_stop_time - iter_start_time
  396. frames_per_sec = iter_num_frames / iter_time
  397. epoch_frames_per_sec += frames_per_sec
  398. epoch_loss += iter_loss
  399. epoch_num_frames += iter_num_frames
  400. iter_mel_loss = iter_meta['mel_loss'].item()
  401. epoch_mel_loss += iter_mel_loss
  402. DLLogger.log((epoch, epoch_iter, num_iters), OrderedDict([
  403. ('train_loss', iter_loss), ('train_mel_loss', iter_mel_loss),
  404. ('train_frames/s', frames_per_sec), ('took', iter_time),
  405. ('lrate_change', dllog_lrate_change)
  406. ]))
  407. train_tblogger.log_meta(total_iter, iter_meta)
  408. accumulated_steps = 0
  409. iter_loss = 0
  410. iter_num_frames = 0
  411. iter_meta = {}
  412. # Finished epoch
  413. epoch_stop_time = time.time()
  414. epoch_time = epoch_stop_time - epoch_start_time
  415. DLLogger.log((epoch,), data=OrderedDict([
  416. ('avg_train_loss', epoch_loss / epoch_iter),
  417. ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
  418. ('avg_train_frames/s', epoch_num_frames / epoch_time),
  419. ('took', epoch_time)
  420. ]))
  421. tik = time.time()
  422. val_loss, meta, num_frames = validate(
  423. model, criterion, valset, args.batch_size, world_size, collate_fn,
  424. distributed_run, local_rank, batch_to_gpu, use_gt_durations=True)
  425. tok = time.time()
  426. DLLogger.log((epoch,), data=OrderedDict([
  427. ('val_loss', val_loss),
  428. ('val_mel_loss', meta['mel_loss'].item()),
  429. ('val_frames/s', num_frames / (tok - tik)),
  430. ('took', tok - tik),
  431. ]))
  432. val_tblogger.log_meta(total_iter, meta)
  433. if args.ema_decay > 0:
  434. tik_e = time.time()
  435. val_loss_e, meta_e, num_frames_e = validate(
  436. ema_model, criterion, valset, args.batch_size, world_size,
  437. collate_fn, distributed_run, local_rank, batch_to_gpu,
  438. use_gt_durations=True)
  439. tok_e = time.time()
  440. DLLogger.log((epoch,), data=OrderedDict([
  441. ('val_ema_loss', val_loss_e),
  442. ('val_ema_mel_loss', meta_e['mel_loss'].item()),
  443. ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)),
  444. ('took', tok_e - tik_e),
  445. ]))
  446. val_ema_tblogger.log_meta(total_iter, meta)
  447. if (epoch > 0 and args.epochs_per_checkpoint > 0 and
  448. (epoch % args.epochs_per_checkpoint == 0) and local_rank == 0):
  449. checkpoint_path = os.path.join(
  450. args.output, f"FastPitch_checkpoint_{epoch}.pt")
  451. save_checkpoint(local_rank, model, ema_model, optimizer, epoch,
  452. model_config, args.amp_run, checkpoint_path)
  453. if local_rank == 0:
  454. DLLogger.flush()
  455. # Finished training
  456. DLLogger.log((), data=OrderedDict([
  457. ('avg_train_loss', epoch_loss / epoch_iter),
  458. ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
  459. ('avg_train_frames/s', epoch_num_frames / epoch_time),
  460. ]))
  461. DLLogger.log((), data=OrderedDict([
  462. ('val_loss', val_loss),
  463. ('val_mel_loss', meta['mel_loss'].item()),
  464. ('val_frames/s', num_frames / (tok - tik)),
  465. ]))
  466. if local_rank == 0:
  467. DLLogger.flush()
  468. if __name__ == '__main__':
  469. main()