瀏覽代碼

[QuartzNet/PyT] Support NeMo checkpoints

Adrian Lancucki 3 年之前
父節點
當前提交
b472e61ba7

+ 18 - 0
PyTorch/SpeechRecognition/QuartzNet/README.md

@@ -13,6 +13,7 @@ This repository provides a script and recipe to train the QuartzNet model to ach
         * [Enabling mixed precision](#enabling-mixed-precision)
         * [Enabling TF32](#enabling-tf32)
     * [Glossary](#glossary)
+    * [Language support and NeMo compatibility](#language-support-and-nemo-compatibility)
 - [Setup](#setup)
     * [Requirements](#requirements)
 - [Quick Start Guide](#quick-start-guide)
@@ -144,6 +145,23 @@ Assigns a probability distribution over a sequence of words. Given a sequence of
 **Pre-training**
 Training a model on vast amounts of data on the same (or different) task to build general understandings.
 
+### Language support and NeMo compatibility
+
+This repository allows to train and run models in laguages other than English.
+
+During inference, QuartzNet models trained with [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) can also be used, for instance one of pre-trained models
+for Catalan, French, German, Italian, Mandarin Chinese, Polish, Russian or Spanish available on [NGC](https://ngc.nvidia.com/).
+To download automatically, run:
+```bash
+bash scripts/download_quartznet.sh [ca|fr|de|it|zh|pl|ru|es]
+```
+
+Pre-trained models can be explicitly converted from the `.nemo` checkpoint format to `.pt` and vice versa.
+For more details, run:
+```bash
+python nemo_dle_model_converter.py --help
+```
+
 ## Setup
 
 The following section lists the requirements that you need to meet in order to start training the QuartzNet model.

+ 1 - 1
PyTorch/SpeechRecognition/QuartzNet/common/audio.py

@@ -75,7 +75,7 @@ class AudioSegment(object):
         self._samples = samples
         self._sample_rate = sample_rate
         if self._samples.ndim >= 2:
-            self._samples = np.mean(self._samples, 1)
+            self._samples = np.mean(self._samples, 0)
 
     def __eq__(self, other):
         """Return whether two objects are equal."""

+ 22 - 7
PyTorch/SpeechRecognition/QuartzNet/inference.py

@@ -35,8 +35,9 @@ from common.dataset import (AudioDataset, FilelistDataset, get_data_loader,
                             SingleAudioDataset)
 from common.features import BaseFeatures, FilterbankFeatures
 from common.helpers import print_once, process_evaluation_epoch
-from quartznet.model import GreedyCTCDecoder, QuartzNet
 from common.tb_dllogger import stdout_metric_format, unique_log_fpath
+from nemo_dle_model_converter import load_nemo_ckpt
+from quartznet.model import GreedyCTCDecoder, QuartzNet
 
 
 def get_parser():
@@ -189,7 +190,25 @@ def main():
         distrib.init_process_group(backend='nccl', init_method='env://')
         print_once(f'Inference with {distrib.get_world_size()} GPUs')
 
-    cfg = config.load(args.model_config)
+    if args.ckpt is not None:
+        print(f'Loading the model from {args.ckpt} ...')
+        print(f'{args.model_config} will be overriden.')
+        if args.ckpt.lower().endswith('.nemo'):
+            ckpt, cfg = load_nemo_ckpt(args.ckpt)
+        else:
+            cfg = config.load(args.model_config)
+            ckpt = torch.load(args.ckpt, map_location='cpu')
+
+        sd_key = 'ema_state_dict' if args.ema else 'state_dict'
+        if args.ema and 'ema_state_dict' not in ckpt:
+            print(f'WARNING: EMA weights are unavailable in {args.ckpt}.')
+            sd_key = 'state_dict'
+        state_dict = ckpt[sd_key]
+
+    else:
+        cfg = config.load(args.model_config)
+        state_dict = None
+
     config.apply_config_overrides(cfg, args)
 
     symbols = helpers.add_ctc_blank(cfg['labels'])
@@ -267,11 +286,7 @@ def main():
     model = QuartzNet(encoder_kw=config.encoder(cfg),
                       decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
 
-    if args.ckpt is not None:
-        print(f'Loading the model from {args.ckpt} ...')
-        checkpoint = torch.load(args.ckpt, map_location="cpu")
-        key = 'ema_state_dict' if args.ema else 'state_dict'
-        state_dict = checkpoint[key]
+    if state_dict is not None:
         model.load_state_dict(state_dict, strict=True)
 
     model.to(device)

+ 223 - 0
PyTorch/SpeechRecognition/QuartzNet/nemo_dle_model_converter.py

@@ -0,0 +1,223 @@
+import argparse
+import io
+import sys
+from copy import deepcopy
+from functools import reduce
+from pathlib import Path
+from subprocess import CalledProcessError, check_output
+
+import torch
+import yaml
+
+import quartznet.config
+from common import helpers
+from common.features import FilterbankFeatures
+from quartznet.config import load as load_yaml
+from quartznet.model import QuartzNet, MaskedConv1d
+
+
+# Corresponding DLE <-> NeMo config keys
+cfg_key_map = {
+    ("input_val", "audio_dataset", "sample_rate"): ("preprocessor", "sample_rate"),
+    ("input_val", "filterbank_features", "dither"): ("preprocessor", "dither"),
+    ("input_val", "filterbank_features", "frame_splicing"): ("preprocessor", "frame_splicing"),
+    ("input_val", "filterbank_features", "n_fft"): ("preprocessor", "n_fft"),
+    ("input_val", "filterbank_features", "n_filt"): ("preprocessor", "features"),
+    ("input_val", "filterbank_features", "normalize"): ("preprocessor", "normalize"),
+    ("input_val", "filterbank_features", "sample_rate"): ("preprocessor", "sample_rate"),
+    ("input_val", "filterbank_features", "window"): ("preprocessor", "window"),
+    ("input_val", "filterbank_features", "window_size"): ("preprocessor", "window_size"),
+    ("input_val", "filterbank_features", "window_stride"): ("preprocessor", "window_stride"),
+    ("labels",): ("decoder", "vocabulary"),
+    ("quartznet", "decoder", "in_feats"): ("decoder", "feat_in"),
+    ("quartznet", "encoder", "activation"): ("encoder", "activation"),
+    ("quartznet", "encoder", "blocks"): ("encoder", "jasper"),
+    ("quartznet", "encoder", "frame_splicing"): ("preprocessor", "frame_splicing"),
+    ("quartznet", "encoder", "in_feats"): ("encoder", "feat_in"),
+    ("quartznet", "encoder", "use_conv_masks"): ("encoder", "conv_mask"),
+}
+
+
+def load_nemo_ckpt(fpath):
+    """Make a DeepLearningExamples state_dict and config from a .nemo file."""
+    try:
+        cmd = ['tar', 'Oxzf', fpath, './model_config.yaml']
+        nemo_cfg = yaml.safe_load(io.BytesIO(check_output(cmd)))
+
+        cmd = ['tar', 'Oxzf', fpath, './model_weights.ckpt']
+        ckpt = torch.load(io.BytesIO(check_output(cmd)), map_location="cpu")
+
+    except (FileNotFoundError, CalledProcessError):
+        print('WARNING: Could not uncompress with tar. '
+              'Falling back to the tarfile module (might take a few minutes).')
+        import tarfile
+        with tarfile.open(fpath, "r:gz") as tar:
+            f = tar.extractfile(tar.getmember("./model_config.yaml"))
+            nemo_cfg = yaml.safe_load(f)
+
+            f = tar.extractfile(tar.getmember("./model_weights.ckpt"))
+            ckpt = torch.load(f, map_location="cpu")
+
+    remap = lambda k: (k.replace("encoder.encoder", "encoder.layers")
+                       .replace("decoder.decoder_layers", "decoder.layers")
+                       .replace("conv.weight", "weight"))
+    dle_ckpt = {'state_dict': {remap(k): v for k, v in ckpt.items()
+                               if "preproc" not in k}}
+    dle_cfg = config_from_nemo(nemo_cfg)
+    return dle_ckpt, dle_cfg
+
+
+def save_nemo_ckpt(dle_ckpt, dle_cfg, dest_path):
+    """Save a DeepLearningExamples model as a .nemo file."""
+    cfg = deepcopy(dle_cfg)
+
+    dle_ckpt = torch.load(dle_ckpt, map_location="cpu")["ema_state_dict"]
+
+    # Build a DLE model instance and fill with weights
+    symbols = helpers.add_ctc_blank(cfg['labels'])
+    enc_kw = quartznet.config.encoder(cfg)
+    dec_kw = quartznet.config.decoder(cfg, n_classes=len(symbols))
+    model = QuartzNet(enc_kw, dec_kw)
+    model.load_state_dict(dle_ckpt, strict=True)
+
+    # Reaname core modules, e.g., encoder.layers -> encoder.encoder
+    model.encoder._modules['encoder'] = model.encoder._modules.pop('layers')
+    model.decoder._modules['decoder_layers'] = model.decoder._modules.pop('layers')
+
+    # MaskedConv1d is made via composition in NeMo, and via inheritance in DLE
+    # Params for MaskedConv1d in NeMo have an additional '.conv.' infix
+    def rename_convs(module):
+        for name in list(module._modules.keys()):
+            submod = module._modules[name]
+
+            if isinstance(submod, MaskedConv1d):
+                module._modules[f'{name}.conv'] = module._modules.pop(name)
+            else:
+                rename_convs(submod)
+
+    rename_convs(model.encoder.encoder)
+
+    # Use FilterbankFeatures to calculate fbanks and store with model weights
+    feature_processor = FilterbankFeatures(
+        **dle_cfg['input_val']['filterbank_features'])
+
+    nemo_ckpt = model.state_dict()
+    nemo_ckpt["preprocessor.featurizer.fb"] = feature_processor.fb
+    nemo_ckpt["preprocessor.featurizer.window"] = feature_processor.window
+
+    nemo_cfg = config_to_nemo(dle_cfg)
+
+    # Prepare the directory for zipping
+    ckpt_files = dest_path / "ckpt_files"
+    ckpt_files.mkdir(exist_ok=True, parents=False)
+    with open(ckpt_files / "model_config.yaml", "w") as f:
+        yaml.dump(nemo_cfg, f)
+    torch.save(nemo_ckpt, ckpt_files / "model_weights.ckpt")
+
+    with tarfile.open(dest_path / "quartznet.nemo", "w:gz") as tar:
+        tar.add(ckpt_files, arcname="./")
+
+
+def save_dle_ckpt(ckpt, cfg, dest_dir):
+    torch.save(ckpt, dest_dir / "model.pt")
+    with open(dest_dir / "model_config.yaml", "w") as f:
+        yaml.dump(cfg, f)
+
+
+def set_nested_item(tgt, src, tgt_keys, src_keys):
+    """Assigns nested dict keys, e.g., d1[a][b][c] = d2[e][f][g][h]."""
+    tgt_nested = reduce(lambda d, k: d[k], tgt_keys[:-1], tgt)
+    tgt_nested[tgt_keys[-1]] = reduce(lambda d, k: d[k], src_keys, src)
+
+
+def config_from_nemo(nemo_cfg):
+    """Convert a DeepLearningExamples config to a NeMo format."""
+    dle_cfg = {
+        'name': 'QuartzNet',
+        'input_val': {
+            'audio_dataset': {
+                'normalize_transcripts': True,
+            },
+            'filterbank_features': {
+                'pad_align': 16,
+            },
+        },
+        'quartznet': {
+            'decoder': {},
+            'encoder': {},
+        },
+    }
+
+    for dle_keys, nemo_keys in cfg_key_map.items():
+        try:
+            set_nested_item(dle_cfg, nemo_cfg, dle_keys, nemo_keys)
+        except KeyError:
+            print(f'WARNING: Could not load config {nemo_keys} as {dle_keys}.')
+
+    # mapping kernel_size is not expressable with cfg_map
+    for block in dle_cfg["quartznet"]["encoder"]["blocks"]:
+        block["kernel_size"] = block.pop("kernel")
+
+    return dle_cfg
+
+
+def config_to_nemo(dle_cfg):
+    """Convert a DeepLearningExamples config to a NeMo format."""
+    nemo_cfg = {
+        "target": "nemo.collections.asr.models.ctc_models.EncDecCTCModel",
+        "dropout": 0.0,
+        "preprocessor": {
+            "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
+            "stft_conv": False,
+        },
+        "encoder": {
+            "_target_": "nemo.collections.asr.modules.ConvASREncoder",
+            "jasper": {}
+        },
+        "decoder": {
+          "_target_": "nemo.collections.asr.modules.ConvASRDecoder",
+        },
+    }
+
+    for dle_keys, nemo_keys in cfg_key_map.items():
+        try:
+            set_nested_item(nemo_cfg, dle_cfg, nemo_keys, dle_keys)
+        except KeyError:
+            print(f"WARNING: Could not load config {dle_keys} as {nemo_keys}.")
+
+    nemo_cfg["sample_rate"] = nemo_cfg["preprocessor"]["sample_rate"]
+    nemo_cfg["repeat"] = nemo_cfg["encoder"]["jasper"][1]["repeat"]
+    nemo_cfg["separable"] = nemo_cfg["encoder"]["jasper"][1]["separable"]
+    nemo_cfg["labels"] = nemo_cfg["decoder"]["vocabulary"]
+    nemo_cfg["decoder"]["num_classes"] = len(nemo_cfg["decoder"]["vocabulary"])
+
+    # mapping kernel_size is not expressable with cfg_map
+    for block in nemo_cfg["encoder"]["jasper"]:
+        if "kernel_size" in block:
+            block["kernel"] = block.pop("kernel_size")
+
+    return nemo_cfg
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="QuartzNet DLE <-> NeMo model converter.")
+    parser.add_argument("source_model", type=Path,
+                        help="A DLE or NeMo QuartzNet model to be converted (.pt or .nemo, respectively)")
+    parser.add_argument("dest_dir", type=Path, help="Destination directory")
+    parser.add_argument("--dle_config_yaml", type=Path,
+                        help="A DLE config .yaml file, required only to convert DLE -> NeMo")
+    args = parser.parse_args()
+
+    ext = args.source_model.suffix.lower()
+    if ext == ".nemo":
+        ckpt, cfg = load_nemo_ckpt(args.source_model)
+        save_dle_ckpt(ckpt, cfg, args.dest_dir)
+
+    elif ext == ".pt":
+        dle_cfg = load_yaml(args.dle_config_yaml)
+        save_nemo_ckpt(args.source_model, dle_cfg, args.dest_dir)
+
+    else:
+        raise ValueError(f"Unknown extension {ext}.")
+
+    print('Converted succesfully.')

+ 21 - 6
PyTorch/SpeechRecognition/QuartzNet/scripts/download_quartznet.sh

@@ -2,22 +2,37 @@
 
 set -e
 
-: ${MODEL_DIR:="pretrained_models/quartznet"}
-MODEL_ZIP="quartznet_pyt_ckpt_amp_21.03.0.zip"
-MODEL="nvidia_quartznet_210504.pt"
-MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/quartznet_pyt_ckpt_amp/versions/21.03.0/zip"
+: ${LANGUAGE:=${1:-en}}
+: ${MODEL_DIR:="pretrained_models/quartznet_${LANGUAGE}"}
+
+case $LANGUAGE in
+  en)
+    MODEL="nvidia_quartznet_210504.pt"
+    MODEL_ZIP="quartznet_pyt_ckpt_amp_21.03.0.zip"
+    MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/quartznet_pyt_ckpt_amp/versions/21.03.0/zip"
+    ;;
+  ca|de|es|fr|it|pl|ru|zh)
+    MODEL="stt_${LANGUAGE}_quartznet15x5.nemo"
+    MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_${LANGUAGE}_quartznet15x5/versions/1.0.0rc1/zip"
+    MODEL_ZIP="stt_${LANGUAGE}_quartznet15x5_1.0.0rc1.zip"
+    ;;
+  *)
+    echo "Unsupported language $LANGUAGE"
+    exit 1
+    ;;
+esac
 
 mkdir -p "$MODEL_DIR"
 
 if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
   echo "Downloading ${MODEL_ZIP} ..."
-  wget -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
+  wget -O ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
        || { echo "ERROR: Failed to download ${MODEL_ZIP} from NGC"; exit 1; }
 fi
 
 if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
   echo "Extracting ${MODEL} ..."
-  unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
+  unzip -o ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
         || { echo "ERROR: Failed to extract ${MODEL_ZIP}"; exit 1; }
 
   echo "OK"

+ 1 - 1
PyTorch/SpeechRecognition/QuartzNet/scripts/inference.sh

@@ -17,7 +17,7 @@
 : ${DATA_DIR:=${1:-"/datasets/LibriSpeech"}}
 : ${MODEL_CONFIG:=${2:-"configs/quartznet15x5_speedp-online-1.15_speca.yaml"}}
 : ${OUTPUT_DIR:=${3:-"/results"}}
-: ${CHECKPOINT:=${4:-"pretrained_models/quartznet/nvidia_quartznet_210504.pt"}}
+: ${CHECKPOINT:=${4:-"pretrained_models/quartznet_en/nvidia_quartznet_210504.pt"}}
 : ${DATASET:="test-other"}
 : ${LOG_FILE:=""}
 : ${CUDNN_BENCHMARK:=false}