Jelajahi Sumber

[Jasper/PyT] Minor update in metrics and CLI params

Mikolaj Blaz 4 tahun lalu
induk
melakukan
0d4dd6b523

+ 20 - 0
PyTorch/SpeechRecognition/Jasper/common/utils.py

@@ -0,0 +1,20 @@
+import numpy as np
+
+
+class BenchmarkStats:
+    """ Tracks statistics used for benchmarking. """
+    def __init__(self):
+        self.utts = []
+        self.times = []
+        self.losses = []
+
+    def update(self, utts, times, losses):
+        self.utts.append(utts)
+        self.times.append(times)
+        self.losses.append(losses)
+
+    def get(self, n_epochs):
+        throughput = sum(self.utts[-n_epochs:]) / sum(self.times[-n_epochs:])
+
+        return {'throughput': throughput, 'benchmark_epochs_num': n_epochs,
+                'loss': np.mean(self.losses[-n_epochs:])}

+ 4 - 13
PyTorch/SpeechRecognition/Jasper/inference.py

@@ -57,10 +57,6 @@ def get_parser():
                         help='Relative path to evaluation dataset manifest files')
     parser.add_argument('--ckpt', default=None, type=str,
                         help='Path to model checkpoint')
-    parser.add_argument('--max_duration', default=None, type=float,
-                        help='Filter out longer inputs (in seconds)')
-    parser.add_argument('--pad_to_max_duration', action='store_true',
-                        help='Pads every batch to max_duration')
     parser.add_argument('--amp', '--fp16', action='store_true',
                         help='Use FP16 precision')
     parser.add_argument('--cudnn_benchmark', action='store_true',
@@ -92,6 +88,9 @@ def get_parser():
                     help='Evaluate with a TorchScripted model')
     io.add_argument('--torchscript_export', action='store_true',
                     help='Export the model with torch.jit to the output_dir')
+    io.add_argument('--override_config', type=str, action='append',
+                    help='Overrides a value from a config .yaml.'
+                         ' Syntax: `--override_config nested.config.key=val`.')
     return parser
 
 
@@ -193,15 +192,7 @@ def main():
         print_once(f'Inference with {distrib.get_world_size()} GPUs')
 
     cfg = config.load(args.model_config)
-
-    if args.max_duration is not None:
-        cfg['input_val']['audio_dataset']['max_duration'] = args.max_duration
-        cfg['input_val']['filterbank_features']['max_duration'] = args.max_duration
-
-    if args.pad_to_max_duration:
-        assert cfg['input_val']['audio_dataset']['max_duration'] > 0
-        cfg['input_val']['audio_dataset']['pad_to_max_duration'] = True
-        cfg['input_val']['filterbank_features']['pad_to_max_duration'] = True
+    config.apply_config_overrides(cfg, args)
 
     symbols = helpers.add_ctc_blank(cfg['labels'])
 

+ 24 - 9
PyTorch/SpeechRecognition/Jasper/jasper/config.py

@@ -1,5 +1,10 @@
 import copy
 import inspect
+import typing
+from ast import literal_eval
+from contextlib import suppress
+from numbers import Number
+
 import yaml
 
 from .model import JasperDecoderForCTC, JasperBlock, JasperEncoder
@@ -99,12 +104,22 @@ def decoder(conf, n_classes):
     return validate_and_fill(JasperDecoderForCTC, decoder_kw)
 
 
-def apply_duration_flags(cfg, max_duration, pad_to_max_duration):
-    if max_duration is not None:
-        cfg['input_train']['audio_dataset']['max_duration'] = max_duration
-        cfg['input_train']['filterbank_features']['max_duration'] = max_duration
-
-    if pad_to_max_duration:
-        assert cfg['input_train']['audio_dataset']['max_duration'] > 0
-        cfg['input_train']['audio_dataset']['pad_to_max_duration'] = True
-        cfg['input_train']['filterbank_features']['pad_to_max_duration'] = True
+def apply_config_overrides(conf, args):
+    if args.override_config is None:
+        return
+    for override_key_val in args.override_config:
+        key, val = override_key_val.split('=')
+        with suppress(TypeError, ValueError):
+            val = literal_eval(val)
+        apply_nested_config_override(conf, key, val)
+
+
+def apply_nested_config_override(conf, key_str, val):
+    fields = key_str.split('.')
+    for f in fields[:-1]:
+        conf = conf[f]
+    f = fields[-1]
+    assert (f not in conf
+            or type(val) is type(conf[f])
+            or (isinstance(val, Number) and isinstance(conf[f], Number)))
+    conf[f] = val

+ 4 - 2
PyTorch/SpeechRecognition/Jasper/scripts/inference.sh

@@ -55,7 +55,9 @@ ARGS+=" --warmup_steps $NUM_WARMUP_STEPS"
 [ -n "$PREDICTION_FILE" ] &&         ARGS+=" --save_prediction $PREDICTION_FILE"
 [ -n "$LOGITS_FILE" ] &&             ARGS+=" --logits_save_to $LOGITS_FILE"
 [ "$CPU" == "true" ] &&              ARGS+=" --cpu"
