Explorar el Código

[Jasper/PyT] Clean up inference flags

Mikolaj Blaz hace 4 años
padre
commit
706ef498c9

+ 2 - 0
PyTorch/SpeechRecognition/Jasper/README.md

@@ -439,6 +439,7 @@ LOG_FILE:            path to the DLLogger .json logfile. (default: '')
 CUDNN_BENCHMARK:     enable cudnn benchmark mode for using more optimized kernels. (default: false)
 MAX_DURATION:        filter out recordings shorter then MAX_DURATION seconds. (default: "")
 PAD_TO_MAX_DURATION: pad all sequences with zeros to maximum length. (default: false)
+PAD_LEADING:         pad every batch with leading zeros to counteract conv shifts of the field of view. (default: 16)
 NUM_GPUS:            number of GPUs to use. Note that with > 1 GPUs WER results might be inaccurate due to the batching policy. (default: 1)
 NUM_STEPS:           number of batches to evaluate, loop the dataset if necessary. (default: 0)
 NUM_WARMUP_STEPS:    number of initial steps before measuring performance. (default: 0)
@@ -464,6 +465,7 @@ BATCH_SIZE_SEQ:      batch sizes to measure on. (default: "1 2 4 8 16")
 MAX_DURATION_SEQ:    input durations (in seconds) to measure on (default: "2 7 16.7")
 CUDNN_BENCHMARK:     (default: true)
 PAD_TO_MAX_DURATION: (default: true)
+PAD_LEADING:         (default: 0)
 NUM_WARMUP_STEPS:    (default: 10)
 NUM_STEPS:           (default: 500)
 DALI_DEVICE:         (default: cpu)

+ 8 - 3
PyTorch/SpeechRecognition/Jasper/inference.py

@@ -26,6 +26,7 @@ import dllogger
 import torch
 import numpy as np
 import torch.distributed as distrib
+import torch.nn.functional as F
 from apex import amp
 from apex.parallel import DistributedDataParallel
 from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
@@ -57,8 +58,9 @@ def get_parser():
                         help='Relative path to evaluation dataset manifest files')
     parser.add_argument('--ckpt', default=None, type=str,
                         help='Path to model checkpoint')
-    parser.add_argument("--max_duration", default=None, type=float, help='maximum duration of sequences. if None uses attribute from model configuration file')
-    parser.add_argument("--pad_to_max_duration", action='store_true', help='pad to maximum duration of sequences')
+    parser.add_argument('--pad_leading', type=int, default=16,
+                        help='Pads every batch with leading zeros '
+                             'to counteract conv shifts of the field of view')
     parser.add_argument('--amp', '--fp16', action='store_true',
                         help='Use FP16 precision')
     parser.add_argument('--cudnn_benchmark', action='store_true',
@@ -210,7 +212,6 @@ def main():
             print("DALI supported only with input .json files; disabling")
             use_dali = False
 
-        assert not args.pad_to_max_duration
         assert not (args.transcribe_wav and args.transcribe_filelist)
 
         if args.transcribe_wav:
@@ -226,6 +227,7 @@ def main():
                                       drop_last=(True if measure_perf else False))
 
         _, features_kw = config.input(cfg, 'val')
+        assert not features_kw['pad_to_max_duration']
         feat_proc = FilterbankFeatures(**features_kw)
 
     elif use_dali:
@@ -327,6 +329,9 @@ def main():
             if args.amp:
                 feats = feats.half()
 
+            feats = F.pad(feats, (args.pad_leading, 0))
+            feat_lens += args.pad_leading
+
             if model.encoder.use_conv_masks:
                 log_probs, log_prob_lens = model(feats, feat_lens)
             else:

+ 2 - 0
PyTorch/SpeechRecognition/Jasper/scripts/inference.sh

@@ -23,6 +23,7 @@
 : ${CUDNN_BENCHMARK:=false}
 : ${MAX_DURATION:=""}
 : ${PAD_TO_MAX_DURATION:=false}
+: ${PAD_LEADING:=16}
 : ${NUM_GPUS:=1}
 : ${NUM_STEPS:=0}
 : ${NUM_WARMUP_STEPS:=0}
@@ -46,6 +47,7 @@ ARGS+=" --seed=$SEED"
 ARGS+=" --dali_device=$DALI_DEVICE"
 ARGS+=" --steps $NUM_STEPS"
 ARGS+=" --warmup_steps $NUM_WARMUP_STEPS"
+ARGS+=" --pad_leading $PAD_LEADING"
 
 [ "$AMP" = true ] &&                 ARGS+=" --amp"
 [ "$EMA" = true ] &&                 ARGS+=" --ema"

+ 1 - 0
PyTorch/SpeechRecognition/Jasper/scripts/inference_benchmark.sh

@@ -19,6 +19,7 @@ set -a
 : ${OUTPUT_DIR:=${3:-"/results"}}
 : ${CUDNN_BENCHMARK:=true}
 : ${PAD_TO_MAX_DURATION:=true}
+: ${PAD_LEADING:=0}
 : ${NUM_WARMUP_STEPS:=10}
 : ${NUM_STEPS:=500}
 

+ 3 - 25
PyTorch/SpeechRecognition/Jasper/triton/jasper_module.py

@@ -74,14 +74,7 @@ def get_dataloader(model_args_list):
         return None
 
     cfg = config.load(args.model_config)
-
-    if args.max_duration is not None:
-        cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
-        cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
-
-    if args.pad_to_max_duration:
-        assert cfg['input_train']['audio_dataset']['max_duration'] > 0
-        cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
+    config.apply_config_overrides(cfg, args)
 
     symbols = add_ctc_blank(cfg['labels'])
 
@@ -108,15 +101,7 @@ def init_feature_extractor(args):
     from common.features import FilterbankFeatures
 
     cfg = config.load(args.model_config)
-
-    if args.max_duration is not None:
-        cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
-        cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
-
-    if args.pad_to_max_duration:
-        assert cfg['input_train']['audio_dataset']['max_duration'] > 0
-        cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
-
+    config.apply_config_overrides(cfg, args)
     _, features_kw = config.input(cfg, 'val')
 
     feature_proc = FilterbankFeatures(**features_kw)
@@ -131,14 +116,7 @@ def init_acoustic_model(args):
     from jasper import config
 
     cfg = config.load(args.model_config)
-
-    if args.max_duration is not None:
-        cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
-        cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
-
-    if args.pad_to_max_duration:
-        assert cfg['input_train']['audio_dataset']['max_duration'] > 0
-        cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
+    config.apply_config_overrides(cfg, args)
 
     if cfg['jasper']['encoder']['use_conv_masks'] == True:
         print("[Jasper module]: Warning: setting 'use_conv_masks' \