Browse Source

Merge pull request #708 from NVIDIA/gh/release

[FastPitch/PyT] Updating for 20.08
nv-kkudrynski 5 years ago
parent
commit
0b27e359a5
25 changed files with 367 additions and 308 deletions
  1. 1 1
      PyTorch/SpeechSynthesis/FastPitch/.dockerignore
  2. 0 1
      PyTorch/SpeechSynthesis/FastPitch/.gitignore
  3. 1 1
      PyTorch/SpeechSynthesis/FastPitch/Dockerfile
  4. 5 5
      PyTorch/SpeechSynthesis/FastPitch/README.md
  5. 0 121
      PyTorch/SpeechSynthesis/FastPitch/common/log_helper.py
  6. 169 0
      PyTorch/SpeechSynthesis/FastPitch/common/tb_dllogger.py
  7. 1 0
      PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py
  8. 1 1
      PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py
  9. 50 35
      PyTorch/SpeechSynthesis/FastPitch/inference.py
  10. 34 0
      PyTorch/SpeechSynthesis/FastPitch/pitch_transform.py
  11. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_AMP_1GPU.sh
  12. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_AMP_4GPU.sh
  13. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_AMP_8GPU.sh
  14. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_FP32_1GPU.sh
  15. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_FP32_4GPU.sh
  16. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_FP32_8GPU.sh
  17. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_AMP_1GPU.sh
  18. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_AMP_4GPU.sh
  19. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_AMP_8GPU.sh
  20. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_TF32_1GPU.sh
  21. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_TF32_4GPU.sh
  22. 1 2
      PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_TF32_8GPU.sh
  23. 2 3
      PyTorch/SpeechSynthesis/FastPitch/scripts/inference_benchmark.sh
  24. 2 1
      PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh
  25. 89 115
      PyTorch/SpeechSynthesis/FastPitch/train.py

+ 1 - 1
PyTorch/SpeechSynthesis/FastPitch/.dockerignore

@@ -1,7 +1,7 @@
 *~
 *.pyc
 __pycache__
-output
+output*
 LJSpeech-1.1*
 runs*
 pretrained_models

+ 0 - 1
PyTorch/SpeechSynthesis/FastPitch/.gitignore

@@ -4,6 +4,5 @@
 __pycache__
 scripts_joc/
 runs*/
-notebooks/
 LJSpeech-1.1/
 output*

+ 1 - 1
PyTorch/SpeechSynthesis/FastPitch/Dockerfile

@@ -1,4 +1,4 @@
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.06-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.08-py3
 FROM ${FROM_IMAGE_NAME}
 
 ADD requirements.txt .

+ 5 - 5
PyTorch/SpeechSynthesis/FastPitch/README.md

@@ -114,7 +114,7 @@ To speed-up FastPitch training,
 reference mel-spectrograms, character durations, and pitch cues
 are generated during the pre-processing step and read
 directly from the disk during training. For more information on data pre-processing refer to [Dataset guidelines
-](#dataset-guidelines) and the [paper](#).
+](#dataset-guidelines) and the [paper](https://arxiv.org/abs/2006.06873).
 
 ### Feature support matrix
 
@@ -196,11 +196,11 @@ called `losses`):
         ```
 #### Enabling TF32
 
-TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](#https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs. 
+TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs.
 
 TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models which require high dynamic range for weights or activations.
 
-For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](#https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
+For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
 
 TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
 
@@ -290,7 +290,7 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
    ```
    The training will produce a FastPitch model capable of generating mel-spectrograms from raw text.
    It will be serialized as a single `.pt` checkpoint file, along with a series of intermediate checkpoints.