-[ -n "$MAX_DURATION" ] &&            ARGS+=" --max_duration $MAX_DURATION"
-[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --pad_to_max_duration"
+[ -n "$MAX_DURATION" ] &&            ARGS+=" --override_config input_val.audio_dataset.max_duration=$MAX_DURATION" \
+                                     ARGS+=" --override_config input_val.filterbank_features.max_duration=$MAX_DURATION"
+[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --override_config input_val.audio_dataset.pad_to_max_duration=True" \
+                                     ARGS+=" --override_config input_val.filterbank_features.pad_to_max_duration=True"
 
 python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS inference.py $ARGS

+ 4 - 1
PyTorch/SpeechRecognition/Jasper/scripts/train.sh

@@ -78,7 +78,10 @@ ARGS+=" --dali_device=$DALI_DEVICE"
 [ "$AMP" = true ] &&                 ARGS+=" --amp"
 [ "$RESUME" = true ] &&              ARGS+=" --resume"
 [ "$CUDNN_BENCHMARK" = true ] &&     ARGS+=" --cudnn_benchmark"
-[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --pad_to_max_duration"
+[ -n "$MAX_DURATION" ] &&            ARGS+=" --override_config input_train.audio_dataset.max_duration=$MAX_DURATION" \
+                                     ARGS+=" --override_config input_train.filterbank_features.max_duration=$MAX_DURATION"
+[ "$PAD_TO_MAX_DURATION" = true ] && ARGS+=" --override_config input_train.audio_dataset.pad_to_max_duration=True" \
+                                     ARGS+=" --override_config input_train.filterbank_features.pad_to_max_duration=True"
 [ -n "$CHECKPOINT" ] &&              ARGS+=" --ckpt=$CHECKPOINT"
 [ -n "$LOG_FILE" ] &&                ARGS+=" --log_file $LOG_FILE"
 [ -n "$PRE_ALLOCATE" ] &&            ARGS+=" --pre_allocate_range $PRE_ALLOCATE"

+ 16 - 7
PyTorch/SpeechRecognition/Jasper/train.py

@@ -38,6 +38,7 @@ from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
                             process_evaluation_epoch)
 from common.optimizers import AdamW, lr_policy, Novograd
 from common.tb_dllogger import flush_log, init_log, log
+from common.utils import BenchmarkStats
 from jasper import config
 from jasper.model import CTCLossNM, GreedyCTCDecoder, Jasper
 
@@ -111,16 +112,17 @@ def parse_args():
                     help='Paths of the training dataset manifest file')
     io.add_argument('--val_manifests', type=str, required=True, nargs='+',
                     help='Paths of the evaluation datasets manifest files')
-    io.add_argument('--max_duration', type=float,
-                    help='Discard samples longer than max_duration')
-    io.add_argument('--pad_to_max_duration', action='store_true', default=False,
-                    help='Pad training sequences to max_duration')
     io.add_argument('--dataset_dir', required=True, type=str,
                     help='Root dir of dataset')
     io.add_argument('--output_dir', type=str, required=True,
                     help='Directory for logs and checkpoints')
     io.add_argument('--log_file', type=str, default=None,
                     help='Path to save the training logfile.')
+    io.add_argument('--benchmark_epochs_num', type=int, default=1,
+                    help='Number of epochs accounted in final average throughput.')
+    io.add_argument('--override_config', type=str, action='append',
+                    help='Overrides a value from a config .yaml.'
+                         ' Syntax: `--override_config nested.config.key=val`.')
     return parser.parse_args()
 
 
@@ -202,7 +204,7 @@ def main():
     init_log(args)
 
     cfg = config.load(args.model_config)
-    config.apply_duration_flags(cfg, args.max_duration, args.pad_to_max_duration)
+    config.apply_config_overrides(cfg, args)
 
     symbols = helpers.add_ctc_blank(cfg['labels'])
 
@@ -384,11 +386,14 @@ def main():
             loss.backward()
             model.zero_grad()
 
+    bmark_stats = BenchmarkStats()
+
     for epoch in range(start_epoch + 1, args.epochs + 1):
         if multi_gpu and not use_dali:
             train_loader.sampler.set_epoch(epoch)
 
         epoch_utts = 0
+        epoch_loss = 0
         accumulated_batches = 0
         epoch_start_time = time.time()
 
@@ -434,6 +439,7 @@ def main():
                 accumulated_batches += 1
 
             if accumulated_batches % args.grad_accumulation_steps == 0:
+                epoch_loss += step_loss
                 optimizer.step()
                 apply_ema(model, ema_model, args.ema)
 
@@ -476,8 +482,11 @@ def main():
                 break
 
         epoch_time = time.time() - epoch_start_time
+        epoch_loss /= steps_per_epoch
         log((epoch,), None, 'train_avg', {'throughput': epoch_utts / epoch_time,
-                                          'took': epoch_time})
+                                          'took': epoch_time,
+                                          'loss': epoch_loss})
+        bmark_stats.update(epoch_utts, epoch_time, epoch_loss)
 
         if epoch % args.save_frequency == 0 or epoch in args.keep_milestones:
             checkpointer.save(model, ema_model, optimizer, epoch, step, best_wer)
@@ -491,7 +500,7 @@ def main():
         profiler.stop()
         torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
 
-    log((), None, 'train_avg', {'throughput': epoch_utts / epoch_time})
+    log((), None, 'train_avg', bmark_stats.get(args.benchmark_epochs_num))
 
     if epoch == args.epochs:
         evaluate(epoch, step, val_loader, val_feat_proc, symbols, model,