|
|
@@ -27,8 +27,7 @@ import torch
|
|
|
import numpy as np
|
|
|
import torch.cuda.profiler as profiler
|
|
|
import torch.distributed as dist
|
|
|
-from apex import amp
|
|
|
-from apex.parallel import DistributedDataParallel
|
|
|
+from contextlib import suppress as empty_context
|
|
|
|
|
|
from common import helpers
|
|
|
from common.dali.data_loader import DaliDataLoader
|
|
|
@@ -59,7 +58,7 @@ def parse_args():
|
|
|
training.add_argument('--cudnn_benchmark', action='store_true', default=True,
|
|
|
help='Enable cudnn benchmark')
|
|
|
training.add_argument('--amp', '--fp16', action='store_true', default=False,
|
|
|
- help='Use mixed precision training')
|
|
|
+ help='Use pytorch native mixed precision training')
|
|
|
training.add_argument('--seed', default=42, type=int, help='Random seed')
|
|
|
training.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0),
|
|
|
type=int, help='GPU id used for distributed training')
|
|
|
@@ -158,15 +157,16 @@ def evaluate(epoch, step, val_loader, val_feat_proc, labels, model,
|
|
|
# with DALI, the data is already on GPU
|
|
|
feat, feat_lens, txt, txt_lens = batch
|
|
|
if val_feat_proc is not None:
|
|
|
- feat, feat_lens = val_feat_proc(feat, feat_lens, use_amp)
|
|
|
+ feat, feat_lens = val_feat_proc(feat, feat_lens)
|
|
|
else:
|
|
|
batch = [t.cuda(non_blocking=True) for t in batch]
|
|
|
audio, audio_lens, txt, txt_lens = batch
|
|
|
- feat, feat_lens = val_feat_proc(audio, audio_lens, use_amp)
|
|
|
+ feat, feat_lens = val_feat_proc(audio, audio_lens)
|
|
|
|
|
|
- log_probs, enc_lens = model.forward(feat, feat_lens)
|
|
|
- loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
|
|
|
- pred = greedy_decoder(log_probs)
|
|
|
+ with torch.cuda.amp.autocast(enabled=use_amp):
|
|
|
+ log_probs, enc_lens = model(feat, feat_lens)
|
|
|
+ loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
|
|
|
+ pred = greedy_decoder(log_probs)
|
|
|
|
|
|
agg['losses'] += helpers.gather_losses([loss])
|
|
|
agg['preds'] += helpers.gather_predictions([pred], labels)
|
|
|
@@ -323,37 +323,34 @@ def main():
|
|
|
else:
|
|
|
raise ValueError(f'Invalid optimizer "{args.optimizer}"')
|
|
|
|
|
|
+ scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
|
|
+
|
|
|
adjust_lr = lambda step, epoch, optimizer: lr_policy(
|
|
|
step, epoch, args.lr, optimizer, steps_per_epoch=steps_per_epoch,
|
|
|
warmup_epochs=args.warmup_epochs, hold_epochs=args.hold_epochs,
|
|
|
num_epochs=args.epochs, policy=args.lr_policy, min_lr=args.min_lr,
|
|
|
exp_gamma=args.lr_exp_gamma)
|
|
|
|
|
|
- if args.amp:
|
|
|
- model, optimizer = amp.initialize(
|
|
|
- min_loss_scale=1.0, models=model, optimizers=optimizer,
|
|
|
- opt_level='O1', max_loss_scale=512.0)
|
|
|
-
|
|
|
if args.ema > 0:
|
|
|
ema_model = copy.deepcopy(model)
|
|
|
else:
|
|
|
ema_model = None
|
|
|
|
|
|
if multi_gpu:
|
|
|
- model = DistributedDataParallel(model)
|
|
|
-
|
|
|
+ model = torch.nn.parallel.DistributedDataParallel(
|
|
|
+ model, device_ids=[args.local_rank], output_device=args.local_rank)
|
|
|
if args.pyprof:
|
|
|
pyprof.init(enable_function_stack=True)
|
|
|
|
|
|
# load checkpoint
|
|
|
meta = {'best_wer': 10**6, 'start_epoch': 0}
|
|
|
checkpointer = Checkpointer(args.output_dir, 'Jasper',
|
|
|
- args.keep_milestones, args.amp)
|
|
|
+ args.keep_milestones)
|
|
|
if args.resume:
|
|
|
args.ckpt = checkpointer.last_checkpoint() or args.ckpt
|
|
|
|
|
|
if args.ckpt is not None:
|
|
|
- checkpointer.load(args.ckpt, model, ema_model, optimizer, meta)
|
|
|
+ checkpointer.load(args.ckpt, model, ema_model, optimizer, scaler, meta)
|
|
|
|
|
|
start_epoch = meta['start_epoch']
|
|
|
best_wer = meta['best_wer']
|
|
|
@@ -380,11 +377,13 @@ def main():
|
|
|
txt = torch.randint(high=len(symbols)-1, size=(batch_size, 100),
|
|
|
device='cuda')
|
|
|
txt_lens = torch.ones(batch_size, device='cuda').fill_(100)
|
|
|
- log_probs, enc_lens = model(feat, feat_lens)
|
|
|
- del feat
|
|
|
- loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
|
|
|
+ with torch.cuda.amp.autocast(enabled=args.amp):
|
|
|
+ log_probs, enc_lens = model(feat, feat_lens)
|
|
|
+ del feat
|
|
|
+ loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
|
|
|
loss.backward()
|
|
|
model.zero_grad()
|
|
|
+ torch.cuda.empty_cache()
|
|
|
|
|
|
bmark_stats = BenchmarkStats()
|
|
|
|
|
|
@@ -396,12 +395,11 @@ def main():
|
|
|
epoch_loss = 0
|
|
|
accumulated_batches = 0
|
|
|
epoch_start_time = time.time()
|
|
|
+ epoch_eval_time = 0
|
|
|
|
|
|
for batch in train_loader:
|
|
|
|
|
|
if accumulated_batches == 0:
|
|
|
- adjust_lr(step, epoch, optimizer)
|
|
|
- optimizer.zero_grad()
|
|
|
step_loss = 0
|
|
|
step_utts = 0
|
|
|
step_start_time = time.time()
|
|
|
@@ -410,37 +408,49 @@ def main():
|
|
|
# with DALI, the data is already on GPU
|
|
|
feat, feat_lens, txt, txt_lens = batch
|
|
|
if train_feat_proc is not None:
|
|
|
- feat, feat_lens = train_feat_proc(feat, feat_lens, args.amp)
|
|
|
+ feat, feat_lens = train_feat_proc(feat, feat_lens)
|
|
|
else:
|
|
|
batch = [t.cuda(non_blocking=True) for t in batch]
|
|
|
audio, audio_lens, txt, txt_lens = batch
|
|
|
- feat, feat_lens = train_feat_proc(audio, audio_lens, args.amp)
|
|
|
+ feat, feat_lens = train_feat_proc(audio, audio_lens)
|
|
|
|
|
|
- log_probs, enc_lens = model(feat, feat_lens)
|
|
|
+ # Use context manager to prevent redundant accumulation of gradients
|
|
|
+ if (multi_gpu and accumulated_batches + 1 < args.grad_accumulation_steps):
|
|
|
+ ctx = model.no_sync()
|
|
|
+ else:
|
|
|
+ ctx = empty_context()
|
|
|
|
|
|
- loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
|
|
|
- loss /= args.grad_accumulation_steps
|
|
|
+ with ctx:
|
|
|
+ with torch.cuda.amp.autocast(enabled=args.amp):
|
|
|
+ log_probs, enc_lens = model(feat, feat_lens)
|
|
|
+
|
|
|
+ loss = ctc_loss(log_probs, txt, enc_lens, txt_lens)
|
|
|
+ loss /= args.grad_accumulation_steps
|
|
|
|
|
|
- if torch.isnan(loss).any():
|
|
|
- print_once(f'WARNING: loss is NaN; skipping update')
|
|
|
- else:
|
|
|
if multi_gpu:
|
|
|
- step_loss += reduce_tensor(loss.data, world_size).item()
|
|
|
+ reduced_loss = reduce_tensor(loss.data, world_size)
|
|
|
else:
|
|
|
- step_loss += loss.item()
|
|
|
+ reduced_loss = loss
|
|
|
|
|
|
- if args.amp:
|
|
|
- with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
- scaled_loss.backward()
|
|
|
+ if torch.isnan(reduced_loss).any():
|
|
|
+ print_once(f'WARNING: loss is NaN; skipping update')
|
|
|
+ continue
|
|
|
else:
|
|
|
- loss.backward()
|
|
|
- step_utts += batch[0].size(0) * world_size
|
|
|
- epoch_utts += batch[0].size(0) * world_size
|
|
|
- accumulated_batches += 1
|
|
|
+ step_loss += reduced_loss.item()
|
|
|
+ step_utts += batch[0].size(0) * world_size
|
|
|
+ epoch_utts += batch[0].size(0) * world_size
|
|
|
+ accumulated_batches += 1
|
|
|
+
|
|
|
+ scaler.scale(loss).backward()
|
|
|
|
|
|
if accumulated_batches % args.grad_accumulation_steps == 0:
|
|
|
epoch_loss += step_loss
|
|
|
- optimizer.step()
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+
|
|
|
+ adjust_lr(step, epoch, optimizer)
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
apply_ema(model, ema_model, args.ema)
|
|
|
|
|
|
if step % args.log_frequency == 0:
|
|
|
@@ -463,14 +473,16 @@ def main():
|
|
|
step_start_time = time.time()
|
|
|
|
|
|
if step % args.eval_frequency == 0:
|
|
|
+ tik = time.time()
|
|
|
wer = evaluate(epoch, step, val_loader, val_feat_proc,
|
|
|
symbols, model, ema_model, ctc_loss,
|
|
|
greedy_decoder, args.amp, use_dali)
|
|
|
|
|
|
if wer < best_wer and epoch >= args.save_best_from:
|
|
|
- checkpointer.save(model, ema_model, optimizer, epoch,
|
|
|
- step, best_wer, is_best=True)
|
|
|
+ checkpointer.save(model, ema_model, optimizer, scaler,
|
|
|
+ epoch, step, best_wer, is_best=True)
|
|
|
best_wer = wer
|
|
|
+ epoch_eval_time += time.time() - tik
|
|
|
|
|
|
step += 1
|
|
|
accumulated_batches = 0
|
|
|
@@ -489,7 +501,8 @@ def main():
|
|
|
bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
|
|
|
|
|
|
if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
|
|
|
- checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
|
|
|
+ checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
|
|
|
+ best_wer)
|
|
|
|
|
|
if 0 < args.epochs_this_job <= epoch - start_epoch:
|
|
|
print_once(f'Finished after {args.epochs_this_job} epochs.')
|
|
|
@@ -506,7 +519,8 @@ def main():
|
|
|
evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
|
|
|
ema_model, ctc_loss, greedy_decoder, args.amp, use_dali)
|
|
|
|
|
|
- checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
|
|
|
+ checkpointer.save(model, ema_model, optimizer, scaler, epoch, step,
|
|
|
+ best_wer)
|
|
|
flush_log()
|
|
|
|
|
|
|