-   The script is configured for 8x GPU with at least 16GB of memory. Consult [Training process](#training-process) and [example configs](#-training-performance-benchmark) to adjust to a different configuration or enable Automatic Mixed Precision.
+   The script is configured for 8x GPU with at least 16GB of memory. Consult [Training process](#training-process) and [example configs](#training-performance-benchmark) to adjust to a different configuration or enable Automatic Mixed Precision.
 
 5. Start validation/evaluation.
 
@@ -492,7 +492,7 @@ In a single accumulated step, there are `batch_size x gradient_accumulation_step
     ```
 With automatic mixed precision (AMP), a larger batch size fits in 16GB of memory:
     ```bash
-    NGPU=4 GRAD_ACC=1 BS=64 AMP=true bash scripta/train.sh
+    NGPU=4 GRAD_ACC=1 BS=64 AMP=true bash scripts/train.sh
     ```
 
 ### Inference process

+ 0 - 121
PyTorch/SpeechSynthesis/FastPitch/common/log_helper.py

@@ -1,121 +0,0 @@
-import atexit
-import glob
-import os
-import re
-import numpy as np
-
-from tensorboardX import SummaryWriter
-
-import dllogger as DLLogger
-from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
-
-
-def unique_dllogger_fpath(log_fpath):
-
-    if not os.path.isfile(log_fpath):
-        return log_fpath
-
-    # Avoid overwriting old logs
-    saved = sorted([int(re.search('\.(\d+)', f).group(1))
-                    for f in glob.glob(f'{log_fpath}.*')])
-
-    log_num = (saved[-1] if saved else 0) + 1
-    return f'{log_fpath}.{log_num}'
-
-
-def stdout_step_format(step):
-    if isinstance(step, str):
-        return step
-    fields = []
-    if len(step) > 0:
-        fields.append("epoch {:>4}".format(step[0]))
-    if len(step) > 1:
-        fields.append("iter {:>3}".format(step[1]))
-    if len(step) > 2:
-        fields[-1] += "/{}".format(step[2])
-    return " | ".join(fields)
-
-
-def stdout_metric_format(metric, metadata, value):
-    name = metadata["name"] if "name" in metadata.keys() else metric + " : "
-    unit = metadata["unit"] if "unit" in metadata.keys() else None
-    format = "{" + metadata["format"] + "}" if "format" in metadata.keys() else "{}"
-    fields = [name, format.format(value) if value is not None else value, unit]
-    fields = filter(lambda f: f is not None, fields)
-    return "| " + " ".join(fields)
-
-
-def init_dllogger(log_fpath=None, dummy=False):
-    if dummy:
-        DLLogger.init(backends=[])
-        return
-    DLLogger.init(backends=[
-        JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
-        StdOutBackend(Verbosity.VERBOSE, step_format=stdout_step_format,
-                      metric_format=stdout_metric_format)
-        ]
-    )
-    DLLogger.metadata("train_loss", {"name": "loss", "format": ":>5.2f"})
-    DLLogger.metadata("train_mel_loss", {"name": "mel loss", "format": ":>5.2f"})
-    DLLogger.metadata("avg_train_loss", {"name": "avg train loss", "format": ":>5.2f"})
-    DLLogger.metadata("avg_train_mel_loss", {"name": "avg train mel loss", "format": ":>5.2f"})
-    DLLogger.metadata("val_loss", {"name": "  avg val loss", "format": ":>5.2f"})
-    DLLogger.metadata("val_mel_loss", {"name": "  avg val mel loss", "format": ":>5.2f"})
-    DLLogger.metadata(
-        "val_ema_loss",
-        {"name": "  EMA val loss", "format": ":>5.2f"})
-    DLLogger.metadata(
-        "val_ema_mel_loss",
-        {"name": "  EMA val mel loss", "format": ":>5.2f"})
-    DLLogger.metadata(
-        "train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
-    DLLogger.metadata(
-        "avg_train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
-    DLLogger.metadata(
-        "val_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
-    DLLogger.metadata(
-        "val_ema_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
-    DLLogger.metadata(
-        "took", {"name": "took", "unit": "s", "format": ":>3.2f"})
-    DLLogger.metadata("lrate_change", {"name": "lrate"})
-
-
-class TBLogger(object):
-    """
-    xyz_dummies: stretch the screen with empty plots so the legend would
-                 always fit for other plots
-    """
-    def __init__(self, local_rank, log_dir, name, interval=1, dummies=False):
-        self.enabled = (local_rank == 0)
-        self.interval = interval
-        self.cache = {}
-        if local_rank == 0:
-            self.summary_writer = SummaryWriter(
-                log_dir=os.path.join(log_dir, name),
-                flush_secs=120, max_queue=200)
-            atexit.register(self.summary_writer.close)
-            if dummies:
-                for key in ('aaa', 'zzz'):
-                    self.summary_writer.add_scalar(key, 0.0, 1)
-
-    def log_value(self, step, key, val, stat='mean'):
-        if self.enabled:
-            if key not in self.cache:
-                self.cache[key] = []
-            self.cache[key].append(val)
-            if len(self.cache[key]) == self.interval:
-                agg_val = getattr(np, stat)(self.cache[key])
-                self.summary_writer.add_scalar(key, agg_val, step)
-                del self.cache[key]
-
-    def log_meta(self, step, meta):
-        for k, v in meta.items():
-            self.log_value(step, k, v.item())
-
-    def log_grads(self, step, model):
-        if self.enabled:
-            norms = [p.grad.norm().item() for p in model.parameters()
-                     if p.grad is not None]
-            for stat in ('max', 'min', 'mean'):
-                self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
-                               stat=stat)

+ 169 - 0
PyTorch/SpeechSynthesis/FastPitch/common/tb_dllogger.py

@@ -0,0 +1,169 @@
+import atexit
+import glob
+import os
+import re
+import numpy as np
+
+import torch
+from torch.utils.tensorboard import SummaryWriter
+
+import dllogger
+from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
+
+
+tb_loggers = {}
+
+
+class TBLogger:
+    """
+    xyz_dummies: stretch the screen with empty plots so the legend would
+                 always fit for other plots
+    """
+    def __init__(self, enabled, log_dir, name, interval=1, dummies=True):
+        self.enabled = enabled
+        self.interval = interval
+        self.cache = {}
+        if self.enabled:
+            self.summary_writer = SummaryWriter(
+                log_dir=os.path.join(log_dir, name),
+                flush_secs=120, max_queue=200)
+            atexit.register(self.summary_writer.close)
+            if dummies:
+                for key in ('aaa', 'zzz'):
+                    self.summary_writer.add_scalar(key, 0.0, 1)
+
+    def log(self, step, data):
+        for k, v in data.items():
+            self.log_value(step, k, v.item() if type(v) is torch.Tensor else v)
+
+    def log_value(self, step, key, val, stat='mean'):
+        if self.enabled:
+            if key not in self.cache:
+                self.cache[key] = []
+            self.cache[key].append(val)
+            if len(self.cache[key]) == self.interval:
+                agg_val = getattr(np, stat)(self.cache[key])
+                self.summary_writer.add_scalar(key, agg_val, step)
+                del self.cache[key]
+
+    def log_grads(self, step, model):
+        if self.enabled:
+            norms = [p.grad.norm().item() for p in model.parameters()
+                     if p.grad is not None]
+            for stat in ('max', 'min', 'mean'):
+                self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
+                               stat=stat)
+
+
+def unique_log_fpath(log_fpath):
+
+    if not os.path.isfile(log_fpath):
+        return log_fpath
+
+    # Avoid overwriting old logs
+    saved = sorted([int(re.search('\.(\d+)', f).group(1))
+                    for f in glob.glob(f'{log_fpath}.*')])
+
+    log_num = (saved[-1] if saved else 0) + 1
+    return f'{log_fpath}.{log_num}'
+
+
+def stdout_step_format(step):
+    if isinstance(step, str):
+        return step
+    fields = []
+    if len(step) > 0:
+        fields.append("epoch {:>4}".format(step[0]))
+    if len(step) > 1:
+        fields.append("iter {:>3}".format(step[1]))
+    if len(step) > 2:
+        fields[-1] += "/{}".format(step[2])
+    return " | ".join(fields)
+
+
+def stdout_metric_format(metric, metadata, value):
+    name = metadata.get("name", metric + " : ")
+    unit = metadata.get("unit", None)
+    format = f'{{{metadata.get("format", "")}}}'
+    fields = [name, format.format(value) if value is not None else value, unit]
+    fields = [f for f in fields if f is not None]
+    return "| " + " ".join(fields)
+
+
+def init(log_fpath, log_dir, enabled=True, tb_subsets=[], **tb_kw):
+
+    if enabled:
+        backends = [JSONStreamBackend(Verbosity.DEFAULT,
+                                      unique_log_fpath(log_fpath)),
+                    StdOutBackend(Verbosity.VERBOSE,
+                                  step_format=stdout_step_format,
+                                  metric_format=stdout_metric_format)]
+    else:
+        backends = []
+
+    dllogger.init(backends=backends)
+    dllogger.metadata("train_lrate", {"name": "lrate", "format": ":>3.2e"})
+
+    for id_, pref in [('train', ''), ('train_avg', 'avg train '),
+                      ('val', '  avg val '), ('val_ema', '  EMA val ')]:
+
+        dllogger.metadata(f"{id_}_loss",
+                          {"name": f"{pref}loss", "format": ":>5.2f"})
+        dllogger.metadata(f"{id_}_mel_loss",
+                          {"name": f"{pref}mel loss", "format": ":>5.2f"})
+
+        dllogger.metadata(f"{id_}_frames/s",
+                          {"name": None, "unit": "frames/s", "format": ":>10.2f"})
+        dllogger.metadata(f"{id_}_took",
+                          {"name": "took", "unit": "s", "format": ":>3.2f"})
+
+    global tb_loggers
+    tb_loggers = {s: TBLogger(enabled, log_dir, name=s, **tb_kw)
+                  for s in tb_subsets}
+
+
+def init_inference_metadata():
+
+    modalities = [('latency', 's', ':>10.5f'), ('RTF', 'x', ':>10.2f'),
+                  ('frames/s', None, ':>10.2f'), ('samples/s', None, ':>10.2f'),
+                  ('letters/s', None, ':>10.2f')]
+
+    for perc in ['', 'avg', '90%', '95%', '99%']:
+        for model in ['fastpitch', 'waveglow', '']:
+            for mod, unit, format in modalities:
+
+                name = f'{perc} {model} {mod}'.strip().replace('  ', ' ')
+
+                dllogger.metadata(
+                    name.replace(' ', '_'),
+                    {'name': f'{name: <26}', 'unit': unit, 'format': format})
+
+
+def log(step, tb_total_steps=None, data={}, subset='train'):
+    if tb_total_steps is not None:
+        tb_loggers[subset].log(tb_total_steps, data)
+
+    if subset != '':
+        data = {f'{subset}_{key}': v for key,v in data.items()}
+    dllogger.log(step, data=data)
+
+
+def log_grads_tb(tb_total_steps, grads, tb_subset='train'):
+    tb_loggers[tb_subset].log_grads(tb_total_steps, grads)
+
+
+def parameters(data, verbosity=0, tb_subset=None):
+    for k,v in data.items():
+        dllogger.log(step="PARAMETER", data={k:v}, verbosity=verbosity)
+
+    if tb_subset is not None and tb_loggers[tb_subset].enabled:
+        tb_data = {k:v for k,v in data.items()
+                   if type(v) in (str, bool, int, float)}
+        tb_loggers[tb_subset].summary_writer.add_hparams(tb_data, {})
+
+
+def flush():
+    dllogger.flush()
+    for tbl in tb_loggers.values():
+        if tbl.enabled:
+            tbl.summary_writer.flush()

+ 1 - 0
PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py

@@ -82,6 +82,7 @@ class TextMelAliCollate():
             dur = batch[ids_sorted_decreasing[i]][3]
             dur_padded[i, :dur.shape[0]] = dur
             dur_lens[i] = dur.shape[0]
+            assert dur_lens[i] == input_lengths[i]
 
         # Right zero-pad mel-spec
         num_mels = batch[0][1].size(0)

+ 1 - 1
PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py

@@ -190,7 +190,7 @@ class FastPitch(nn.Module):
                 mean, std = 218.14, 67.24
             else:
                 mean, std = self.pitch_mean[0], self.pitch_std[0]
-            pitch_pred = pitch_transform(pitch_pred, mean, std)
+            pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
 
         if pitch_tgt is None:
             pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)

+ 50 - 35
PyTorch/SpeechSynthesis/FastPitch/inference.py

@@ -43,8 +43,10 @@ import dllogger as DLLogger
 from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
 
 from common import utils
-from common.log_helper import unique_dllogger_fpath
+from common.tb_dllogger import (init_inference_metadata, stdout_metric_format,
+                                unique_log_fpath)
 from common.text import text_to_sequence
+from pitch_transform import pitch_transform_custom
 from waveglow import model as glow
 from waveglow.denoiser import Denoiser
 
@@ -63,6 +65,8 @@ def parse_args(parser):
                         help='Path to a DLLogger log file')
     parser.add_argument('--cuda', action='store_true',
                         help='Run inference on a GPU using CUDA')
+    parser.add_argument('--cudnn-benchmark', action='store_true',
+                        help='Enable cudnn benchmark mode')
     parser.add_argument('--fastpitch', type=str,
                         help='Full path to the generator checkpoint file (skip to use ground truth mels)')
     parser.add_argument('--waveglow', type=str,
@@ -77,7 +81,7 @@ def parse_args(parser):
                         help='STFT hop length for estimating audio length from mel size')
     parser.add_argument('--amp', action='store_true',
                         help='Inference with AMP')
-    parser.add_argument('--batch-size', type=int, default=64)
+    parser.add_argument('-bs', '--batch-size', type=int, default=64)
     parser.add_argument('--include-warmup', action='store_true',
                         help='Include warmup')
     parser.add_argument('--repeats', type=int, default=1,
@@ -98,10 +102,12 @@ def parse_args(parser):
                            help='Flatten the pitch')
     transform.add_argument('--pitch-transform-invert', action='store_true',
                            help='Invert the pitch wrt mean value')
-    transform.add_argument('--pitch-transform-amplify', action='store_true',
-                           help='Amplify the pitch variability')
+    transform.add_argument('--pitch-transform-amplify', type=float, default=1.0,
+                           help='Amplify pitch variability, typical values are in the range (1.0, 3.0).')
     transform.add_argument('--pitch-transform-shift', type=float, default=0.0,
                            help='Raise/lower the pitch by <hz>')
+    transform.add_argument('--pitch-transform-custom', action='store_true',
+                           help='Apply the transform from pitch_transform.py')
     return parser
 
 
@@ -200,17 +206,25 @@ def prepare_input_sequence(fields, device, batch_size=128, dataset=None,
 
 
 def build_pitch_transformation(args):
+
+    if args.pitch_transform_custom:
+        def custom_(pitch, pitch_lens, mean, std):
+            return (pitch_transform_custom(pitch * std + mean, pitch_lens)
+                    - mean) / std
+        return custom_
+
     fun = 'pitch'
     if args.pitch_transform_flatten:
         fun = f'({fun}) * 0.0'
     if args.pitch_transform_invert:
         fun = f'({fun}) * -1.0'
     if args.pitch_transform_amplify:
-        fun = f'({fun}) * 2.0'
+        ampl = args.pitch_transform_amplify
+        fun = f'({fun}) * {ampl}'
     if args.pitch_transform_shift != 0.0:
         hz = args.pitch_transform_shift
         fun = f'({fun}) + {hz} / std'
-    return eval(f'lambda pitch, mean, std: {fun}')
+    return eval(f'lambda pitch, pitch_lens, mean, std: {fun}')
 
 
 class MeasureTime(list):
@@ -232,26 +246,27 @@ def main():
     Launches text to speech (inference).
     Inference is executed on a single GPU.
     """
-
-    torch.backends.cudnn.benchmark = True
-
     parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference',
                                      allow_abbrev=False)
     parser = parse_args(parser)
     args, unk_args = parser.parse_known_args()
 
+    torch.backends.cudnn.benchmark = args.cudnn_benchmark
+
     if args.output is not None:
         Path(args.output).mkdir(parents=False, exist_ok=True)
 
     log_fpath = args.log_file or str(Path(args.output, 'nvlog_infer.json'))
-    log_fpath = unique_dllogger_fpath(log_fpath)
+    log_fpath = unique_log_fpath(log_fpath)
     DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
-                            StdOutBackend(Verbosity.VERBOSE)])
+                            StdOutBackend(Verbosity.VERBOSE,
+                                          metric_format=stdout_metric_format)])
+    init_inference_metadata()
     [DLLogger.log("PARAMETER", {k:v}) for k,v in vars(args).items()]
 
     device = torch.device('cuda' if args.cuda else 'cpu')
 
-    if args.fastpitch is not None:
+    if args.fastpitch != 'SKIP':
         generator = load_and_setup_model(
             'FastPitch', parser, args.fastpitch, args.amp, device,
             unk_args=unk_args, forward_is_infer=True, ema=args.ema,
@@ -262,7 +277,7 @@ def main():
     else:
         generator = None
 
-    if args.waveglow is not None:
+    if args.waveglow != 'SKIP':
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
             waveglow = load_and_setup_model(
@@ -325,7 +340,7 @@ def main():
                 gen_infer_perf = mel.size(0) * mel.size(2) / gen_measures[-1]
                 all_letters += b['text_lens'].sum().item()
                 all_frames += mel.size(0) * mel.size(2)
-                log(rep, {"fastpitch_frames_per_sec": gen_infer_perf})
+                log(rep, {"fastpitch_frames/s": gen_infer_perf})
                 log(rep, {"fastpitch_latency": gen_measures[-1]})
 
             if waveglow is not None:
@@ -340,7 +355,7 @@ def main():
                 waveglow_infer_perf = (
                     audios.size(0) * audios.size(1) / waveglow_measures[-1])
 
-                log(rep, {"waveglow_samples_per_sec": waveglow_infer_perf})
+                log(rep, {"waveglow_samples/s": waveglow_infer_perf})
                 log(rep, {"waveglow_latency": waveglow_measures[-1]})
 
                 if args.output is not None and reps == 1:
@@ -364,32 +379,32 @@ def main():
     if generator is not None:
         gm = np.sort(np.asarray(gen_measures))
         rtf = all_samples / (all_utterances * gm.mean() * args.sampling_rate)
-        log('avg', {"fastpitch letters/s": all_letters / gm.sum()})
-        log('avg', {"fastpitch_frames/s": all_frames / gm.sum()})
-        log('avg', {"fastpitch_latency": gm.mean()})
-        log('avg', {"fastpitch RTF": rtf})
-        log('90%', {"fastpitch_latency": gm.mean() + norm.ppf((1.0 + 0.90) / 2) * gm.std()})
-        log('95%', {"fastpitch_latency": gm.mean() + norm.ppf((1.0 + 0.95) / 2) * gm.std()})
-        log('99%', {"fastpitch_latency": gm.mean() + norm.ppf((1.0 + 0.99) / 2) * gm.std()})
+        log((), {"avg_fastpitch_letters/s": all_letters / gm.sum()})
+        log((), {"avg_fastpitch_frames/s": all_frames / gm.sum()})
+        log((), {"avg_fastpitch_latency": gm.mean()})
+        log((), {"avg_fastpitch_RTF": rtf})
+        log((), {"90%_fastpitch_latency": gm.mean() + norm.ppf((1.0 + 0.90) / 2) * gm.std()})
+        log((), {"95%_fastpitch_latency": gm.mean() + norm.ppf((1.0 + 0.95) / 2) * gm.std()})
+        log((), {"99%_fastpitch_latency": gm.mean() + norm.ppf((1.0 + 0.99) / 2) * gm.std()})
     if waveglow is not None:
         wm = np.sort(np.asarray(waveglow_measures))
         rtf = all_samples / (all_utterances * wm.mean() * args.sampling_rate)
-        log('avg', {"waveglow_samples/s": all_samples / wm.sum()})
-        log('avg', {"waveglow_latency": wm.mean()})
-        log('avg', {"waveglow RTF": rtf})
-        log('90%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.90) / 2) * wm.std()})
-        log('95%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.95) / 2) * wm.std()})
-        log('99%', {"waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.99) / 2) * wm.std()})
+        log((), {"avg_waveglow_samples/s": all_samples / wm.sum()})
+        log((), {"avg_waveglow_latency": wm.mean()})
+        log((), {"avg_waveglow_RTF": rtf})
+        log((), {"90%_waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.90) / 2) * wm.std()})
+        log((), {"95%_waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.95) / 2) * wm.std()})
+        log((), {"99%_waveglow_latency": wm.mean() + norm.ppf((1.0 + 0.99) / 2) * wm.std()})
     if generator is not None and waveglow is not None:
         m = gm + wm
         rtf = all_samples / (all_utterances * m.mean() * args.sampling_rate)
-        log('avg', {"samples/s": all_samples / m.sum()})
-        log('avg', {"letters/s": all_letters / m.sum()})
-        log('avg', {"latency": m.mean()})
-        log('avg', {"RTF": rtf})
-        log('90%', {"latency": m.mean() + norm.ppf((1.0 + 0.90) / 2) * m.std()})
-        log('95%', {"latency": m.mean() + norm.ppf((1.0 + 0.95) / 2) * m.std()})
-        log('99%', {"latency": m.mean() + norm.ppf((1.0 + 0.99) / 2) * m.std()})
+        log((), {"avg_samples/s": all_samples / m.sum()})
+        log((), {"avg_letters/s": all_letters / m.sum()})
+        log((), {"avg_latency": m.mean()})
+        log((), {"avg_RTF": rtf})
+        log((), {"90%_latency": m.mean() + norm.ppf((1.0 + 0.90) / 2) * m.std()})
+        log((), {"95%_latency": m.mean() + norm.ppf((1.0 + 0.95) / 2) * m.std()})
+        log((), {"99%_latency": m.mean() + norm.ppf((1.0 + 0.99) / 2) * m.std()})
     DLLogger.flush()
 
 

+ 34 - 0
PyTorch/SpeechSynthesis/FastPitch/pitch_transform.py

@@ -0,0 +1,34 @@
+
+import torch
+
+
+def pitch_transform_custom(pitch, pitch_lens):
+    """Apply a custom pitch transformation to predicted pitch values.
+
+    This sample modification linearly increases the pitch throughout
+    the utterance from 0.5 of predicted pitch to 1.5 of predicted pitch.
+    In other words, it starts low and ends high.
+
+    PARAMS
+    ------
+    pitch: torch.Tensor (bs, max_len)
+        Predicted pitch values for each lexical unit, padded to max_len (in Hz).
+    pitch_lens: torch.Tensor (bs, max_len)
+        Number of lexical units in each utterance.
+
+    RETURNS
+    -------
+    pitch: torch.Tensor
+        Modified pitch (in Hz).
+    """
+
+    weights = torch.arange(pitch.size(1), dtype=torch.float32, device=pitch.device)
+
+    # The weights increase linearly from 0.0 to 1.0 in every i-th row
+    # in the range (0, pitch_lens[i])
+    weights = weights.unsqueeze(0) / pitch_lens.unsqueeze(1)
+
+    # Shift the range from (0.0, 1.0) to (0.5, 1.5)
+    weights += 0.5
+
+    return pitch * weights

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_AMP_1GPU.sh

@@ -4,13 +4,12 @@ mkdir -p output
 python train.py \
     --amp \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_AMP_4GPU.sh

@@ -4,13 +4,12 @@ mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 4 train.py \
     --amp \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_AMP_8GPU.sh

@@ -4,13 +4,12 @@ mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 8 train.py \
     --amp \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_FP32_1GPU.sh

@@ -3,13 +3,12 @@
 mkdir -p output
 python train.py \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_FP32_4GPU.sh

@@ -3,13 +3,12 @@
 mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 4 train.py \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGX1_FastPitch_FP32_8GPU.sh

@@ -3,13 +3,12 @@
 mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 8 train.py \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_AMP_1GPU.sh

@@ -4,13 +4,12 @@ mkdir -p output
 python train.py \
     --amp \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_AMP_4GPU.sh

@@ -4,13 +4,12 @@ mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 4 train.py \
     --amp \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_AMP_8GPU.sh

@@ -4,13 +4,12 @@ mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 8 train.py \
     --amp \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_TF32_1GPU.sh

@@ -3,13 +3,12 @@
 mkdir -p output
 python train.py \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_TF32_4GPU.sh

@@ -3,13 +3,12 @@
 mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 4 train.py \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 1 - 2
PyTorch/SpeechSynthesis/FastPitch/platform/DGXA100_FastPitch_TF32_8GPU.sh

@@ -3,13 +3,12 @@
 mkdir -p output
 python -m torch.distributed.launch --nproc_per_node 8 train.py \
     --cuda \
-    --cudnn-enabled \
     -o ./output/ \
     --log-file output/nvlog.json \
     --dataset-path LJSpeech-1.1 \
     --training-files filelists/ljs_mel_dur_pitch_text_train_filelist.txt \
     --validation-files filelists/ljs_mel_dur_pitch_text_test_filelist.txt \
-    --pitch-mean-std LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
+    --pitch-mean-std-file LJSpeech-1.1/pitch_char_stats__ljs_audio_text_train_filelist.json \
     --epochs 1500 \
     --epochs-per-checkpoint 100 \
     --warmup-steps 1000 \

+ 2 - 3
PyTorch/SpeechSynthesis/FastPitch/scripts/inference_benchmark.sh

@@ -7,13 +7,12 @@
 [ ! -n "$PHRASES" ] && PHRASES="phrases/benchmark_8_128.tsv"
 [ ! -n "$OUTPUT_DIR" ] && OUTPUT_DIR="./output/audio_$(basename ${PHRASES} .tsv)"
 [ "$AMP" == "true" ] && AMP_FLAG="--amp" || AMP=false
-[ "$SET_AFFINITY" == "true" ] && SET_AFFINITY_FLAG="--set-affinity"
 
 for BS in $BS_SEQ ; do
 
   echo -e "\nAMP: ${AMP}, batch size: ${BS}\n"
 
-  python inference.py --cuda \
+  python inference.py --cuda --cudnn-benchmark \
                       -i ${PHRASES} \
                       -o ${OUTPUT_DIR} \
                       --fastpitch ${FASTPITCH_CH} \
@@ -23,5 +22,5 @@ for BS in $BS_SEQ ; do
                       --batch-size ${BS} \
                       --repeats ${REPEATS} \
                       --torchscript \
-                      ${AMP_FLAG} ${SET_AFFINITY_FLAG}
+                      ${AMP_FLAG}
 done

+ 2 - 1
PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh

@@ -1,5 +1,7 @@
 #!/bin/bash
 
+export OMP_NUM_THREADS=1
+
 # Adjust env variables to maintain the global batch size
 #
 #    NGPU x BS x GRAD_ACC = 256.
@@ -19,7 +21,6 @@ echo -e "\nSetup: ${NGPU}x${BS}x${GRAD_ACC} - global batch size ${GBS}\n"
 mkdir -p "$OUTPUT_DIR"
 python -m torch.distributed.launch --nproc_per_node ${NGPU} train.py \
     --cuda \
-    --cudnn-enabled \
     -o "$OUTPUT_DIR/" \
     --log-file "$OUTPUT_DIR/nvlog.json" \
     --dataset-path LJSpeech-1.1 \

+ 89 - 115
PyTorch/SpeechSynthesis/FastPitch/train.py

@@ -45,7 +45,7 @@ from torch.nn.parameter import Parameter
 from torch.utils.data import DataLoader
 from torch.utils.data.distributed import DistributedSampler
 
-import dllogger as DLLogger
+import common.tb_dllogger as logger
 from apex import amp
 from apex.optimizers import FusedAdam, FusedLAMB
 
@@ -53,7 +53,6 @@ import common
 import data_functions
 import loss_functions
 import models
-from common.log_helper import init_dllogger, TBLogger, unique_dllogger_fpath
 
 
 def parse_args(parser):
@@ -82,10 +81,8 @@ def parse_args(parser):
                           help='Enable AMP')
     training.add_argument('--cuda', action='store_true',
                           help='Run on GPU using CUDA')
-    training.add_argument('--cudnn-enabled', action='store_true',
-                          help='Enable cudnn')
     training.add_argument('--cudnn-benchmark', action='store_true',
-                          help='Run cudnn benchmark')
+                          help='Enable cudnn benchmark mode')
     training.add_argument('--ema-decay', type=float, default=0,
                           help='Discounting factor for training weights EMA')
     training.add_argument('--gradient-accumulation-steps', type=int, default=1,
@@ -131,8 +128,7 @@ def parse_args(parser):
 def reduce_tensor(tensor, num_gpus):
     rt = tensor.clone()
     dist.all_reduce(rt, op=dist.ReduceOp.SUM)
-    rt /= num_gpus
-    return rt
+    return rt.true_divide(num_gpus)
 
 
 def init_distributed(args, world_size, rank):
@@ -174,6 +170,7 @@ def save_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
                     config, amp_run, filepath):
     if local_rank != 0:
         return
+
     print(f"Saving model and optimizer state at epoch {epoch} to {filepath}")
     ema_dict = None if ema_model is None else ema_model.state_dict()
     checkpoint = {'epoch': epoch,
@@ -208,11 +205,14 @@ def load_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
         ema_model.load_state_dict(checkpoint['ema_state_dict'])
 
 
-def validate(model, criterion, valset, batch_size, world_size, collate_fn,
-             distributed_run, rank, batch_to_gpu, use_gt_durations=False):
+def validate(model, epoch, total_iter, criterion, valset, batch_size,
+             collate_fn, distributed_run, batch_to_gpu, use_gt_durations=False,
+             ema=False):
     """Handles all the validation scoring and printing"""
     was_training = model.training
     model.eval()
+
+    tik = time.perf_counter()
     with torch.no_grad():
         val_sampler = DistributedSampler(valset) if distributed_run else None
         val_loader = DataLoader(valset, num_workers=8, shuffle=False,
@@ -225,6 +225,7 @@ def validate(model, criterion, valset, batch_size, world_size, collate_fn,
             x, y, num_frames = batch_to_gpu(batch)
             y_pred = model(x, use_gt_durations=use_gt_durations)
             loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
+
             if distributed_run:
                 for k,v in meta.items():
                     val_meta[k] += reduce_tensor(v, 1)
@@ -233,12 +234,24 @@ def validate(model, criterion, valset, batch_size, world_size, collate_fn,
                 for k,v in meta.items():
                     val_meta[k] += v
                 val_num_frames = num_frames.item()
+
         val_meta = {k: v / len(valset) for k,v in val_meta.items()}
-        val_loss = val_meta['loss']
+
+    val_meta['took'] = time.perf_counter() - tik
+
+    logger.log((epoch,) if epoch is not None else (),
+               tb_total_steps=total_iter,
+               subset='val_ema' if ema else 'val',
+               data=OrderedDict([
+                   ('loss', val_meta['loss'].item()),
+                   ('mel_loss', val_meta['mel_loss'].item()),
+                   ('frames/s', num_frames.item() / val_meta['took']),
+                   ('took', val_meta['took'])]),
+    )
 
     if was_training:
         model.train()
-    return val_loss.item(), val_meta, val_num_frames
+    return val_meta
 
 
 def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
@@ -270,39 +283,33 @@ def main():
     parser = parse_args(parser)
     args, _ = parser.parse_known_args()
 
-    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
-        local_rank = int(os.environ['LOCAL_RANK'])
-        world_size = int(os.environ['WORLD_SIZE'])
-    else:
-        local_rank = args.rank
-        world_size = args.world_size
-    distributed_run = world_size > 1
+    distributed_run = args.world_size > 1
 
-    torch.manual_seed(args.seed + local_rank)
-    np.random.seed(args.seed + local_rank)
+    torch.manual_seed(args.seed + args.local_rank)
+    np.random.seed(args.seed + args.local_rank)
 
-    if local_rank == 0:
+    if args.local_rank == 0:
         if not os.path.exists(args.output):
             os.makedirs(args.output)
 
-        log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
-        log_fpath = unique_dllogger_fpath(log_fpath)
-        init_dllogger(log_fpath)
-    else:
-        init_dllogger(dummy=True)
+    log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
+    tb_subsets = ['train', 'val']
+    if args.ema_decay > 0.0:
+        tb_subsets.append('val_ema')
 
-    [DLLogger.log("PARAMETER", {k:v}) for k,v in vars(args).items()]
+    logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
+                tb_subsets=tb_subsets)
+    logger.parameters(vars(args), tb_subset='train')
 
     parser = models.parse_model_args('FastPitch', parser)
     args, unk_args = parser.parse_known_args()
     if len(unk_args) > 0:
         raise ValueError(f'Invalid options {unk_args}')
 
-    torch.backends.cudnn.enabled = args.cudnn_enabled
     torch.backends.cudnn.benchmark = args.cudnn_benchmark
 
     if distributed_run:
-        init_distributed(args, world_size, local_rank)
+        init_distributed(args, args.world_size, args.local_rank)
 
     device = torch.device('cuda' if args.cuda else 'cpu')
     model_config = models.get_model_config('FastPitch', args)
@@ -351,9 +358,9 @@ def main():
         ch_fpath = None
 
     if ch_fpath is not None:
-        load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch,
+        load_checkpoint(args.local_rank, model, ema_model, optimizer, start_epoch,
                         start_iter, model_config, args.amp, ch_fpath,
-                        world_size)
+                        args.world_size)
 
     start_epoch = start_epoch[0]
     total_iter = start_iter[0]
@@ -381,15 +388,9 @@ def main():
 
     model.train()
 
-    train_tblogger = TBLogger(local_rank, args.output, 'train')
-    val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True)
-    if args.ema_decay > 0:
-        val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema')
-
-    val_loss = 0.0
     torch.cuda.synchronize()
     for epoch in range(start_epoch, args.epochs + 1):
-        epoch_start_time = time.time()
+        epoch_start_time = time.perf_counter()
 
         epoch_loss = 0.0
         epoch_mel_loss = 0.0
@@ -407,24 +408,16 @@ def main():
         epoch_iter = 0
         num_iters = len(train_loader) // args.gradient_accumulation_steps
         for batch in train_loader:
+
             if accumulated_steps == 0:
                 if epoch_iter == num_iters:
                     break
                 total_iter += 1
                 epoch_iter += 1
-                iter_start_time = time.time()
-                start = time.perf_counter()
+                iter_start_time = time.perf_counter()
 
-                old_lr = optimizer.param_groups[0]['lr']
                 adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                      args.warmup_steps)
-                new_lr = optimizer.param_groups[0]['lr']
-
-                if new_lr != old_lr:
-                    dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}'
-                    train_tblogger.log_value(total_iter, 'lrate', new_lr)
-                else:
-                    dllog_lrate_change = None
 
                 model.zero_grad()
 
@@ -443,9 +436,9 @@ def main():
                 loss.backward()
 
             if distributed_run:
-                reduced_loss = reduce_tensor(loss.data, world_size).item()
+                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                 reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
-                meta = {k: reduce_tensor(v, world_size) for k,v in meta.items()}
+                meta = {k: reduce_tensor(v, args.world_size) for k,v in meta.items()}
             else:
                 reduced_loss = loss.item()
                 reduced_num_frames = num_frames.item()
@@ -459,7 +452,7 @@ def main():
 
             if accumulated_steps % args.gradient_accumulation_steps == 0:
 
-                train_tblogger.log_grads(total_iter, model)
+                logger.log_grads_tb(total_iter, model)
                 if args.amp:
                     torch.nn.utils.clip_grad_norm_(
                         amp.master_params(optimizer), args.grad_clip_thresh)
@@ -470,21 +463,23 @@ def main():
                 optimizer.step()
                 apply_ema_decay(model, ema_model, args.ema_decay)
 
-                iter_stop_time = time.time()
-                iter_time = iter_stop_time - iter_start_time
-                frames_per_sec = iter_num_frames / iter_time
-                epoch_frames_per_sec += frames_per_sec
+                iter_time = time.perf_counter() - iter_start_time
+                iter_mel_loss = iter_meta['mel_loss'].item()
+                epoch_frames_per_sec += iter_num_frames / iter_time
                 epoch_loss += iter_loss
                 epoch_num_frames += iter_num_frames
-                iter_mel_loss = iter_meta['mel_loss'].item()
                 epoch_mel_loss += iter_mel_loss
 
-                DLLogger.log((epoch, epoch_iter, num_iters), OrderedDict([
-                    ('train_loss', iter_loss), ('train_mel_loss', iter_mel_loss),
-                    ('train_frames/s', frames_per_sec), ('took', iter_time),
-                    ('lrate_change', dllog_lrate_change)
-                ]))
-                train_tblogger.log_meta(total_iter, iter_meta)
+                logger.log((epoch, epoch_iter, num_iters),
+                           tb_total_steps=total_iter,
+                           subset='train',
+                           data=OrderedDict([
+                               ('loss', iter_loss),
+                               ('mel_loss', iter_mel_loss),
+                               ('frames/s', iter_num_frames / iter_time),
+                               ('took', iter_time),
+                               ('lrate', optimizer.param_groups[0]['lr'])]),
+                )
 
                 accumulated_steps = 0
                 iter_loss = 0
@@ -492,69 +487,48 @@ def main():
                 iter_meta = {}
 
         # Finished epoch
-        epoch_stop_time = time.time()
-        epoch_time = epoch_stop_time - epoch_start_time
-
-        DLLogger.log((epoch,), data=OrderedDict([
-            ('avg_train_loss', epoch_loss / epoch_iter),
-            ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
-            ('avg_train_frames/s', epoch_num_frames / epoch_time),
-            ('took', epoch_time)
-        ]))
-
-        tik = time.time()
-        val_loss, meta, num_frames = validate(
-            model, criterion, valset, args.batch_size, world_size, collate_fn,
-            distributed_run, local_rank, batch_to_gpu, use_gt_durations=True)
-        tok = time.time()
-
-        DLLogger.log((epoch,), data=OrderedDict([
-            ('val_loss', val_loss),
-            ('val_mel_loss', meta['mel_loss'].item()),
-            ('val_frames/s', num_frames / (tok - tik)),
-            ('took', tok - tik),
-        ]))
-        val_tblogger.log_meta(total_iter, meta)
+        epoch_time = time.perf_counter() - epoch_start_time
+
+        logger.log((epoch,),
+                   tb_total_steps=None,
+                   subset='train_avg',
+                   data=OrderedDict([
+                       ('loss', epoch_loss / epoch_iter),
+                       ('mel_loss', epoch_mel_loss / epoch_iter),
+                       ('frames/s', epoch_num_frames / epoch_time),
+                       ('took', epoch_time)]),
+        )
+
+        validate(model, epoch, total_iter, criterion, valset, args.batch_size,
+                 collate_fn, distributed_run, batch_to_gpu,
+                 use_gt_durations=True)
 
         if args.ema_decay > 0:
-            tik_e = time.time()
-            val_loss_e, meta_e, num_frames_e = validate(
-                ema_model, criterion, valset, args.batch_size, world_size,
-                collate_fn, distributed_run, local_rank, batch_to_gpu,
-                use_gt_durations=True)
-            tok_e = time.time()
-
-            DLLogger.log((epoch,), data=OrderedDict([
-                ('val_ema_loss', val_loss_e),
-                ('val_ema_mel_loss', meta_e['mel_loss'].item()),
-                ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)),
-                ('took', tok_e - tik_e),
-            ]))
-            val_ema_tblogger.log_meta(total_iter, meta)
+            validate(ema_model, epoch, total_iter, criterion, valset,
+                     args.batch_size, collate_fn, distributed_run, batch_to_gpu,
+                     use_gt_durations=True, ema=True)
 
         if (epoch > 0 and args.epochs_per_checkpoint > 0 and
-            (epoch % args.epochs_per_checkpoint == 0) and local_rank == 0):
+            (epoch % args.epochs_per_checkpoint == 0) and args.local_rank == 0):
 
             checkpoint_path = os.path.join(
                 args.output, f"FastPitch_checkpoint_{epoch}.pt")
-            save_checkpoint(local_rank, model, ema_model, optimizer, epoch,
+            save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
                             total_iter, model_config, args.amp, checkpoint_path)
-        if local_rank == 0:
-            DLLogger.flush()
+        logger.flush()
 
     # Finished training
-    DLLogger.log((), data=OrderedDict([
-        ('avg_train_loss', epoch_loss / epoch_iter),
-        ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
-        ('avg_train_frames/s', epoch_num_frames / epoch_time),
-    ]))
-    DLLogger.log((), data=OrderedDict([
-        ('val_loss', val_loss),
-        ('val_mel_loss', meta['mel_loss'].item()),
-        ('val_frames/s', num_frames / (tok - tik)),
-    ]))
-    if local_rank == 0:
-        DLLogger.flush()
+    logger.log((),
+               tb_total_steps=None,
+               subset='train_avg',
+               data=OrderedDict([
+                   ('loss', epoch_loss / epoch_iter),
+                   ('mel_loss', epoch_mel_loss / epoch_iter),
+                   ('frames/s', epoch_num_frames / epoch_time),
+                   ('took', epoch_time)]),
+    )
+    validate(model, None, total_iter, criterion, valset, args.batch_size,
+             collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)
 
 
 if __name__ == '__main__':