train.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import itertools
  16. import os
  17. from functools import partial
  18. from itertools import islice
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. from torch.cuda import amp
  23. from torch.cuda.amp import autocast
  24. from torch.nn.parallel import DistributedDataParallel as DDP
  25. from torch.optim import AdamW
  26. from torch.optim.lr_scheduler import ExponentialLR
  27. from apex.optimizers import FusedAdam, FusedLAMB
  28. import models
  29. from common import tb_dllogger as logger, utils, gpu_affinity
  30. from common.utils import (Checkpointer, freeze, init_distributed, print_once,
  31. reduce_tensor, unfreeze, l2_promote)
  32. from hifigan.data_function import get_data_loader, mel_spectrogram
  33. from hifigan.logging import init_logger, Metrics
  34. from hifigan.models import (MultiPeriodDiscriminator, MultiScaleDiscriminator,
  35. feature_loss, generator_loss, discriminator_loss)
  36. def parse_args(parser):
  37. parser.add_argument('-o', '--output', type=str, required=True,
  38. help='Directory to save checkpoints')
  39. parser.add_argument('--log_file', type=str, default=None,
  40. help='Path to a DLLogger log file')
  41. train = parser.add_argument_group('training setup')
  42. train.add_argument('--epochs', type=int, required=True,
  43. help='Number of total epochs to run')
  44. train.add_argument('--epochs_this_job', type=int, default=None,
  45. help='Number of epochs in partial training run')
  46. train.add_argument('--keep_milestones', type=int, nargs='+',
  47. default=[1000, 2000, 3000, 4000, 5000, 6000],
  48. help='Milestone checkpoints to keep from removing')
  49. train.add_argument('--checkpoint_interval', type=int, default=50,
  50. help='Saving checkpoints frequency (in epochs)')
  51. train.add_argument('--step_logs_interval', default=1, type=int,
  52. help='Step logs dumping frequency (in steps)')
  53. train.add_argument('--validation_interval', default=10, type=int,
  54. help='Validation frequency (in epochs)')
  55. train.add_argument('--samples_interval', default=100, type=int,
  56. help='Dumping audio samples frequency (in epochs)')
  57. train.add_argument('--resume', action='store_true',
  58. help='Resume training from the last checkpoint')
  59. train.add_argument('--checkpoint_path_gen', type=str, default=None,
  60. help='Resume training from a selected checkpoint')
  61. train.add_argument('--checkpoint_path_discrim', type=str, default=None,
  62. help='Resume training from a selected checkpoint')
  63. train.add_argument('--seed', type=int, default=1234,
  64. help='Seed for PyTorch random number generators')
  65. train.add_argument('--amp', action='store_true',
  66. help='Enable AMP')
  67. train.add_argument('--autocast_spectrogram', action='store_true',
  68. help='Enable autocast while computing spectrograms')
  69. train.add_argument('--cuda', action='store_true',
  70. help='Run on GPU using CUDA')
  71. train.add_argument('--disable_cudnn_benchmark', action='store_true',
  72. help='Disable cudnn benchmark mode')
  73. train.add_argument('--ema_decay', type=float, default=0,
  74. help='Discounting factor for training weights EMA')
  75. train.add_argument('--grad_accumulation', type=int, default=1,
  76. help='Training steps to accumulate gradients for')
  77. train.add_argument('--num_workers', type=int, default=1,
  78. help='Data loader workers number')
  79. train.add_argument('--fine_tuning', action='store_true',
  80. help='Enable fine-tuning')
  81. train.add_argument('--input_mels_dir', type=str, default=None,
  82. help='Directory with mels for fine-tuning')
  83. train.add_argument('--benchmark_epochs_num', type=int, default=5)
  84. train.add_argument('--no_amp_grouped_conv', action='store_true',
  85. help='Disable AMP on certain convs for better perf')
  86. opt = parser.add_argument_group('optimization setup')
  87. opt.add_argument('--optimizer', type=str, default='adamw',
  88. help='Optimization algorithm')
  89. opt.add_argument('--lr_decay', type=float, default=0.9998,
  90. help='Learning rate decay')
  91. opt.add_argument('-lr', '--learning_rate', type=float, required=True,
  92. help='Learning rate')
  93. opt.add_argument('--fine_tune_lr_factor', type=float, default=1.,
  94. help='Learning rate multiplier for fine-tuning')
  95. opt.add_argument('--adam_betas', type=float, nargs=2, default=(0.8, 0.99),
  96. help='Adam Beta coefficients')
  97. opt.add_argument('--grad_clip_thresh', default=1000.0, type=float,
  98. help='Clip threshold for gradients')
  99. opt.add_argument('-bs', '--batch_size', type=int, required=True,
  100. help=('Batch size per training iter. '
  101. 'May be split into grad accumulation steps.'))
  102. opt.add_argument('--warmup_steps', type=int, default=1000,
  103. help='Number of steps for lr warmup')
  104. data = parser.add_argument_group('dataset parameters')
  105. data.add_argument('-d', '--dataset_path', default='data/LJSpeech-1.1',
  106. help='Path to dataset', type=str)
  107. data.add_argument('--training_files', type=str, required=True, nargs='+',
  108. help='Paths to training filelists.')
  109. data.add_argument('--validation_files', type=str, required=True, nargs='+',
  110. help='Paths to validation filelists.')
  111. audio = parser.add_argument_group('audio parameters')
  112. audio.add_argument('--max_wav_value', default=32768.0, type=float,
  113. help='Maximum audiowave value')
  114. audio.add_argument('--sampling_rate', default=22050, type=int,
  115. help='Sampling rate')
  116. audio.add_argument('--filter_length', default=1024, type=int,
  117. help='Filter length')
  118. audio.add_argument('--num_mels', default=80, type=int,
  119. help='number of Mel bands')
  120. audio.add_argument('--hop_length', default=256, type=int,
  121. help='Hop (stride) length')
  122. audio.add_argument('--win_length', default=1024, type=int,
  123. help='Window length')
  124. audio.add_argument('--mel_fmin', default=0.0, type=float,
  125. help='Minimum mel frequency')
  126. audio.add_argument('--mel_fmax', default=8000.0, type=float,
  127. help='Maximum mel frequency')
  128. audio.add_argument('--mel_fmax_loss', default=None, type=float,
  129. help='Maximum mel frequency used for computing loss')
  130. audio.add_argument('--segment_size', default=8192, type=int,
  131. help='Training segment size')
  132. dist = parser.add_argument_group('distributed setup')
  133. dist.add_argument(
  134. '--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0),
  135. help='Rank of the process for multiproc. Do not set manually.')
  136. dist.add_argument(
  137. '--world_size', type=int, default=os.getenv('WORLD_SIZE', 1),
  138. help='Number of processes for multiproc. Do not set manually.')
  139. dist.add_argument('--affinity', type=str,
  140. default='socket_unique_interleaved',
  141. choices=['socket', 'single', 'single_unique',
  142. 'socket_unique_interleaved',
  143. 'socket_unique_continuous',
  144. 'disabled'],
  145. help='type of CPU affinity')
  146. return parser
  147. def validate(args, gen, mel_spec, mpd, msd, val_loader, val_metrics):
  148. gen.eval()
  149. val_metrics.start_val()
  150. with torch.no_grad():
  151. for i, batch in enumerate(val_loader):
  152. x, y, _, y_mel = batch
  153. x = x.cuda(non_blocking=True)
  154. y = y.cuda(non_blocking=True).unsqueeze(1)
  155. y_mel = y_mel.cuda(non_blocking=True)
  156. with autocast(enabled=args.amp):
  157. y_g_hat = gen(x)
  158. with autocast(enabled=args.amp and args.autocast_spectrogram):
  159. y_g_hat_mel = mel_spec(y_g_hat.float().squeeze(1),
  160. fmax=args.mel_fmax_loss)
  161. with autocast(enabled=args.amp):
  162. # val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() * 45
  163. # NOTE: Scale by 45.0 to match train loss magnitude
  164. loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
  165. # MPD
  166. y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
  167. loss_disc_f = discriminator_loss(y_df_hat_r, y_df_hat_g)
  168. # MSD
  169. y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
  170. loss_disc_s = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
  171. y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
  172. y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
  173. loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
  174. loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
  175. loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
  176. loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
  177. loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
  178. val_metrics['loss_discrim'] = reduce_tensor(
  179. loss_disc_s + loss_disc_f, args.world_size)
  180. val_metrics['loss_gen'] = reduce_tensor(loss_gen_all,
  181. args.world_size)
  182. val_metrics['loss_mel'] = reduce_tensor(loss_mel, args.world_size)
  183. val_metrics['frames'] = x.size(0) * x.size(1) * args.world_size
  184. val_metrics.accumulate(scopes=['val'])
  185. val_metrics.finish_val()
  186. gen.train()
  187. def main():
  188. parser = argparse.ArgumentParser(description='PyTorch HiFi-GAN Training',
  189. allow_abbrev=False)
  190. parser = models.parse_model_args('HiFi-GAN', parse_args(parser))
  191. args, unk_args = parser.parse_known_args()
  192. if len(unk_args) > 0:
  193. raise ValueError(f'Invalid options {unk_args}')
  194. if args.affinity != 'disabled':
  195. nproc_per_node = torch.cuda.device_count()
  196. print(nproc_per_node)
  197. affinity = gpu_affinity.set_affinity(
  198. args.local_rank,
  199. nproc_per_node,
  200. args.affinity
  201. )
  202. print(f'{args.local_rank}: thread affinity: {affinity}')
  203. # seeds, distributed init, logging, cuDNN
  204. distributed_run = args.world_size > 1
  205. torch.manual_seed(args.seed + args.local_rank)
  206. np.random.seed(args.seed + args.local_rank)
  207. if distributed_run:
  208. init_distributed(args, args.world_size, args.local_rank)
  209. metrics = Metrics(scopes=['train', 'train_avg'],
  210. benchmark_epochs=args.benchmark_epochs_num,
  211. cuda=args.cuda)
  212. val_metrics = Metrics(scopes=['val'], cuda=args.cuda)
  213. init_logger(args.output, args.log_file, args.ema_decay)
  214. logger.parameters(vars(args), tb_subset='train')
  215. l2_promote()
  216. torch.backends.cudnn.benchmark = not args.disable_cudnn_benchmark
  217. train_setup = models.get_model_train_setup('HiFi-GAN', args)
  218. gen_config = models.get_model_config('HiFi-GAN', args)
  219. gen = models.get_model('HiFi-GAN', gen_config, 'cuda')
  220. mpd = MultiPeriodDiscriminator(periods=args.mpd_periods,
  221. concat_fwd=args.concat_fwd).cuda()
  222. assert args.amp or not args.no_amp_grouped_conv, \
  223. "--no-amp-grouped-conv is applicable only when AMP is enabled"
  224. msd = MultiScaleDiscriminator(concat_fwd=args.concat_fwd,
  225. no_amp_grouped_conv=args.no_amp_grouped_conv)
  226. msd = msd.cuda()
  227. mel_spec = partial(mel_spectrogram, n_fft=args.filter_length,
  228. num_mels=args.num_mels,
  229. sampling_rate=args.sampling_rate,
  230. hop_size=args.hop_length, win_size=args.win_length,
  231. fmin=args.mel_fmin)
  232. kw = {'lr': args.learning_rate, 'betas': args.adam_betas}
  233. proto = {'adam': FusedAdam, 'lamb': FusedLAMB, 'adamw': AdamW
  234. }[args.optimizer]
  235. optim_g = proto(gen.parameters(), **kw)
  236. optim_d = proto(itertools.chain(msd.parameters(), mpd.parameters()), **kw)
  237. scaler_g = amp.GradScaler(enabled=args.amp)
  238. scaler_d = amp.GradScaler(enabled=args.amp)
  239. # setup EMA
  240. if args.ema_decay > 0:
  241. # burried import, requires apex
  242. from common.ema_utils import (apply_multi_tensor_ema,
  243. init_multi_tensor_ema)
  244. gen_ema = models.get_model('HiFi-GAN', gen_config, 'cuda').cuda()
  245. mpd_ema = MultiPeriodDiscriminator(
  246. periods=args.mpd_periods,
  247. concat_fwd=args.concat_fwd).cuda()
  248. msd_ema = MultiScaleDiscriminator(
  249. concat_fwd=args.concat_fwd,
  250. no_amp_grouped_conv=args.no_amp_grouped_conv).cuda()
  251. else:
  252. gen_ema, mpd_ema, msd_ema = None, None, None
  253. # setup DDP
  254. if distributed_run:
  255. kw = {'device_ids': [args.local_rank],
  256. 'output_device': args.local_rank}
  257. gen = DDP(gen, **kw)
  258. msd = DDP(msd, **kw)
  259. # DDP needs nonempty model
  260. mpd = DDP(mpd, **kw) if len(args.mpd_periods) else mpd
  261. # resume from last / load a checkpoint
  262. train_state = {}
  263. checkpointer = Checkpointer(args.output, args.keep_milestones)
  264. checkpointer.maybe_load(
  265. gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d, train_state, args,
  266. gen_ema=None, mpd_ema=None, msd_ema=None)
  267. iters_all = train_state.get('iters_all', 0)
  268. last_epoch = train_state['epoch'] + 1 if 'epoch' in train_state else -1
  269. sched_g = ExponentialLR(optim_g, gamma=args.lr_decay, last_epoch=last_epoch)
  270. sched_d = ExponentialLR(optim_d, gamma=args.lr_decay, last_epoch=last_epoch)
  271. if args.fine_tuning:
  272. print_once('Doing fine-tuning')
  273. train_loader = get_data_loader(args, distributed_run, train=True)
  274. val_loader = get_data_loader(args, distributed_run, train=False,
  275. val_kwargs=dict(repeat=5, split=True))
  276. val_samples_loader = get_data_loader(args, False, train=False,
  277. val_kwargs=dict(split=False),
  278. batch_size=1)
  279. if args.ema_decay > 0.0:
  280. gen_ema_params = init_multi_tensor_ema(gen, gen_ema)
  281. mpd_ema_params = init_multi_tensor_ema(mpd, mpd_ema)
  282. msd_ema_params = init_multi_tensor_ema(msd, msd_ema)
  283. epochs_done = 0
  284. for epoch in range(max(1, last_epoch), args.epochs + 1):
  285. metrics.start_epoch(epoch)
  286. if distributed_run:
  287. train_loader.sampler.set_epoch(epoch)
  288. gen.train()
  289. mpd.train()
  290. msd.train()
  291. iter_ = 0
  292. iters_num = len(train_loader) // args.grad_accumulation
  293. for step, batch in enumerate(train_loader):
  294. if step // args.grad_accumulation >= iters_num:
  295. break # only full effective batches
  296. is_first_accum_step = step % args.grad_accumulation == 0
  297. is_last_accum_step = (step + 1) % args.grad_accumulation == 0
  298. assert (args.grad_accumulation > 1
  299. or (is_first_accum_step and is_last_accum_step))
  300. if is_first_accum_step:
  301. iter_ += 1
  302. iters_all += 1
  303. metrics.start_iter(iter_)
  304. accum_batches = []
  305. optim_d.zero_grad(set_to_none=True)
  306. optim_g.zero_grad(set_to_none=True)
  307. x, y, _, y_mel = batch
  308. x = x.cuda(non_blocking=True)
  309. y = y.cuda(non_blocking=True).unsqueeze(1)
  310. y_mel = y_mel.cuda(non_blocking=True)
  311. accum_batches.append((x, y, y_mel))
  312. with torch.set_grad_enabled(is_last_accum_step), \
  313. autocast(enabled=args.amp):
  314. y_g_hat = gen(x)
  315. unfreeze(mpd)
  316. unfreeze(msd)
  317. with autocast(enabled=args.amp):
  318. # MPD
  319. y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
  320. loss_disc_f = discriminator_loss(y_df_hat_r, y_df_hat_g)
  321. # MSD
  322. y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
  323. loss_disc_s = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
  324. loss_disc_all = loss_disc_s + loss_disc_f
  325. metrics['loss_discrim'] = reduce_tensor(loss_disc_all, args.world_size)
  326. metrics['frames'] = x.size(0) * x.size(1) * args.world_size
  327. metrics.accumulate()
  328. loss_disc_all /= args.grad_accumulation
  329. scaler_d.scale(loss_disc_all).backward()
  330. if not is_last_accum_step:
  331. continue
  332. scaler_d.step(optim_d)
  333. scaler_d.update()
  334. # generator
  335. freeze(mpd)
  336. freeze(msd)
  337. for _i, (x, y, y_mel) in enumerate(reversed(accum_batches)):
  338. if _i != 0: # first `y_g_hat` can be reused
  339. with autocast(enabled=args.amp):
  340. y_g_hat = gen(x)
  341. with autocast(enabled=args.amp and args.autocast_spectrogram):
  342. y_g_hat_mel = mel_spec(y_g_hat.float().squeeze(1),
  343. fmax=args.mel_fmax_loss)
  344. # L1 mel-spectrogram Loss
  345. with autocast(enabled=args.amp):
  346. loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
  347. y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
  348. y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
  349. loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
  350. loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
  351. loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
  352. loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
  353. loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
  354. metrics['loss_gen'] = reduce_tensor(loss_gen_all, args.world_size)
  355. metrics['loss_mel'] = reduce_tensor(loss_mel, args.world_size)
  356. metrics.accumulate()
  357. loss_gen_all /= args.grad_accumulation
  358. scaler_g.scale(loss_gen_all).backward()
  359. scaler_g.step(optim_g)
  360. scaler_g.update()
  361. metrics['lrate_gen'] = optim_g.param_groups[0]['lr']
  362. metrics['lrate_discrim'] = optim_d.param_groups[0]['lr']
  363. metrics.accumulate()
  364. if args.ema_decay > 0.0:
  365. apply_multi_tensor_ema(args.ema_decay, *gen_ema_params)
  366. apply_multi_tensor_ema(args.ema_decay, *mpd_ema_params)
  367. apply_multi_tensor_ema(args.ema_decay, *msd_ema_params)
  368. metrics.finish_iter() # done accumulating
  369. if iters_all % args.step_logs_interval == 0:
  370. logger.log((epoch, iter_, iters_num), metrics, scope='train',
  371. tb_iter=iters_all, flush_log=True)
  372. assert is_last_accum_step
  373. metrics.finish_epoch()
  374. logger.log((epoch,), metrics, scope='train_avg', flush_log=True)
  375. if epoch % args.validation_interval == 0:
  376. validate(args, gen, mel_spec, mpd, msd, val_loader, val_metrics)
  377. logger.log((epoch,), val_metrics, scope='val', tb_iter=iters_all,
  378. flush_log=True)
  379. # validation samples
  380. if epoch % args.samples_interval == 0 and args.local_rank == 0:
  381. gen.eval()
  382. with torch.no_grad():
  383. for i, batch in enumerate(islice(val_samples_loader, 5)):
  384. x, y, _, _ = batch
  385. x = x.cuda(non_blocking=True)
  386. y = y.cuda(non_blocking=True).unsqueeze(1)
  387. with autocast(enabled=args.amp):
  388. y_g_hat = gen(x)
  389. with autocast(enabled=args.amp and args.autocast_spectrogram):
  390. # args.fmax instead of args.max_for_inference
  391. y_hat_spec = mel_spec(y_g_hat.float().squeeze(1),
  392. fmax=args.mel_fmax)
  393. logger.log_samples_tb(iters_all, i, y_g_hat, y_hat_spec,
  394. args.sampling_rate)
  395. if epoch == args.samples_interval: # ground truth
  396. logger.log_samples_tb(0, i, y, x, args.sampling_rate)
  397. gen.train()
  398. train_state.update({'epoch': epoch, 'iters_all': iters_all})
  399. # save before making sched.step() for proper loading of LR
  400. checkpointer.maybe_save(
  401. gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d, epoch,
  402. train_state, args, gen_config, train_setup,
  403. gen_ema=gen_ema, mpd_ema=mpd_ema, msd_ema=msd_ema)
  404. logger.flush()
  405. sched_g.step()
  406. sched_d.step()
  407. epochs_done += 1
  408. if (args.epochs_this_job is not None
  409. and epochs_done == args.epochs_this_job):
  410. break
  411. # finished training
  412. if epochs_done > 0:
  413. logger.log((), metrics, scope='train_benchmark', flush_log=True)
  414. if epoch % args.validation_interval != 0: # val metrics are not up-to-date
  415. validate(args, gen, mel_spec, mpd, msd, val_loader, val_metrics)
  416. logger.log((), val_metrics, scope='val', flush_log=True)
  417. else:
  418. print_once(f'Finished without training after epoch {args.epochs}.')
  419. if __name__ == '__main__':
  420. main()