|
|
@@ -38,6 +38,7 @@ from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
|
|
|
process_evaluation_epoch)
|
|
|
from common.optimizers import AdamW, lr_policy, Novograd
|
|
|
from common.tb_dllogger import flush_log, init_log, log
|
|
|
+from common.utils import BenchmarkStats
|
|
|
from jasper import config
|
|
|
from jasper.model import CTCLossNM, GreedyCTCDecoder, Jasper
|
|
|
|
|
|
@@ -111,16 +112,17 @@ def parse_args():
|
|
|
help='Paths of the training dataset manifest file')
|
|
|
io.add_argument('--val_manifests', type=str, required=True, nargs='+',
|
|
|
help='Paths of the evaluation datasets manifest files')
|
|
|
- io.add_argument('--max_duration', type=float,
|
|
|
- help='Discard samples longer than max_duration')
|
|
|
- io.add_argument('--pad_to_max_duration', action='store_true', default=False,
|
|
|
- help='Pad training sequences to max_duration')
|
|
|
io.add_argument('--dataset_dir', required=True, type=str,
|
|
|
help='Root dir of dataset')
|
|
|
io.add_argument('--output_dir', type=str, required=True,
|
|
|
help='Directory for logs and checkpoints')
|
|
|
io.add_argument('--log_file', type=str, default=None,
|
|
|
help='Path to save the training logfile.')
|
|
|
+ io.add_argument('--benchmark_epochs_num', type=int, default=1,
|
|
|
+ help='Number of epochs accounted in final average throughput.')
|
|
|
+ io.add_argument('--override_config', type=str, action='append',
|
|
|
+ help='Overrides a value from a config .yaml.'
|
|
|
+ ' Syntax: `--override_config nested.config.key=val`.')
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
@@ -202,7 +204,7 @@ def main():
|
|
|
init_log(args)
|
|
|
|
|
|
cfg = config.load(args.model_config)
|
|
|
- config.apply_duration_flags(cfg, args.max_duration, args.pad_to_max_duration)
|
|
|
+ config.apply_config_overrides(cfg, args)
|
|
|
|
|
|
symbols = helpers.add_ctc_blank(cfg['labels'])
|
|
|
|
|
|
@@ -384,11 +386,14 @@ def main():
|
|
|
loss.backward()
|
|
|
model.zero_grad()
|
|
|
|
|
|
+ bmark_stats = BenchmarkStats()
|
|
|
+
|
|
|
for epoch in range(start_epoch + 1, args.epochs + 1):
|
|
|
if multi_gpu and not use_dali:
|
|
|
train_loader.sampler.set_epoch(epoch)
|
|
|
|
|
|
epoch_utts = 0
|
|
|
+ epoch_loss = 0
|
|
|
accumulated_batches = 0
|
|
|
epoch_start_time = time.time()
|
|
|
|
|
|
@@ -434,6 +439,7 @@ def main():
|
|
|
accumulated_batches += 1
|
|
|
|
|
|
if accumulated_batches % args.grad_accumulation_steps == 0:
|
|
|
+ epoch_loss += step_loss
|
|
|
optimizer.step()
|
|
|
apply_ema(model, ema_model, args.ema)
|
|
|
|
|
|
@@ -476,8 +482,11 @@ def main():
|
|
|
break
|
|
|
|
|
|
epoch_time = time.time() - epoch_start_time
|
|
|
+ epoch_loss /= steps_per_epoch
|
|
|
log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
|
|
|
- 'took': epoch_time})
|
|
|
+ 'took': epoch_time,
|
|
|
+ 'loss': epoch_loss})
|
|
|
+ bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
|
|
|
|
|
|
if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
|
|
|
checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
|
|
|
@@ -491,7 +500,7 @@ def main():
|
|
|
profiler.stop()
|
|
|
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
|
|
|
|
|
|
- log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time})
|
|
|
+ log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))
|
|
|
|
|
|
if epoch == args.epochs:
|
|
|
evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,
|