train.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import copy
  16. import os
  17. import random
  18. import time
  19. import torch
  20. import numpy as np
  21. import torch.distributed as dist
  22. from contextlib import suppress as empty_context
  23. from common import helpers
  24. from common.dali.data_loader import DaliDataLoader
  25. from common.dataset import AudioDataset, get_data_loader
  26. from common.features import BaseFeatures, FilterbankFeatures
  27. from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
  28. process_evaluation_epoch)
  29. from common.optimizers import AdamW, lr_policy, Novograd
  30. from common.tb_dllogger import flush_log, init_log, log
  31. from common.utils import BenchmarkStats
  32. from jasper import config
  33. from jasper.model import CTCLossNM, GreedyCTCDecoder, Jasper
  34. def parse_args():
  35. parser = argparse.ArgumentParser(description='Jasper')
  36. training = parser.add_argument_group('training setup')
  37. training.add_argument('--epochs', default=400, type=int,
  38. help='Number of epochs for the entire training; influences the lr schedule')
  39. training.add_argument("--warmup_epochs", default=0, type=int,
  40. help='Initial epochs of increasing learning rate')
  41. training.add_argument("--hold_epochs", default=0, type=int,
  42. help='Constant max learning rate epochs after warmup')
  43. training.add_argument('--epochs_this_job', default=0, type=int,
  44. help=('Run for a number of epochs with no effect on the lr schedule.'
  45. 'Useful for re-starting the training.'))
  46. training.add_argument('--cudnn_benchmark', action='store_true', default=True,
  47. help='Enable cudnn benchmark')
  48. training.add_argument('--amp', '--fp16', action='store_true', default=False,
  49. help='Use pytorch native mixed precision training')
  50. training.add_argument('--seed', default=42, type=int, help='Random seed')
  51. training.add_argument('--local_rank', '--local-rank', default=os.getenv('LOCAL_RANK', 0),
  52. type=int, help='GPU id used for distributed training')
  53. training.add_argument('--pre_allocate_range', default=None, type=int, nargs=2,
  54. help='Warmup with batches of length [min, max] before training')
  55. optim = parser.add_argument_group('optimization setup')
  56. optim.add_argument('--batch_size', default=32, type=int,
  57. help='Global batch size')
  58. optim.add_argument('--lr', default=1e-3, type=float,
  59. help='Peak learning rate')
  60. optim.add_argument("--min_lr", default=1e-5, type=float,
  61. help='minimum learning rate')
  62. optim.add_argument("--lr_policy", default='exponential', type=str,
  63. choices=['exponential', 'legacy'], help='lr scheduler')
  64. optim.add_argument("--lr_exp_gamma", default=0.99, type=float,
  65. help='gamma factor for exponential lr scheduler')
  66. optim.add_argument('--weight_decay', default=1e-3, type=float,
  67. help='Weight decay for the optimizer')
  68. optim.add_argument('--grad_accumulation_steps', default=1, type=int,
  69. help='Number of accumulation steps')
  70. optim.add_argument('--optimizer', default='novograd', type=str,
  71. choices=['novograd', 'adamw'], help='Optimization algorithm')
  72. optim.add_argument('--ema', type=float, default=0.0,
  73. help='Discount factor for exp averaging of model weights')
  74. io = parser.add_argument_group('feature and checkpointing setup')
  75. io.add_argument('--dali_device', type=str, choices=['none', 'cpu', 'gpu'],
  76. default='gpu', help='Use DALI pipeline for fast data processing')
  77. io.add_argument('--resume', action='store_true',
  78. help='Try to resume from last saved checkpoint.')
  79. io.add_argument('--ckpt', default=None, type=str,
  80. help='Path to a checkpoint for resuming training')
  81. io.add_argument('--save_frequency', default=10, type=int,
  82. help='Checkpoint saving frequency in epochs')
  83. io.add_argument('--keep_milestones', default=[100, 200, 300], type=int, nargs='+',
  84. help='Milestone checkpoints to keep from removing')
  85. io.add_argument('--save_best_from', default=380, type=int,
  86. help='Epoch on which to begin tracking best checkpoint (dev WER)')
  87. io.add_argument('--eval_frequency', default=200, type=int,
  88. help='Number of steps between evaluations on dev set')
  89. io.add_argument('--log_frequency', default=25, type=int,
  90. help='Number of steps between printing training stats')
  91. io.add_argument('--prediction_frequency', default=100, type=int,
  92. help='Number of steps between printing sample decodings')
  93. io.add_argument('--model_config', type=str, required=True,
  94. help='Path of the model configuration file')
  95. io.add_argument('--train_manifests', type=str, required=True, nargs='+',
  96. help='Paths of the training dataset manifest file')
  97. io.add_argument('--val_manifests', type=str, required=True, nargs='+',
  98. help='Paths of the evaluation datasets manifest files')
  99. io.add_argument('--dataset_dir', required=True, type=str,
  100. help='Root dir of dataset')
  101. io.add_argument('--output_dir', type=str, required=True,
  102. help='Directory for logs and checkpoints')
  103. io.add_argument('--log_file', type=str, default=None,
  104. help='Path to save the training logfile.')
  105. io.add_argument('--benchmark_epochs_num', type=int, default=1,
  106. help='Number of epochs accounted in final average throughput.')
  107. io.add_argument('--override_config', type=str, action='append',
  108. help='Overrides a value from a config .yaml.'
  109. ' Syntax: `--override_config nested.config.key=val`.')
  110. return parser.parse_args()
  111. def reduce_tensor(tensor, num_gpus):
  112. rt = tensor.clone()
  113. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  114. return rt.true_divide(num_gpus)
  115. def apply_ema(model, ema_model, decay):
  116. if not decay:
  117. return
  118. sd = getattr(model, 'module', model).state_dict()
  119. for k, v in ema_model.state_dict().items():
  120. v.copy_(decay * v + (1 - decay) * sd[k])
  121. @torch.no_grad()
  122. def evaluate(epoch, step, val_loader, val_feat_proc, labels, model,
  123. ema_model, ctc_loss, greedy_decoder, use_amp, use_dali=False):
  124. for model, subset in [(model, 'dev'), (ema_model, 'dev_ema')]:
  125. if model is None:
  126. continue
  127. model.eval()
  128. torch.cuda.synchronize()
  129. start_time = time.time()
  130. agg = {'losses': [], 'preds': [], 'txts': []}
  131. for batch in val_loader:
  132. if use_dali:
  133. # with DALI, the data is already on GPU
  134. feat, feat_lens, txt, txt_lens = batch
  135. if val_feat_proc is not None:
  136. feat, feat_lens = val_feat_proc(feat, feat_lens)
  137. else:
  138. batch = [t.cuda(non_blocking=True) for t in batch]
  139. audio, audio_lens, txt, txt_lens = batch
  140. feat, feat_lens = val_feat_proc(audio, audio_lens)
  141. with torch.cuda.amp.autocast(enabled=use_amp):
  142. log_probs, enc_lens = model(feat, feat_lens)
  143. loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
  144. pred = greedy_decoder(log_probs)
  145. agg['losses'] += helpers.gather_losses([loss])
  146. agg['preds'] += helpers.gather_predictions([pred], labels)
  147. agg['txts'] += helpers.gather_transcripts([txt], [txt_lens], labels)
  148. wer, loss = process_evaluation_epoch(agg)
  149. torch.cuda.synchronize()
  150. log(() if epoch is None else (epoch,),
  151. step, subset, {'loss': loss, 'wer': 100.0 * wer,
  152. 'took': time.time() - start_time})
  153. model.train()
  154. return wer
  155. def main():
  156. args = parse_args()
  157. assert(torch.cuda.is_available())
  158. assert args.prediction_frequency % args.log_frequency == 0
  159. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  160. # set up distributed training
  161. multi_gpu = int(os.environ.get('WORLD_SIZE', 1)) > 1
  162. if multi_gpu:
  163. torch.cuda.set_device(args.local_rank)
  164. dist.init_process_group(backend='nccl', init_method='env://')
  165. world_size = dist.get_world_size()
  166. print_once(f'Distributed training with {world_size} GPUs\n')
  167. else:
  168. world_size = 1
  169. torch.manual_seed(args.seed + args.local_rank)
  170. np.random.seed(args.seed + args.local_rank)
  171. random.seed(args.seed + args.local_rank)
  172. init_log(args)
  173. cfg = config.load(args.model_config)
  174. config.apply_config_overrides(cfg, args)
  175. symbols = helpers.add_ctc_blank(cfg['labels'])
  176. assert args.grad_accumulation_steps >= 1
  177. assert args.batch_size % args.grad_accumulation_steps == 0
  178. batch_size = args.batch_size // args.grad_accumulation_steps
  179. print_once('Setting up datasets...')
  180. train_dataset_kw, train_features_kw = config.input(cfg, 'train')
  181. val_dataset_kw, val_features_kw = config.input(cfg, 'val')
  182. use_dali = args.dali_device in ('cpu', 'gpu')
  183. if use_dali:
  184. assert train_dataset_kw['ignore_offline_speed_perturbation'], \
  185. "DALI doesn't support offline speed perturbation"
  186. # pad_to_max_duration is not supported by DALI - have simple padders
  187. if train_features_kw['pad_to_max_duration']:
  188. train_feat_proc = BaseFeatures(
  189. pad_align=train_features_kw['pad_align'],
  190. pad_to_max_duration=True,
  191. max_duration=train_features_kw['max_duration'],
  192. sample_rate=train_features_kw['sample_rate'],
  193. window_size=train_features_kw['window_size'],
  194. window_stride=train_features_kw['window_stride'])
  195. train_features_kw['pad_to_max_duration'] = False
  196. else:
  197. train_feat_proc = None
  198. if val_features_kw['pad_to_max_duration']:
  199. val_feat_proc = BaseFeatures(
  200. pad_align=val_features_kw['pad_align'],
  201. pad_to_max_duration=True,
  202. max_duration=val_features_kw['max_duration'],
  203. sample_rate=val_features_kw['sample_rate'],
  204. window_size=val_features_kw['window_size'],
  205. window_stride=val_features_kw['window_stride'])
  206. val_features_kw['pad_to_max_duration'] = False
  207. else:
  208. val_feat_proc = None
  209. train_loader = DaliDataLoader(gpu_id=args.local_rank,
  210. dataset_path=args.dataset_dir,
  211. config_data=train_dataset_kw,
  212. config_features=train_features_kw,
  213. json_names=args.train_manifests,
  214. batch_size=batch_size,
  215. grad_accumulation_steps=args.grad_accumulation_steps,
  216. pipeline_type="train",
  217. device_type=args.dali_device,
  218. symbols=symbols)
  219. val_loader = DaliDataLoader(gpu_id=args.local_rank,
  220. dataset_path=args.dataset_dir,
  221. config_data=val_dataset_kw,
  222. config_features=val_features_kw,
  223. json_names=args.val_manifests,
  224. batch_size=batch_size,
  225. pipeline_type="val",
  226. device_type=args.dali_device,
  227. symbols=symbols)
  228. else:
  229. train_dataset_kw, train_features_kw = config.input(cfg, 'train')
  230. train_dataset = AudioDataset(args.dataset_dir,
  231. args.train_manifests,
  232. symbols,
  233. **train_dataset_kw)
  234. train_loader = get_data_loader(train_dataset,
  235. batch_size,
  236. multi_gpu=multi_gpu,
  237. shuffle=True,
  238. num_workers=4)
  239. train_feat_proc = FilterbankFeatures(**train_features_kw)
  240. val_dataset_kw, val_features_kw = config.input(cfg, 'val')
  241. val_dataset = AudioDataset(args.dataset_dir,
  242. args.val_manifests,
  243. symbols,
  244. **val_dataset_kw)
  245. val_loader = get_data_loader(val_dataset,
  246. batch_size,
  247. multi_gpu=multi_gpu,
  248. shuffle=False,
  249. num_workers=4,
  250. drop_last=False)
  251. val_feat_proc = FilterbankFeatures(**val_features_kw)
  252. dur = train_dataset.duration / 3600
  253. dur_f = train_dataset.duration_filtered / 3600
  254. nsampl = len(train_dataset)
  255. print_once(f'Training samples: {nsampl} ({dur:.1f}h, '
  256. f'filtered {dur_f:.1f}h)')
  257. if train_feat_proc is not None:
  258. train_feat_proc.cuda()
  259. if val_feat_proc is not None:
  260. val_feat_proc.cuda()
  261. steps_per_epoch = len(train_loader) // args.grad_accumulation_steps
  262. # set up the model
  263. model = Jasper(encoder_kw=config.encoder(cfg),
  264. decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
  265. model.cuda()
  266. ctc_loss = CTCLossNM(n_classes=len(symbols))
  267. greedy_decoder = GreedyCTCDecoder()
  268. print_once(f'Model size: {num_weights(model) / 10**6:.1f}M params\n')
  269. # optimization
  270. kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
  271. if args.optimizer == "novograd":
  272. optimizer = Novograd(model.parameters(), **kw)
  273. elif args.optimizer == "adamw":
  274. optimizer = AdamW(model.parameters(), **kw)
  275. else:
  276. raise ValueError(f'Invalid optimizer "{args.optimizer}"')
  277. scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
  278. adjust_lr = lambda step, epoch, optimizer: lr_policy(
  279. step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch,
  280. warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
  281. num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr,
  282. exp_gamma=args.lr_exp_gamma)
  283. if args.ema > 0:
  284. ema_model = copy.deepcopy(model)
  285. else:
  286. ema_model = None
  287. if multi_gpu:
  288. model = torch.nn.parallel.DistributedDataParallel(
  289. model, device_ids=[args.local_rank], output_device=args.local_rank)
  290. # load checkpoint
  291. meta = {'best_wer': 10**6, 'start_epoch': 0}
  292. checkpointer = Checkpointer(args.output_dir, 'Jasper',
  293. args.keep_milestones)
  294. if args.resume:
  295. args.ckpt = checkpointer.last_checkpoint() or args.ckpt
  296. if args.ckpt is not None:
  297. checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)
  298. start_epoch = meta['start_epoch']
  299. best_wer = meta['best_wer']
  300. epoch = 1
  301. step = start_epoch * steps_per_epoch + 1
  302. # training loop
  303. model.train()
  304. # pre-allocate
  305. if args.pre_allocate_range is not None:
  306. n_feats = train_features_kw['n_filt']
  307. pad_align = train_features_kw['pad_align']
  308. a, b = args.pre_allocate_range
  309. for n_frames in range(a, b + pad_align, pad_align):
  310. print_once(f'Pre-allocation ({batch_size}x{n_feats}x{n_frames})...')
  311. feat = torch.randn(batch_size, n_feats, n_frames, device='cuda')
  312. feat_lens = torch.ones(batch_size, device='cuda').fill_(n_frames)
  313. txt = torch.randint(high=len(symbols)-1, size=(batch_size, 100),
  314. device='cuda')
  315. txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
  316. with torch.cuda.amp.autocast(enabled=args.amp):
  317. log_probs, enc_lens = model(feat, feat_lens)
  318. del feat
  319. loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
  320. loss.backward()
  321. model.zero_grad()
  322. torch.cuda.empty_cache()
  323. bmark_stats = BenchmarkStats()
  324. for epoch in range(start_epoch + 1, args.epochs + 1):
  325. if multi_gpu and not use_dali:
  326. train_loader.sampler.set_epoch(epoch)
  327. torch.cuda.synchronize()
  328. epoch_start_time = time.time()
  329. epoch_utts = 0
  330. epoch_loss = 0
  331. accumulated_batches = 0
  332. for batch in train_loader:
  333. if accumulated_batches == 0:
  334. step_loss = 0
  335. step_utts = 0
  336. step_start_time = time.time()
  337. if use_dali:
  338. # with DALI, the data is already on GPU
  339. feat, feat_lens, txt, txt_lens = batch
  340. if train_feat_proc is not None:
  341. feat, feat_lens = train_feat_proc(feat, feat_lens)
  342. else:
  343. batch = [t.cuda(non_blocking=True) for t in batch]
  344. audio, audio_lens, txt, txt_lens = batch
  345. feat, feat_lens = train_feat_proc(audio, audio_lens)
  346. # Use context manager to prevent redundant accumulation of gradients
  347. if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation_steps):
  348. ctx = model.no_sync()
  349. else:
  350. ctx = empty_context()
  351. with ctx:
  352. with torch.cuda.amp.autocast(enabled=args.amp):
  353. log_probs, enc_lens = model(feat, feat_lens)
  354. loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
  355. loss /= args.grad_accumulation_steps
  356. if multi_gpu:
  357. reduced_loss = reduce_tensor(loss.data, world_size)
  358. else:
  359. reduced_loss = loss
  360. if torch.isnan(reduced_loss).any():
  361. print_once(f'WARNING: loss is NaN; skipping update')
  362. continue
  363. else:
  364. step_loss += reduced_loss.item()
  365. step_utts += batch[0].size(0) * world_size
  366. epoch_utts += batch[0].size(0) * world_size
  367. accumulated_batches += 1
  368. scaler.scale(loss).backward()
  369. if accumulated_batches % args.grad_accumulation_steps == 0:
  370. epoch_loss += step_loss
  371. scaler.step(optimizer)
  372. scaler.update()
  373. adjust_lr(step, epoch, optimizer)
  374. optimizer.zero_grad()
  375. apply_ema(model, ema_model, args.ema)
  376. if step % args.log_frequency == 0:
  377. preds = greedy_decoder(log_probs)
  378. wer, pred_utt, ref = greedy_wer(preds, txt, txt_lens, symbols)
  379. if step % args.prediction_frequency == 0:
  380. print_once(f' Decoded: {pred_utt[:90]}')
  381. print_once(f' Reference: {ref[:90]}')
  382. step_time = time.time() - step_start_time
  383. log((epoch, step % steps_per_epoch or steps_per_epoch, steps_per_epoch),
  384. step, 'train',
  385. {'loss': step_loss,
  386. 'wer': 100.0 * wer,
  387. 'throughput': step_utts / step_time,
  388. 'took': step_time,
  389. 'lrate': optimizer.param_groups[0]['lr']})
  390. step_start_time = time.time()
  391. if step % args.eval_frequency == 0:
  392. wer = evaluate(epoch, step, val_loader, val_feat_proc,
  393. symbols, model, ema_model, ctc_loss,
  394. greedy_decoder, args.amp, use_dali)
  395. if wer < best_wer and epoch >= args.save_best_from:
  396. checkpointer.save(model, ema_model, optimizer, scaler,
  397. epoch, step, best_wer, is_best=True)
  398. best_wer = wer
  399. step += 1
  400. accumulated_batches = 0
  401. # end of step
  402. # DALI iterator need to be exhausted;
  403. # if not using DALI, simulate drop_last=True with grad accumulation
  404. if not use_dali and step > steps_per_epoch * epoch:
  405. break
  406. torch.cuda.synchronize()
  407. epoch_time = time.time() - epoch_start_time
  408. epoch_loss /= steps_per_epoch
  409. log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
  410. 'took': epoch_time,
  411. 'loss': epoch_loss})
  412. bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
  413. if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
  414. checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
  415. best_wer)
  416. if 0 < args.epochs_this_job <= epoch - start_epoch:
  417. print_once(f'Finished after {args.epochs_this_job} epochs.')
  418. break
  419. # end of epoch
  420. log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))
  421. evaluate(None, step, val_loader, val_feat_proc, symbols, model,
  422. ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)
  423. if epoch == args.epochs:
  424. checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
  425. best_wer)
  426. flush_log()
  427. if __name__ == "__main__":
  428. main()