train.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import copy
  16. import os
  17. import random
  18. import time
  19. import torch
  20. import amp_C
  21. import numpy as np
  22. import torch.distributed as dist
  23. from apex.optimizers import FusedLAMB, FusedNovoGrad
  24. from contextlib import suppress as empty_context
  25. from common import helpers
  26. from common.dali.data_loader import DaliDataLoader
  27. from common.dataset import AudioDataset, get_data_loader
  28. from common.features import BaseFeatures, FilterbankFeatures
  29. from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
  30. process_evaluation_epoch)
  31. from common.optimizers import AdamW, lr_policy, Novograd
  32. from common.tb_dllogger import flush_log, init_log, log
  33. from common.utils import BenchmarkStats
  34. from quartznet import config
  35. from quartznet.model import CTCLossNM, GreedyCTCDecoder, QuartzNet
  36. def parse_args():
  37. parser = argparse.ArgumentParser(description='QuartzNet')
  38. training = parser.add_argument_group('training setup')
  39. training.add_argument('--epochs', default=400, type=int,
  40. help='Number of epochs for the entire training; influences the lr schedule')
  41. training.add_argument("--warmup_epochs", default=0, type=int,
  42. help='Initial epochs of increasing learning rate')
  43. training.add_argument("--hold_epochs", default=0, type=int,
  44. help='Constant max learning rate epochs after warmup')
  45. training.add_argument('--epochs_this_job', default=0, type=int,
  46. help=('Run for a number of epochs with no effect on the lr schedule.'
  47. 'Useful for re-starting the training.'))
  48. training.add_argument('--cudnn_benchmark', action='store_true', default=True,
  49. help='Enable cudnn benchmark')
  50. training.add_argument('--amp', '--fp16', action='store_true', default=False,
  51. help='Use pytorch native mixed precision training')
  52. training.add_argument('--seed', default=None, type=int, help='Random seed')
  53. training.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int,
  54. help='GPU id used for distributed training')
  55. training.add_argument('--pre_allocate_range', default=None, type=int, nargs=2,
  56. help='Warmup with batches of length [min, max] before training')
  57. optim = parser.add_argument_group('optimization setup')
  58. optim.add_argument('--gpu_batch_size', default=32, type=int,
  59. help='Batch size for a single forward/backward pass. '
  60. 'The Effective batch size is gpu_batch_size * grad_accumulation.')
  61. optim.add_argument('--lr', default=1e-3, type=float,
  62. help='Peak learning rate')
  63. optim.add_argument("--min_lr", default=1e-5, type=float,
  64. help='minimum learning rate')
  65. optim.add_argument("--lr_policy", default='exponential', type=str,
  66. choices=['exponential', 'legacy'], help='lr scheduler')
  67. optim.add_argument("--lr_exp_gamma", default=0.99, type=float,
  68. help='gamma factor for exponential lr scheduler')
  69. optim.add_argument('--weight_decay', default=1e-3, type=float,
  70. help='Weight decay for the optimizer')
  71. optim.add_argument('--grad_accumulation', '--update-freq', default=1, type=int,
  72. help='Number of accumulation steps')
  73. optim.add_argument('--optimizer', default='novograd', type=str,
  74. choices=['novograd', 'adamw', 'lamb98', 'fused_novograd'],
  75. help='Optimization algorithm')
  76. optim.add_argument('--ema', type=float, default=0.0,
  77. help='Discount factor for exp averaging of model weights')
  78. optim.add_argument('--multi_tensor_ema', action='store_true',
  79. help='Use multi_tensor_apply for EMA')
  80. io = parser.add_argument_group('feature and checkpointing setup')
  81. io.add_argument('--dali_device', type=str, choices=['none', 'cpu', 'gpu'],
  82. default='gpu', help='Use DALI pipeline for fast data processing')
  83. io.add_argument('--resume', action='store_true',
  84. help='Try to resume from last saved checkpoint.')
  85. io.add_argument('--ckpt', default=None, type=str,
  86. help='Path to a checkpoint for resuming training')
  87. io.add_argument('--save_frequency', default=10, type=int,
  88. help='Checkpoint saving frequency in epochs')
  89. io.add_argument('--keep_milestones', default=[100, 200, 300], type=int, nargs='+',
  90. help='Milestone checkpoints to keep from removing')
  91. io.add_argument('--save_best_from', default=380, type=int,
  92. help='Epoch on which to begin tracking best checkpoint (dev WER)')
  93. io.add_argument('--eval_frequency', default=200, type=int,
  94. help='Number of steps between evaluations on dev set')
  95. io.add_argument('--log_frequency', default=25, type=int,
  96. help='Number of steps between printing training stats')
  97. io.add_argument('--prediction_frequency', default=100, type=int,
  98. help='Number of steps between printing sample decodings')
  99. io.add_argument('--model_config', type=str, required=True,
  100. help='Path of the model configuration file')
  101. io.add_argument('--train_manifests', type=str, required=True, nargs='+',
  102. help='Paths of the training dataset manifest file')
  103. io.add_argument('--val_manifests', type=str, required=True, nargs='+',
  104. help='Paths of the evaluation datasets manifest files')
  105. io.add_argument('--dataset_dir', required=True, type=str,
  106. help='Root dir of dataset')
  107. io.add_argument('--output_dir', type=str, required=True,
  108. help='Directory for logs and checkpoints')
  109. io.add_argument('--log_file', type=str, default=None,
  110. help='Path to save the training logfile.')
  111. io.add_argument('--benchmark_epochs_num', type=int, default=1,
  112. help='Number of epochs accounted in final average throughput.')
  113. io.add_argument('--override_config', type=str, action='append',
  114. help='Overrides arbitrary config value.'
  115. ' Syntax: `--override_config nested.config.key=val`.')
  116. return parser.parse_args()
  117. def reduce_tensor(tensor, num_gpus):
  118. rt = tensor.clone()
  119. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  120. return rt.true_divide(num_gpus)
  121. def init_multi_tensor_ema(model, ema_model):
  122. model_weights = list(model.state_dict().values())
  123. ema_model_weights = list(ema_model.state_dict().values())
  124. ema_overflow_buf = torch.cuda.IntTensor([0])
  125. return model_weights, ema_model_weights, ema_overflow_buf
  126. def apply_multi_tensor_ema(decay, model_weights, ema_model_weights, overflow_buf):
  127. amp_C.multi_tensor_axpby(
  128. 65536, overflow_buf,
  129. [ema_model_weights, model_weights, ema_model_weights],
  130. decay, 1-decay, -1)
  131. def apply_ema(model, ema_model, decay):
  132. if not decay:
  133. return
  134. sd = getattr(model, 'module', model).state_dict()
  135. for k, v in ema_model.state_dict().items():
  136. v.copy_(decay * v + (1 - decay) * sd[k])
  137. @torch.no_grad()
  138. def evaluate(epoch, step, val_loader, val_feat_proc, labels, model,
  139. ema_model, ctc_loss, greedy_decoder, use_amp, use_dali=False):
  140. for model, subset in [(model, 'dev'), (ema_model, 'dev_ema')]:
  141. if model is None:
  142. continue
  143. model.eval()
  144. torch.cuda.synchronize()
  145. start_time = time.time()
  146. agg = {'losses': [], 'preds': [], 'txts': []}
  147. for batch in val_loader:
  148. if use_dali:
  149. # with DALI, the data is already on GPU
  150. feat, feat_lens, txt, txt_lens = batch
  151. if val_feat_proc is not None:
  152. feat, feat_lens = val_feat_proc(feat, feat_lens)
  153. else:
  154. batch = [t.cuda(non_blocking=True) for t in batch]
  155. audio, audio_lens, txt, txt_lens = batch
  156. feat, feat_lens = val_feat_proc(audio, audio_lens)
  157. with torch.cuda.amp.autocast(enabled=use_amp):
  158. log_probs, enc_lens = model(feat, feat_lens)
  159. loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
  160. pred = greedy_decoder(log_probs)
  161. agg['losses'] += helpers.gather_losses([loss])
  162. agg['preds'] += helpers.gather_predictions([pred], labels)
  163. agg['txts'] += helpers.gather_transcripts([txt], [txt_lens], labels)
  164. wer, loss = process_evaluation_epoch(agg)
  165. torch.cuda.synchronize()
  166. log(() if epoch is None else (epoch,),
  167. step, subset, {'loss': loss, 'wer': 100.0 * wer,
  168. 'took': time.time() - start_time})
  169. model.train()
  170. return wer
  171. def main():
  172. args = parse_args()
  173. assert(torch.cuda.is_available())
  174. assert args.prediction_frequency % args.log_frequency == 0
  175. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  176. # set up distributed training
  177. multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
  178. if multi_gpu:
  179. torch.cuda.set_device(args.local_rank)
  180. dist.init_process_group(backend='nccl', init_method='env://')
  181. world_size = dist.get_world_size()
  182. print_once(f'Distributed training with {world_size} GPUs\n')
  183. else:
  184. world_size = 1
  185. if args.seed is not None:
  186. torch.manual_seed(args.seed + args.local_rank)
  187. np.random.seed(args.seed + args.local_rank)
  188. random.seed(args.seed + args.local_rank)
  189. init_log(args)
  190. cfg = config.load(args.model_config)
  191. config.apply_config_overrides(cfg, args)
  192. symbols = helpers.add_ctc_blank(cfg['labels'])
  193. assert args.grad_accumulation >= 1
  194. batch_size = args.gpu_batch_size
  195. print_once('Setting up datasets...')
  196. train_dataset_kw, train_features_kw = config.input(cfg, 'train')
  197. val_dataset_kw, val_features_kw = config.input(cfg, 'val')
  198. use_dali = args.dali_device in ('cpu', 'gpu')
  199. if use_dali:
  200. assert train_dataset_kw['ignore_offline_speed_perturbation'], \
  201. "DALI doesn't support offline speed perturbation"
  202. # pad_to_max_duration is not supported by DALI - have simple padders
  203. if train_features_kw['pad_to_max_duration']:
  204. train_feat_proc = BaseFeatures(
  205. pad_align=train_features_kw['pad_align'],
  206. pad_to_max_duration=True,
  207. max_duration=train_features_kw['max_duration'],
  208. sample_rate=train_features_kw['sample_rate'],
  209. window_size=train_features_kw['window_size'],
  210. window_stride=train_features_kw['window_stride'])
  211. train_features_kw['pad_to_max_duration'] = False
  212. else:
  213. train_feat_proc = None
  214. if val_features_kw['pad_to_max_duration']:
  215. val_feat_proc = BaseFeatures(
  216. pad_align=val_features_kw['pad_align'],
  217. pad_to_max_duration=True,
  218. max_duration=val_features_kw['max_duration'],
  219. sample_rate=val_features_kw['sample_rate'],
  220. window_size=val_features_kw['window_size'],
  221. window_stride=val_features_kw['window_stride'])
  222. val_features_kw['pad_to_max_duration'] = False
  223. else:
  224. val_feat_proc = None
  225. train_loader = DaliDataLoader(gpu_id=args.local_rank,
  226. dataset_path=args.dataset_dir,
  227. config_data=train_dataset_kw,
  228. config_features=train_features_kw,
  229. json_names=args.train_manifests,
  230. batch_size=batch_size,
  231. grad_accumulation_steps=args.grad_accumulation,
  232. pipeline_type="train",
  233. device_type=args.dali_device,
  234. symbols=symbols)
  235. val_loader = DaliDataLoader(gpu_id=args.local_rank,
  236. dataset_path=args.dataset_dir,
  237. config_data=val_dataset_kw,
  238. config_features=val_features_kw,
  239. json_names=args.val_manifests,
  240. batch_size=batch_size,
  241. pipeline_type="val",
  242. device_type=args.dali_device,
  243. symbols=symbols)
  244. else:
  245. train_dataset_kw, train_features_kw = config.input(cfg, 'train')
  246. train_dataset = AudioDataset(args.dataset_dir,
  247. args.train_manifests,
  248. symbols,
  249. **train_dataset_kw)
  250. train_loader = get_data_loader(train_dataset,
  251. batch_size,
  252. multi_gpu=multi_gpu,
  253. shuffle=True,
  254. num_workers=4)
  255. train_feat_proc = FilterbankFeatures(**train_features_kw)
  256. val_dataset_kw, val_features_kw = config.input(cfg, 'val')
  257. val_dataset = AudioDataset(args.dataset_dir,
  258. args.val_manifests,
  259. symbols,
  260. **val_dataset_kw)
  261. val_loader = get_data_loader(val_dataset,
  262. batch_size,
  263. multi_gpu=multi_gpu,
  264. shuffle=False,
  265. num_workers=4,
  266. drop_last=False)
  267. val_feat_proc = FilterbankFeatures(**val_features_kw)
  268. dur = train_dataset.duration / 3600
  269. dur_f = train_dataset.duration_filtered / 3600
  270. nsampl = len(train_dataset)
  271. print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
  272. f'filtered {dur_f:.1f}h)')
  273. if train_feat_proc is not None:
  274. train_feat_proc.cuda()
  275. if val_feat_proc is not None:
  276. val_feat_proc.cuda()
  277. steps_per_epoch = len(train_loader) // args.grad_accumulation
  278. # set up the model
  279. model = QuartzNet(encoder_kw=config.encoder(cfg),
  280. decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
  281. model.cuda()
  282. ctc_loss = CTCLossNM(n_classes=len(symbols))
  283. greedy_decoder = GreedyCTCDecoder()
  284. print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')
  285. # optimization
  286. kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
  287. if args.optimizer == "novograd":
  288. optimizer = Novograd(model.parameters(), **kw)
  289. elif args.optimizer == "adamw":
  290. optimizer = AdamW(model.parameters(), **kw)
  291. elif args.optimizer == 'lamb98':
  292. optimizer = FusedLAMB(model.parameters(), betas=(0.9, 0.98), eps=1e-9,
  293. **kw)
  294. elif args.optimizer == 'fused_novograd':
  295. optimizer = FusedNovoGrad(model.parameters(), betas=(0.95, 0),
  296. bias_correction=False, reg_inside_moment=True,
  297. grad_averaging=False, **kw)
  298. else:
  299. raise ValueError(f'Invalid optimizer "{args.optimizer}"')
  300. scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
  301. adjust_lr = lambda step, epoch, optimizer: lr_policy(
  302. step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch,
  303. warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
  304. num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr,
  305. exp_gamma=args.lr_exp_gamma)
  306. if args.ema > 0:
  307. ema_model = copy.deepcopy(model)
  308. else:
  309. ema_model = None
  310. if multi_gpu:
  311. model = torch.nn.parallel.DistributedDataParallel(
  312. model, device_ids=[args.local_rank], output_device=args.local_rank)
  313. # load checkpoint
  314. meta = {'best_wer': 10**6, 'start_epoch': 0}
  315. checkpointer = Checkpointer(args.output_dir, 'QuartzNet',
  316. args.keep_milestones)
  317. if args.resume:
  318. args.ckpt = checkpointer.last_checkpoint() or args.ckpt
  319. if args.ckpt is not None:
  320. checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)
  321. start_epoch = meta['start_epoch']
  322. best_wer = meta['best_wer']
  323. epoch = 1
  324. step = start_epoch * steps_per_epoch + 1
  325. # training loop
  326. model.train()
  327. if args.ema > 0.0:
  328. mt_ema_params = init_multi_tensor_ema(model, ema_model)
  329. # ema_model_weight_list, model_weight_list, overflow_buf_for_ema = ema_
  330. # pre-allocate
  331. if args.pre_allocate_range is not None:
  332. n_feats = train_features_kw['n_filt']
  333. pad_align = train_features_kw['pad_align']
  334. a, b = args.pre_allocate_range
  335. for n_frames in range(a, b + pad_align, pad_align):
  336. print_once(f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')
  337. feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
  338. feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
  339. txt = torch.randint(high=len(symbols)-1, size=(batch_size, 100),
  340. device='cuda')
  341. txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
  342. with torch.cuda.amp.autocast(enabled=args.amp):
  343. log_probs, enc_lens = model(feat, feat_lens)
  344. del feat
  345. loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
  346. loss.backward()
  347. model.zero_grad()
  348. torch.cuda.empty_cache()
  349. bmark_stats = BenchmarkStats()
  350. for epoch in range(start_epoch + 1, args.epochs + 1):
  351. if multi_gpu and not use_dali:
  352. train_loader.sampler.set_epoch(epoch)
  353. torch.cuda.synchronize()
  354. epoch_start_time = time.time()
  355. epoch_utts = 0
  356. epoch_loss = 0
  357. accumulated_batches = 0
  358. for batch in train_loader:
  359. if accumulated_batches == 0:
  360. step_loss = 0
  361. step_utts = 0
  362. step_start_time = time.time()
  363. if use_dali:
  364. # with DALI, the data is already on GPU
  365. feat, feat_lens, txt, txt_lens = batch
  366. if train_feat_proc is not None:
  367. feat, feat_lens = train_feat_proc(feat, feat_lens)
  368. else:
  369. batch = [t.cuda(non_blocking=True) for t in batch]
  370. audio, audio_lens, txt, txt_lens = batch
  371. feat, feat_lens = train_feat_proc(audio, audio_lens)
  372. # Use context manager to prevent redundant accumulation of gradients
  373. if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation):
  374. ctx = model.no_sync()
  375. else:
  376. ctx = empty_context()
  377. with ctx:
  378. with torch.cuda.amp.autocast(enabled=args.amp):
  379. log_probs, enc_lens = model(feat, feat_lens)
  380. loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
  381. loss /= args.grad_accumulation
  382. if multi_gpu:
  383. reduced_loss = reduce_tensor(loss.data, world_size)
  384. else:
  385. reduced_loss = loss
  386. if torch.isnan(reduced_loss).any():
  387. print_once(f'WARNING: loss is NaN; skipping update')
  388. continue
  389. else:
  390. step_loss += reduced_loss.item()
  391. step_utts += batch[0].size(0) * world_size
  392. epoch_utts += batch[0].size(0) * world_size
  393. accumulated_batches += 1
  394. scaler.scale(loss).backward()
  395. if accumulated_batches % args.grad_accumulation == 0:
  396. epoch_loss += step_loss
  397. scaler.step(optimizer)
  398. scaler.update()
  399. adjust_lr(step, epoch, optimizer)
  400. optimizer.zero_grad()
  401. if args.ema > 0.0:
  402. apply_multi_tensor_ema(args.ema, *mt_ema_params)
  403. if step % args.log_frequency == 0:
  404. preds = greedy_decoder(log_probs)
  405. wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens, symbols)
  406. if step % args.prediction_frequency == 0:
  407. print_once(f' Decoded: {pred_utt[:90]}')
  408. print_once(f' Reference: {ref[:90]}')
  409. step_time = time.time() - step_start_time
  410. log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch),
  411. step, 'train',
  412. {'loss': step_loss,
  413. 'wer': 100.0 * wer,
  414. 'throughput': step_utts / step_time,
  415. 'took': step_time,
  416. 'lrate': optimizer.param_groups[0]['lr']})
  417. step_start_time = time.time()
  418. if step % args.eval_frequency == 0:
  419. wer = evaluate(epoch, step, val_loader, val_feat_proc,
  420. symbols, model, ema_model, ctc_loss,
  421. greedy_decoder, args.amp, use_dali)
  422. if wer < best_wer and epoch >= args.save_best_from:
  423. checkpointer.save(model, ema_model, optimizer, scaler,
  424. epoch, step, best_wer, is_best=True)
  425. best_wer = wer
  426. step += 1
  427. accumulated_batches = 0
  428. # end of step
  429. # DALI iterator need to be exhausted;
  430. # if not using DALI, simulate drop_last=True with grad accumulation
  431. if not use_dali and step > steps_per_epoch * epoch:
  432. break
  433. torch.cuda.synchronize()
  434. epoch_time = time.time() - epoch_start_time
  435. epoch_loss /= steps_per_epoch
  436. log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
  437. 'took': epoch_time,
  438. 'loss': epoch_loss})
  439. bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
  440. if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
  441. checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
  442. best_wer)
  443. if 0 < args.epochs_this_job <= epoch - start_epoch:
  444. print_once(f'Finished after {args.epochs_this_job} epochs.')
  445. break
  446. # end of epoch
  447. log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))
  448. evaluate(None, step, val_loader, val_feat_proc, symbols, model,
  449. ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)
  450. if epoch == args.epochs:
  451. checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
  452. best_wer)
  453. flush_log()
  454. if __name__ == "__main__":
  455. main()