|
|
@@ -26,6 +26,7 @@ import dllogger
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import torch.distributed as distrib
|
|
|
+import torch.nn.functional as F
|
|
|
from apex import amp
|
|
|
from apex.parallel import DistributedDataParallel
|
|
|
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
|
|
|
@@ -57,8 +58,9 @@ def get_parser():
|
|
|
help='Relative path to evaluation dataset manifest files')
|
|
|
parser.add_argument('--ckpt', default=None, type=str,
|
|
|
help='Path to model checkpoint')
|
|
|
- parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
|
|
|
- parser.add_argument("--pad_to_max_duration", action='store_true', help='pad to maximum duration of sequences')
|
|
|
+ parser.add_argument('--pad_leading', type=int, default=16,
|
|
|
+ help='Pads every batch with leading zeros '
|
|
|
+ 'to counteract conv shifts of the field of view')
|
|
|
parser.add_argument('--amp', '--fp16', action='store_true',
|
|
|
help='Use FP16 precision')
|
|
|
parser.add_argument('--cudnn_benchmark', action='store_true',
|
|
|
@@ -210,7 +212,6 @@ def main():
|
|
|
print("DALI supported only with input .json files; disabling")
|
|
|
use_dali = False
|
|
|
|
|
|
- assert not args.pad_to_max_duration
|
|
|
assert not (args.transcribe_wav and args.transcribe_filelist)
|
|
|
|
|
|
if args.transcribe_wav:
|
|
|
@@ -226,6 +227,7 @@ def main():
|
|
|
drop_last=(True if measure_perf else False))
|
|
|
|
|
|
_, features_kw = config.input(cfg, 'val')
|
|
|
+ assert not features_kw['pad_to_max_duration']
|
|
|
feat_proc = FilterbankFeatures(**features_kw)
|
|
|
|
|
|
elif use_dali:
|
|
|
@@ -327,6 +329,9 @@ def main():
|
|
|
if args.amp:
|
|
|
feats = feats.half()
|
|
|
|
|
|
+ feats = F.pad(feats, (args.pad_leading, 0))
|
|
|
+ feat_lens += args.pad_leading
|
|
|
+
|
|
|
if model.encoder.use_conv_masks:
|
|
|
log_probs, log_prob_lens = model(feats, feat_lens)
|
|
|
else:
|