Просмотр исходного кода

[FastPitch/PyT] Add mixed English and Mandarin bilingual support

Myron Du 3 лет назад
Родитель
Сommit
c2bb3fea79
22 измененных файлов с 1045 добавлено и 73 удалено
  1. 174 50
      PyTorch/SpeechSynthesis/FastPitch/README.md
  2. BIN
      PyTorch/SpeechSynthesis/FastPitch/audio/com_SF_ce1514_fastpitch_waveglow.wav
  3. 13 1
      PyTorch/SpeechSynthesis/FastPitch/common/text/symbols.py
  4. 10 0
      PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py
  5. 81 0
      PyTorch/SpeechSynthesis/FastPitch/common/text/zh/chinese.py
  6. 74 0
      PyTorch/SpeechSynthesis/FastPitch/common/text/zh/mandarin_text_processing.py
  7. 412 0
      PyTorch/SpeechSynthesis/FastPitch/common/text/zh/pinyin_dict.txt
  8. 21 0
      PyTorch/SpeechSynthesis/FastPitch/common/utils.py
  9. 10 2
      PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py
  10. 2 2
      PyTorch/SpeechSynthesis/FastPitch/inference.py
  11. 20 0
      PyTorch/SpeechSynthesis/FastPitch/phrases/phrase_bilingual.txt
  12. 7 1
      PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py
  13. 1 0
      PyTorch/SpeechSynthesis/FastPitch/scripts/inference_benchmark.sh
  14. 2 3
      PyTorch/SpeechSynthesis/FastPitch/scripts/inference_example.sh
  15. 5 0
      PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/README.md
  16. 18 0
      PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/inference.sh
  17. 57 0
      PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/prepare_dataset.sh
  18. 1 0
      PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/requirements.txt
  19. 94 0
      PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/split_sf.py
  20. 21 0
      PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/train.sh
  21. 13 12
      PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh
  22. 9 2
      PyTorch/SpeechSynthesis/FastPitch/train.py

+ 174 - 50
PyTorch/SpeechSynthesis/FastPitch/README.md

@@ -25,6 +25,7 @@ This repository provides a script and recipe to train the FastPitch model to ach
         * [Multi-dataset](#multi-dataset)
     * [Training process](#training-process)
     * [Inference process](#inference-process)
+    * [Example: Training a model on Mandarin Chinese](#example-training-a-model-on-mandarin-chinese)
 - [Performance](#performance)
     * [Benchmarking](#benchmarking)
         * [Training performance benchmark](#training-performance-benchmark)
@@ -50,22 +51,22 @@ This repository provides a script and recipe to train the FastPitch model to ach
 [FastPitch](https://arxiv.org/abs/2006.06873) is one of two major components in a neural, text-to-speech (TTS) system:
 
 * a mel-spectrogram generator such as [FastPitch](https://arxiv.org/abs/2006.06873) or [Tacotron 2](https://arxiv.org/abs/1712.05884), and
-* a waveform synthesizer such as [WaveGlow](https://arxiv.org/abs/1811.00002) (see [NVIDIA example code](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2)).
+* a waveform synthesizer such as [WaveGlow](https://arxiv.org/abs/1811.00002) (refer to [NVIDIA example code](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2)).
 
-Such two-component TTS system is able to synthesize natural sounding speech from raw transcripts.
+Such a two-component TTS system is able to synthesize natural-sounding speech from raw transcripts.
 
 The FastPitch model generates mel-spectrograms and predicts a pitch contour from raw input text.
 In version 1.1, it does not need any pre-trained aligning model to bootstrap from.
-It allows to exert additional control over the synthesized utterances, such as:
+It allows exerting additional control over the synthesized utterances, such as:
 * modify the pitch contour to control the prosody,
-* increase or decrease the fundamental frequency in a naturally sounding way, that preserves the perceived identity of the speaker,
+* increase or decrease the fundamental frequency in a natural sounding way, that preserves the perceived identity of the speaker,
 * alter the rate of speech,
 * adjust the energy,
 * specify input as graphemes or phonemes,
 * switch speakers when the model has been trained with data from multiple speakers.
 Some of the capabilities of FastPitch are presented on the website with [samples](https://fastpitch.github.io/).
 
-Speech synthesized with FastPitch has state-of-the-art quality, and does not suffer from missing/repeating phrases like Tacotron 2 does.
+Speech synthesized with FastPitch has state-of-the-art quality, and does not suffer from missing/repeating phrases as Tacotron 2 does.
 This is reflected in Mean Opinion Scores ([details](https://arxiv.org/abs/2006.06873)).
 
 | Model          | Mean Opinion Score (MOS) |
@@ -93,7 +94,7 @@ The FastPitch model is similar to [FastSpeech2](https://arxiv.org/abs/2006.04558
 FastPitch is trained on a publicly
 available [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
 
-This model is trained with mixed precision using Tensor Cores on Volta, Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results from 2.0x to 2.7x faster than training without Tensor Cores, while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
+This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, NVIDIA Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results from 2.0x to 2.7x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
 
 ### Model architecture
 
@@ -105,14 +106,14 @@ from raw text (Figure 1). The entire process is parallel, which means that all i
 </p>
 <p align="center">
   <em>Figure 1. Architecture of FastPitch (<a href=”https://arxiv.org/abs/2006.06873”>source</a>). The model is composed of a bidirectional Transformer backbone (also known as a Transformer encoder), a pitch predictor, and a duration predictor. After passing through the first *N* Transformer blocks, encoding, the signal is augmented with pitch information and discretely upsampled. Then it goes through another set of *N* Transformer blocks, with the goal of
-smoothing out the upsampled signal, and constructing a mel-spectrogram.
+smoothing out the upsampled signal and constructing a mel-spectrogram.
   </em>
 </p>
 
 ### Default configuration
 
 The FastPitch model supports multi-GPU and mixed precision training with dynamic loss
-scaling (see Apex code
+scaling (refer to Apex code
 [here](https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py)),
 as well as mixed precision inference.
 
@@ -123,9 +124,9 @@ The following features were implemented in this model:
 training,
 * gradient accumulation for reproducible results regardless of the number of GPUs.
 
-Pitch contours and mel-spectrograms can be generated on-line during training.
+Pitch contours and mel-spectrograms can be generated online during training.
 To speed-up training, those could be generated during the pre-processing step and read
-directly from the disk during training. For more information on data pre-processing refer to [Dataset guidelines
+directly from the disk during training. For more information on data pre-processing, refer to [Dataset guidelines
 ](#dataset-guidelines) and the [paper](https://arxiv.org/abs/2006.06873).
 
 ### Feature support matrix
@@ -144,21 +145,21 @@ implementation of mixed precision training. It allows us to use FP16 training
 with FP32 master weights by modifying just a few lines of code.
 
 DistributedDataParallel (DDP) - The model uses PyTorch Lightning implementation
-of distributed data parallelism at the module level which can run across
+of distributed data parallelism at the module level, which can run across
 multiple machines.
 
 ### Mixed precision training
 
-Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in Volta, and following with both the Turing and Ampere architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
+Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in NVIDIA Volta, and following with both the Turing and Ampere architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using mixed precision training requires two steps:
 1.  Porting the model to use the FP16 data type where appropriate.
 2.  Adding loss scaling to preserve small gradient values.
 
 The ability to train deep learning networks with lower precision was introduced in the Pascal architecture and first supported in [CUDA 8](https://devblogs.nvidia.com/parallelforall/tag/fp16/) in the NVIDIA Deep Learning SDK.
 
 For information about:
--   How to train using mixed precision, see the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
--   Techniques used for mixed precision training, see the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
--   APEX tools for mixed precision training, see the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
+-   How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
+-   Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
+-   APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
 
 #### Enabling mixed precision
 
@@ -167,9 +168,9 @@ Mixed precision is using [native PyTorch implementation](https://pytorch.org/blo
 
 #### Enabling TF32
 
-TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs.
+TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math, also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on Volta GPUs.
 
-TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models which require high dynamic range for weights or activations.
+TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models which require a high dynamic range for weights or activations.
 
 For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
 
@@ -178,10 +179,10 @@ TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by defaul
 ### Glossary
 
 **Character duration**
-The time during which a character is being articulated. Could be measured in milliseconds, mel-spectrogram frames, etc. Some characters are not pronounced, and thus have 0 duration.
+The time during which a character is being articulated. It could be measured in milliseconds, mel-spectrogram frames, and so on. Some characters are not pronounced, and thus, have 0 duration.
 
 **Fundamental frequency**
-The lowest vibration frequency of a periodic soundwave, for example, produced by a vibrating instrument. It is perceived as the loudest. In the context of speech, it refers to the frequency of vibration of vocal chords.  Abbreviated as *f0*.
+The lowest vibration frequency of a periodic soundwave, for example, is produced by a vibrating instrument, and it is perceived as the loudest. In the context of speech, it refers to the frequency of vibration of vocal cords. It is abbreviated as *f0*.
 
 **Pitch**
 A perceived frequency of vibration of music or sound.
@@ -195,7 +196,7 @@ The following section lists the requirements that you need to meet in order to s
 
 ### Requirements
 
-This repository contains Dockerfile which extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
+This repository contains Dockerfile that extends the PyTorch NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
 -   [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
 -   [PyTorch 22.08-py3 NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
 or newer
@@ -205,16 +206,16 @@ or newer
     - [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
 
 
-For more information about how to get started with NGC containers, see the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
+For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
 -   [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
 -   [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
 -   [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
 
-For those unable to use the PyTorch NGC container, to set up the required environment or create your own container, see the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
+For those unable to use the PyTorch NGC container, to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
 
 ## Quick Start Guide
 
-To train your model using mixed or TF32 precision with Tensor Cores or using FP32, perform the following steps using the default parameters of the FastPitch model on the LJSpeech 1.1 dataset. For the specifics concerning training and inference, see the [Advanced](#advanced) section. Pre-trained FastPitch models are available for download on [NGC](https://ngc.nvidia.com/catalog/models?query=FastPitch&quickFilter=models).
+To train your model using mixed or TF32 precision with Tensor Cores or using FP32, perform the following steps using the default parameters of the FastPitch model on the LJSpeech 1.1 dataset. For the specifics concerning training and inference, refer to the [Advanced](#advanced) section. Pre-trained FastPitch models are available for download on [NGC](https://ngc.nvidia.com/catalog/models?query=FastPitch&quickFilter=models).
 
 1. Clone the repository.
    ```bash
@@ -224,7 +225,7 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
 
 2. Build and run the FastPitch PyTorch NGC container.
 
-   By default the container will use all available GPUs.
+   By default, the container will use all available GPUs.
    ```bash
    bash scripts/docker/build.sh
    bash scripts/docker/interactive.sh
@@ -232,20 +233,20 @@ To train your model using mixed or TF32 precision with Tensor Cores or using FP3
 
 3. Download and preprocess the dataset.
 
-   Use the scripts to automatically download and preprocess the training, validation and test datasets:
+   Use the scripts to automatically download and preprocess the training, validation, and test datasets:
    ```bash
    bash scripts/download_dataset.sh
    bash scripts/prepare_dataset.sh
    ```
 
-   The data is downloaded to the `./LJSpeech-1.1` directory (on the host).  The
+   The data is downloaded to the `./LJSpeech-1.1` directory (on the host). The
    `./LJSpeech-1.1` directory is mounted under the `/workspace/fastpitch/LJSpeech-1.1`
    location in the NGC container. The complete dataset has the following structure:
    ```bash
    ./LJSpeech-1.1
-   ├── mels             # (optional) Pre-calculated target mel-spectrograms; may be calculated on-line
+   ├── mels             # (optional) Pre-calculated target mel-spectrograms; can be calculated online
    ├── metadata.csv     # Mapping of waveforms to utterances
-   ├── pitch            # Fundamental frequency countours for input utterances; may be calculated on-line
+   ├── pitch            # Fundamental frequency contours for input utterances; can be calculated online
    ├── README
    └── wavs             # Raw waveforms
    ```
@@ -309,10 +310,10 @@ given model
 * `<model_name>/loss_function.py` - loss function for the model
 
 In the root directory `./` of this repository, the `./train.py` script is used for
-training while inference can be executed with the `./inference.py` script. The
-script `./models.py` is used to construct a model of requested type and properties.
+training, while inference can be executed with the `./inference.py` script. The
+script `./models.py` is used to construct a model of the requested type and properties.
 
-The repository is structured similarly to the [NVIDIA Tacotron2 Deep Learning example](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2), so that they could be combined in more advanced use cases.
+The repository is structured similarly to the [NVIDIA Tacotron2 Deep Learning example](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2) so that they could be combined in more advanced use cases.
 
 ### Parameters
 
@@ -330,8 +331,8 @@ together with their default values that are used to train FastPitch.
 
 ### Command-line options
 
-To see the full list of available options and their descriptions, use the `-h`
-or `--help` command line option, for example:
+To review the full list of available options and their descriptions, use the `-h`
+or `--help` command-line option, for example:
 ```bash
 python train.py --help
 ```
@@ -351,7 +352,7 @@ The `./scripts/download_dataset.sh` script will automatically download and extra
 
 #### Dataset guidelines
 
-The LJSpeech dataset has 13,100 clips that amount to about 24 hours of speech of a single, female speaker. Since the original dataset does not define a train/dev/test split of the data, we provide a split in the form of three file lists:
+The LJSpeech dataset has 13,100 clips that amount to about 24 hours of speech of a single female speaker. Since the original dataset does not define a train/dev/test split of the data, we provide a split in the form of three file lists:
 ```bash
 ./filelists
 ├── ljs_audio_pitch_text_train_v3.txt
@@ -359,10 +360,10 @@ The LJSpeech dataset has 13,100 clips that amount to about 24 hours of speech of
 └── ljs_audio_pitch_text_val.txt
 ```
 
-FastPitch predicts character durations just like [FastSpeech](https://arxiv.org/abs/1905.09263) does.
+FastPitch predicts character durations just as [FastSpeech](https://arxiv.org/abs/1905.09263) does.
 FastPitch 1.1 aligns input symbols to output mel-spectrogram frames automatically and does not rely
 on any external aligning model. FastPitch training can now be started on raw waveforms
-without any pre-processing: pitch values and mel-spectrograms will be calculated on-line.
+without any pre-processing: pitch values and mel-spectrograms will be calculated online.
 
 For every mel-spectrogram frame, its fundamental frequency in Hz is estimated with
 the Probabilistic YIN algorithm.
@@ -371,8 +372,8 @@ the Probabilistic YIN algorithm.
   <img src="./img/pitch.png" alt="Pitch contour estimate" />
 </p>
 <p align="center">
-  <em>Figure 2. Pitch estimates for mel-spectrogram frames of phrase "in being comparatively"
-(in blue) averaged over characters (in red). Silent letters have duration 0 and are omitted.</em>
+  <em>Figure 2. Pitch estimates for mel-spectrogram frames of the phrase "in being comparatively"
+(in blue) averaged over characters (in red). Silent letters have a duration of 0 and are omitted.</em>
 </p>
 
 #### Multi-dataset
@@ -385,7 +386,7 @@ Follow these steps to use datasets different from the default LJSpeech dataset.
    └── wavs
    ```
 
-2. Prepare filelists with transcripts and paths to .wav files. They define training/validation split of the data (test is currently unused):
+2. Prepare filelists with transcripts and paths to .wav files. They define the training/validation split of the data (the test is currently unused):
    ```bash
    ./filelists
    ├── my-dataset_audio_text_train.txt
@@ -424,7 +425,7 @@ In order to use the prepared dataset, pass the following to the `train.py` scrip
 
 ### Training process
 
-FastPitch is trained to generate mel-spectrograms from raw text input. It uses short time Fourier transform (STFT)
+FastPitch is trained to generate mel-spectrograms from raw text input. It uses short-time Fourier transform (STFT)
 to generate target mel-spectrograms from audio waveforms to be the training targets.
 
 The training loss is averaged over an entire training epoch, whereas the
@@ -478,9 +479,132 @@ Pitch can be adjusted by transforming those pitch cues. A few simple examples ar
 
 The flags can be combined. Modify these functions directly in the `inference.py` script to gain more control over the final result.
 
-You can find all the available options by calling `python inference.py --help`.
+You can find all the available options by callng `python inference.py --help`.
 More examples are presented on the website with [samples](https://fastpitch.github.io/).
 
+### Example: Training a model on Mandarin Chinese
+
+FastPitch can easily be trained or fine-tuned on datasets in various languages.
+We present an example of training on the Mandarin Chinese dataset capable of pronouncing
+phrases in English (for example, brand names).
+For an overview of the deployment of this model in Chunghwa Telecom,
+refer to the [blogpost](https://blogs.nvidia.com.tw/2022/06/20/cht-bilingual-speech-synthesis-enables-more-realistic-interactions/) (in Chinese).
+
+
+1. Set up the repository and run a Docker container
+
+    Follow stetps 1. and 2. of the [Quick Start Guide](#quick-start-guide).
+
+2. Download the data
+
+   The dataset for this section has been provided by Chunghwa Telecom Laboratories
+   and is available for [download on NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/sf_bilingual_speech_zh_en)
+   under the CC BY-NC 4.0 license.
+
+   The dataset can be downloaded manually after signing in to NGC as `files.zip` or `SF_bilingual.zip`, depending on the method (manual or via command line).
+   Afterward, it has to be pre-processed to extract pitch for training and prepare train/dev/test filelists:
+   ```bash
+   pip install -r scripts/mandarin_chinese/requirements.txt
+   bash scripts/mandarin_chinese/prepare_dataset.sh path/to/files.zip
+   ```
+
+   The procedure should take about half an hour. If it completes successfully,
+   `./data/SF_bilingual prepared successfully.` will be written to the standard output.
+
+   After pre-processing, the dataset will be located at `./data/SF_bilingual`,
+   and training/inference filelists at `./filelists/sf_*`.
+
+3. Add support for textual inputs in the target language.
+
+   The model is trained end-to-end, and supporting a new language requires
+   to specify the input `symbol set`, `text normalization` routines,
+   and (optionally) grapheme-to-phoneme (G2P) conversion for phoneme-based synthesis.
+   Our main modifications touch the following files:
+
+   ```bash
+   ./common/text
+   ├── symbols.py
+   ├── text_processing.py
+   └── zh
+       ├── chinese.py
+       ├── mandarin_text_processing.py
+       └── pinyin_dict.txt
+   ```
+   We make small changes to `symbols.py` and `text_processing.py` and keep
+   the crucial code in the `zh` directory.
+
+   We design our Mandarin Chinese symbol set as an extension of the English
+   symbol set, appending to `symbols` lists of `_mandarin_phonemes` and `_chinese_punctuation`:
+
+   ```python
+   # common/text/symbols.py
+
+   def get_symbols(symbol_set='english_basic'):
+
+       # ...
+
+       elif symbol_set == 'english_mandarin_basic':
+           from .zh.chinese import chinese_punctuations, valid_symbols as mandarin_valid_symbols
+
+           # Prepend "#" to mandarin phonemes to ensure uniqueness (some are the same as uppercase letters):
+           _mandarin_phonemes = ['#' + s for s in mandarin_valid_symbols]
+
+           _pad = '_'
+           _punctuation = '!\'(),.:;? '
+           _chinese_punctuation = ["#" + p for p in chinese_punctuations]
+           _special = '-'
+           _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
+           symbols = list(_pad + _special + _punctuation + _letters) + _arpabet + _mandarin_phonemes + _chinese_punctuation
+   ```
+
+   Text normalization and G2P are performed by a `TextProcessing` instance. We implement Mandarin text processing
+   inside a `MandarinTextProcessing` class. For G2P, an off-shelf [pypinyin](https://github.com/mozillazg/python-pinyin) phonemizer and [the CMU Dictionary](http://www.speech.cs.cmu.edu/cgi-bin/cmudict) are used.
+   `MandarinTextProcessing` is applied to the data only if `english_mandarin_basic` symbol set is in use:
+
+   ```python
+   # common/text/text_processing.py
+
+   def get_text_processing(symbol_set, text_cleaners, p_arpabet):
+       if symbol_set in ['englh_basic', 'english_basic_lowercase', 'english_expanded']:
+           return TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
+       elif symbol_set == 'english_mandarin_basic':
+           from common.text.zh.mandarin_text_processing import MandarinTextProcessing
+           return MandarinTextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
+   ```
+
+   Note that text normalization is dependent on the target language, domain, and assumptions
+   on how normalized the input already is.
+
+4. Train the model
+
+   The `SF dataset` is rather small (4.5 h compared to 24 h in `LJSpeech-1.1`).
+   There are numerous English phrases in the transcriptions, such as technical terms
+   and proper nouns. Thus, it is beneficial to initialize model weights with
+   a pre-trained English model from NGC, using the flag `--init-from-checkpoint`.
+
+   Note that by initializing with another model, possibly trained on a different symbol set,
+   we also initialize grapheme/phoneme embedding tables. For this reason, we design
+   the `english_mandarin_basic` symbol set as an extension of `english_basic`,
+   so that the same English phonemes would retain their embeddings.
+
+   In order to train, issue
+   ```bash
+   NUM_GPUS=<available_gpus> GRAD_ACCUMULATION=<number> bash scripts/mandarin_chinese/train.sh
+   ```
+   Adjust the variables to satisfy `$NUM_GPUS x $GRAD_ACCUMULATION = 256`.
+
+   The model will be trained for 1000 epochs. Note that we have disabled mixed-precision
+   training, as we found it unstable at times on this dataset.
+
+5. Synthesize
+
+   After training, samples can be synthesized ([audio sample](./audio/com_SF_ce1514_fastpitch_waveglow.wav)):
+   ```bash
+   bash scripts/mandarin_chinese/inference.sh
+   ```
+   Paths to specific checkpoints can be supplied as env variables or changed
+   directly in the `.sh` files.
+
 ## Performance
 
 ### Benchmarking
@@ -508,7 +632,7 @@ To benchmark the training performance on a specific batch size, run:
         AMP=false NUM_GPUS=8 BS=16 GRAD_ACCUMULATION=2 EPOCHS=10 bash scripts/train.sh
     ```
 
-Each of these scripts runs for 10 epochs and for each epoch measures the
+Each of these scripts runs for 10 epochs, and for each epoch, measures the
 average number of items per second. The performance results can be read from
 the `nvlog.json` files produced by the commands.
 
@@ -529,7 +653,7 @@ To benchmark the inference performance on a specific batch size, run:
 The output log files will contain performance numbers for the FastPitch model
 (number of output mel-spectrogram frames per second, reported as `generator_frames/s w
 `)
-and for WaveGlow (number of output samples per second, reported as ` waveglow_samples/s
+and for WaveGlow (nuber of output samples per second, reported as ` waveglow_samples/s
 `).
 The `inference.py` script will run a few warm-up iterations before running the benchmark. Inference will be averaged over 100 runs, as set by the `REPEATS` env variable.
 
@@ -546,8 +670,8 @@ Our results were obtained by running the `./platform/DGXA100_FastPitch_{AMP,TF32
 
 | Loss (Model/Epoch)   |    50 |   250 |   500 |   750 |  1000 |  1250 |  1500 |
 |:---------------------|------:|------:|------:|------:|------:|------:|------:|
-| FastPitch AMP        | 3.35 |  2.89 |  2.79 |  2.71 |   2.68 |   2.64 |   2.61 |
-| FastPitch TF32       | 3.37 |  2.88 |  2.78 |  2.71 |   2.68 |   2.63 |   2.61 |
+| FastPitch AMP        | 3.35 |  2.89 |  2.79 |  2.71 |   2.68 |  2.64 |  2.61 |
+| FastPitch TF32       | 3.37 |  2.88 |  2.78 |  2.71 |   2.68 |  2.63 |  2.61 |
 
 ##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
 
@@ -558,8 +682,8 @@ All of the results were produced using the `train.py` script as described in the
 
 | Loss (Model/Epoch)   |    50 |   250 |   500 |   750 |  1000 |  1250 |  1500 |
 |:---------------------|------:|------:|------:|------:|------:|------:|------:|
-| FastPitch AMP        | 3.38 |  2.88 |  2.79 |  2.71 |   2.68 |   2.64 |   2.61 |
-| FastPitch FP32       | 3.38 |  2.89 |  2.80 |  2.71 |   2.68 |   2.65 |   2.62 |
+| FastPitch AMP        | 3.38 |  2.88 |  2.79 |  2.71 |   2.68 |  2.64 |  2.61 |
+| FastPitch FP32       | 3.38 |  2.89 |  2.80 |  2.71 |   2.68 |  2.65 |  2.62 |
 
 
 <div style="text-align:center" align="center">
@@ -621,7 +745,7 @@ Note that most of the quality is achieved after the initial 1000 epochs.
 The following tables show inference statistics for the FastPitch and WaveGlow
 text-to-speech system, gathered from 100 inference runs. Latency is measured from the start of FastPitch inference to
 the end of WaveGlow inference. Throughput is measured
-as the number of generated audio samples per second at 22KHz. RTF is the real-time factor which denotes the number of seconds of speech generated in a second of wall-clock time, per input utterance.
+as the number of generated audio samples per second at 22KHz. RTF is the real-time factor that denotes the number of seconds of speech generated in a second of wall-clock time per input utterance.
 The used WaveGlow model is a 256-channel model.
 
 Note that performance numbers are related to the length of input. The numbers reported below were taken with a moderate length of 128 characters. Longer utterances yield higher RTF, as the generator is fully parallel.
@@ -734,7 +858,7 @@ FastPitch + WaveGlow (TorchScript, denoising)
 
 ## Release notes
 
-We're constantly refining and improving our performance on AI and HPC workloads even on the same hardware with frequent updates to our software stack. For our latest performance data please refer to these pages for AI and HPC benchmarks.
+We're constantly refining and improving our performance on AI and HPC workloads even on the same hardware, with frequent updates to our software stack. For our latest performance data, refer to these pages for AI and HPC benchmarks.
 
 ### Changelog
 
@@ -769,4 +893,4 @@ May 2020
 
 ### Known issues
 
-There are no known issues with this model with this model.
+There are no known issues with this model.

BIN
PyTorch/SpeechSynthesis/FastPitch/audio/com_SF_ce1514_fastpitch_waveglow.wav


+ 13 - 1
PyTorch/SpeechSynthesis/FastPitch/common/text/symbols.py

@@ -31,6 +31,18 @@ def get_symbols(symbol_set='english_basic'):
         _accented = 'áçéêëñöøćž'
         _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
         symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
+    elif symbol_set == 'english_mandarin_basic':
+        from .zh.chinese import chinese_punctuations, valid_symbols as mandarin_valid_symbols
+
+        # Prepend "#" to mandarin phonemes to ensure uniqueness (some are the same as uppercase letters):
+        _mandarin_phonemes = ['#' + s for s in mandarin_valid_symbols]
+
+        _pad = '_'
+        _punctuation = '!\'(),.:;? '
+        _chinese_punctuation = ["#" + p for p in chinese_punctuations]
+        _special = '-'
+        _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
+        symbols = list(_pad + _special + _punctuation + _letters) + _arpabet + _mandarin_phonemes + _chinese_punctuation
     else:
         raise Exception("{} symbol set does not exist".format(symbol_set))
 
@@ -38,7 +50,7 @@ def get_symbols(symbol_set='english_basic'):
 
 
 def get_pad_idx(symbol_set='english_basic'):
-    if symbol_set in {'english_basic', 'english_basic_lowercase'}:
+    if symbol_set in {'english_basic', 'english_basic_lowercase', 'english_mandarin_basic'}:
         return 0
     else:
         raise Exception("{} symbol set not used yet".format(symbol_set))

+ 10 - 0
PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py

@@ -162,3 +162,13 @@ class TextProcessing(object):
             return text_encoded, text_clean, text_arpabet
 
         return text_encoded
+
+
+def get_text_processing(symbol_set, text_cleaners, p_arpabet):
+    if symbol_set in ['english_basic', 'english_basic_lowercase', 'english_expanded']:
+        return TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
+    elif symbol_set == 'english_mandarin_basic':
+        from common.text.zh.mandarin_text_processing import MandarinTextProcessing
+        return MandarinTextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
+    else:
+        raise ValueError(f"No TextProcessing for symbol set {symbol_set} unknown.")

+ 81 - 0
PyTorch/SpeechSynthesis/FastPitch/common/text/zh/chinese.py

@@ -0,0 +1,81 @@
+# *****************************************************************************
+#  Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
+#
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions are met:
+#      * Redistributions of source code must retain the above copyright
+#        notice, this list of conditions and the following disclaimer.
+#      * Redistributions in binary form must reproduce the above copyright
+#        notice, this list of conditions and the following disclaimer in the
+#        documentation and/or other materials provided with the distribution.
+#      * Neither the name of the NVIDIA CORPORATION nor the
+#        names of its contributors may be used to endorse or promote products
+#        derived from this software without specific prior written permission.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+#  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+#  DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
+#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+#  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+#  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+#  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+# *****************************************************************************
+
+import re
+
+from pypinyin import lazy_pinyin, Style
+
+
+valid_symbols = ['^', 'A', 'AI', 'AN', 'ANG', 'AO', 'B', 'C', 'CH', 'D', 
+                 'E', 'EI', 'EN', 'ENG', 'ER', 'F', 'G', 'H', 'I', 'IE', 
+                 'IN', 'ING', 'IU', 'J', 'K', 'L', 'M', 'N', 'O', 'ONG', 
+                 'OU', 'P', 'Q', 'R', 'S', 'SH', 'T', 'U', 'UI', 'UN', 
+                 'V', 'VE', 'VN', 'W', 'X', 'Y', 'Z', 'ZH']
+tones = ['1', '2', '3', '4', '5']
+chinese_punctuations = ",。?!;:、‘’“”()【】「」《》"
+valid_symbols += tones
+
+
+def load_pinyin_dict(path="common/text/zh/pinyin_dict.txt"):
+    with open(path) as f:
+        return {l.split()[0]: l.split()[1:] for l in f}
+
+pinyin_dict = load_pinyin_dict()
+
+
+def is_chinese(text):
+    return u'\u4e00' <= text[0] <= u'\u9fff' or text[0] in chinese_punctuations
+
+
+def split_text(text):
+    regex = r'([\u4e00-\u9fff' + chinese_punctuations + ']+)'
+    return re.split(regex, text)
+
+
+def chinese_text_to_symbols(text):
+    symbols = []
+    phonemes_and_tones = ""
+    
+    # convert text to mandarin pinyin sequence
+    # ignore polyphonic words as it has little effect on training
+    pinyin_seq = lazy_pinyin(text, style=Style.TONE3)
+    
+    for item in pinyin_seq:
+        if item in chinese_punctuations:
+            symbols += [item]
+            phonemes_and_tones += ' ' + item
+            continue
+        if not item[-1].isdigit():
+           item += '5'
+        item, tone = item[:-1], item[-1]
+        phonemes = pinyin_dict[item.upper()]
+        symbols += phonemes
+        symbols += [tone]
+        
+        phonemes_and_tones += '{' + ' '.join(phonemes + [tone]) + '}'
+    
+    return symbols, phonemes_and_tones

+ 74 - 0
PyTorch/SpeechSynthesis/FastPitch/common/text/zh/mandarin_text_processing.py

@@ -0,0 +1,74 @@
+import re
+import numpy as np
+from .chinese import split_text, is_chinese, chinese_text_to_symbols
+from ..text_processing import TextProcessing
+
+
+class MandarinTextProcessing(TextProcessing):
+    def __init__(self, symbol_set, cleaner_names, p_arpabet=0.0,
+                 handle_arpabet='word', handle_arpabet_ambiguous='ignore',
+                 expand_currency=True):
+        
+        super().__init__(symbol_set, cleaner_names, p_arpabet, handle_arpabet, 
+                       handle_arpabet_ambiguous, expand_currency)
+
+
+    def sequence_to_text(self, sequence):
+        result = ''
+        
+        tmp = ''
+        for symbol_id in sequence:
+            if symbol_id in self.id_to_symbol:
+                s = self.id_to_symbol[symbol_id]
+                # Enclose ARPAbet and mandarin phonemes back in curly braces:
+                if len(s) > 1 and s[0] == '@':
+                    s = '{%s}' % s[1:]
+                    result += s
+                elif len(s) > 1 and s[0] == '#' and s[1].isdigit(): # mandarin tone
+                    tmp += s[1] + '} '
+                    result += tmp
+                    tmp = ''
+                elif len(s) > 1 and s[0] == '#' and (s[1].isalpha() or s[1] == '^'): # mandarin phoneme
+                    if tmp == '':
+                        tmp += ' {' + s[1:] + ' '
+                    else:
+                        tmp += s[1:] + ' '
+                elif len(s) > 1 and s[0] == '#': # chinese punctuation
+                    s = s[1]
+                    result += s
+                else:
+                    result += s
+                    
+        return result.replace('}{', ' ').replace('  ', ' ')
+
+    
+    def chinese_symbols_to_sequence(self, symbols):
+        return self.symbols_to_sequence(['#' + s for s in symbols])
+
+
+    def encode_text(self, text, return_all=False):
+        # split the text into English and Chinese segments
+        segments = [segment for segment in split_text(text) if segment != ""]
+        
+        text_encoded = []
+        text_clean = ""
+        text_arpabet = ""
+        
+        for segment in segments:
+            if is_chinese(segment[0]): # process the Chinese segment
+                chinese_symbols, segment_arpabet = chinese_text_to_symbols(segment)
+                segment_encoded = self.chinese_symbols_to_sequence(chinese_symbols)
+                segment_clean = segment
+                segment_encoded = segment_encoded
+            else: # process the English segment
+                segment_encoded, segment_clean, segment_arpabet = \
+                    super().encode_text(segment, return_all=True)
+            
+            text_encoded += segment_encoded
+            text_clean += segment_clean
+            text_arpabet += segment_arpabet
+
+        if return_all:
+            return text_encoded, text_clean, text_arpabet
+
+        return text_encoded

+ 412 - 0
PyTorch/SpeechSynthesis/FastPitch/common/text/zh/pinyin_dict.txt

@@ -0,0 +1,412 @@
+NIN	N IN
+FA	F A
+BAI	B AI
+YIN	Y IN
+DE	D E
+SHEN	SH EN
+TAN	T AN
+PAO	P AO
+WENG	W ENG
+LAN	L AN
+CHUAN	CH U AN
+SEI	S EI
+DANG	D ANG
+XUE	X VE
+YUAN	Y V AN
+HU	H U
+CUAN	C U AN
+BO	B O
+SHAI	SH AI
+CHUI	CH UI
+SHOU	SH OU
+QIU	Q IU
+SONG	S ONG
+KAI	K AI
+LING	L ING
+SUO	S U O
+ZHUAI	ZH U AI
+ZHEN	ZH EN
+GENG	G ENG
+YAN	Y AN
+CU	C U
+ZHUA	ZH U A
+MA	M A
+SOU	S OU
+GOU	G OU
+PU	P U
+GUA	G U A
+RONG	R ONG
+JIAN	J I AN
+FOU	F OU
+FO	F O
+ZHUAN	ZH U AN
+DIU	D IU
+TIAN	T I AN
+QUN	Q VN
+NE	N E
+LIN	L IN
+QIE	Q IE
+LANG	L ANG
+CAO	C AO
+PANG	P ANG
+GAN	G AN
+KUI	K UI
+ROU	R OU
+NING	N ING
+NOU	N OU
+CUI	C UI
+NA	N A
+MING	M ING
+JUAN	J V AN
+NIAN	N I AN
+JIONG	J I ONG
+LE	L E
+GEN	G EN
+CHUO	CH U O
+SANG	S ANG
+MANG	M ANG
+GANG	G ANG
+SHENG	SH ENG
+KENG	K ENG
+ANG	^ ANG
+ZHONG	ZH ONG
+PEI	P EI
+LO	L O
+BEN	B EN
+SAN	S AN
+WAI	W AI
+BA	B A
+ZEI	Z EI
+BANG	B ANG
+MENG	M ENG
+HA	H A
+SHAO	SH AO
+RENG	R ENG
+XUAN	X V AN
+GUAI	G U AI
+QUAN	Q V AN
+DIE	D IE
+CEN	C EN
+QIONG	Q I ONG
+QIAO	Q I AO
+NAN	N AN
+CANG	C ANG
+NANG	N ANG
+LA	L A
+KU	K U
+KAO	K AO
+XI	X I
+MO	M O
+CHAN	CH AN
+DUO	D U O
+DIAO	D I AO
+HUN	H UN
+LOU	L OU
+HANG	H ANG
+CENG	C ENG
+ZHI	ZH I
+RUAN	R U AN
+QIANG	Q I ANG
+MIU	M IU
+WO	W O
+GEI	G EI
+EI	^ EI
+CHAI	CH AI
+ZHUI	ZH UI
+CHU	CH U
+YONG	Y ONG
+SHUO	SH U O
+DING	D ING
+CHE	CH E
+YO	Y O
+PENG	P ENG
+RANG	R ANG
+BU	B U
+NIU	N IU
+KE	K E
+MI	M I
+GUAN	G U AN
+RE	R E
+NI	N I
+TI	T I
+DIA	D I A
+NUO	N U O
+WANG	W ANG
+QIAN	Q I AN
+LUO	L U O
+YA	Y A
+CI	C I
+GUN	G UN
+GAO	G AO
+DOU	D OU
+DAI	D AI
+BAO	B AO
+BIN	B IN
+NAI	N AI
+SE	S E
+PA	P A
+ZAO	Z AO
+AO	^ AO
+NIE	N IE
+BENG	B ENG
+ZHU	ZH U
+JU	J V
+XIU	X IU
+XIAN	X I AN
+RUI	R UI
+SAI	S AI
+SHUANG	SH U ANG
+SHUAI	SH U AI
+HEN	H EN
+OU	^ OU
+HUA	H U A
+LONG	L ONG
+ZI	Z I
+SHE	SH E
+JUN	J VN
+YE	Y E
+TUI	T UI
+GUANG	G U ANG
+MAN	M AN
+LAI	L AI
+ZHUN	ZH UN
+CHUANG	CH U ANG
+ZUI	Z UI
+SU	S U
+TE	T E
+TAO	T AO
+CONG	C ONG
+TONG	T ONG
+HENG	H ENG
+ZUO	Z U O
+LU	L U
+BAN	B AN
+PIAO	P I AO
+XIANG	X I ANG
+LIANG	L I ANG
+ZU	Z U
+NIANG	N I ANG
+LIU	L IU
+BIE	B IE
+CHA	CH A
+YANG	Y ANG
+LVE	L VE
+LENG	L ENG
+KOU	K OU
+AN	^ AN
+CHUN	CH UN
+ZAI	Z AI
+DONG	D ONG
+SHI	SH I
+CHAO	CH AO
+ZHAI	ZH AI
+RI	R I
+HUAI	H U AI
+TOU	T OU
+SENG	S ENG
+GUO	G U O
+NENG	N ENG
+ZUN	Z UN
+XIONG	X I ONG
+ZEN	Z EN
+TANG	T ANG
+BIAN	B I AN
+QU	Q V
+QI	Q I
+ZHAN	ZH AN
+JIAO	J I AO
+CHENG	CH ENG
+CHONG	CH ONG
+KEI	K EI
+MEI	M EI
+LV	L V
+SHUA	SH U A
+CA	C A
+DENG	D ENG
+TING	T ING
+YAO	Y AO
+TIAO	T I AO
+ME	M E
+CE	C E
+ZUAN	Z U AN
+SEN	S EN
+O	^ O
+ZENG	Z ENG
+RAO	R AO
+WEI	W EI
+KUAN	K U AN
+PING	P ING
+MAI	M AI
+HUAN	H U AN
+DEN	D EN
+BING	B ING
+QING	Q ING
+PIN	P IN
+GAI	G AI
+LI	L I
+ZHENG	ZH ENG
+ZAN	Z AN
+BEI	B EI
+SHU	SH U
+MU	M U
+KUO	K U O
+JIE	J IE
+CHUAI	CH U AI
+FAN	F AN
+PI	P I
+SHUI	SH UI
+YING	Y ING
+QIN	Q IN
+SHA	SH A
+KANG	K ANG
+CHEN	CH EN
+JIANG	J I ANG
+RAN	R AN
+LUAN	L U AN
+HEI	H EI
+XING	X ING
+WAN	W AN
+TA	T A
+XU	X V
+TENG	T ENG
+ZA	Z A
+KEN	K EN
+DAN	D AN
+TU	T U
+KUANG	K U ANG
+JING	J ING
+REN	R EN
+CHOU	CH OU
+KUA	K U A
+HE	H E
+DAO	D AO
+NEI	N EI
+KUAI	K U AI
+HAO	H AO
+MIAO	M I AO
+YI	Y I
+ZHAO	ZH AO
+TUO	T U O
+ZHEI	ZH EI
+FU	F U
+FEN	F EN
+JIA	J I A
+WA	W A
+CUO	C U O
+WU	W U
+MEN	M EN
+XUN	X VN
+MOU	M OU
+SHAN	SH AN
+PAI	P AI
+GONG	G ONG
+NONG	N ONG
+COU	C OU
+KONG	K ONG
+HUO	H U O
+HUANG	H U ANG
+JIU	J IU
+HONG	H ONG
+MIE	M IE
+HUI	H UI
+WEN	W EN
+ZHUO	ZH U O
+MIAN	M I AN
+BI	B I
+ZE	Z E
+YUN	Y VN
+GA	G A
+SUAN	S U AN
+SUN	S UN
+MAO	M AO
+XIA	X I A
+KA	K A
+NAO	N AO
+TIE	T IE
+GE	G E
+GUI	G UI
+LAO	L AO
+ZOU	Z OU
+SAO	S AO
+PO	P O
+JIN	J IN
+DUAN	D U AN
+DU	D U
+RUN	R UN
+YUE	Y VE
+DUN	D UN
+A	^ A
+PIE	P IE
+SHANG	SH ANG
+XIN	X IN
+CAN	C AN
+PAN	P AN
+LIE	L IE
+QIA	Q I A
+GU	G U
+ZHE	ZH E
+ZONG	Z ONG
+DIAN	D I AN
+LIA	L I A
+FENG	F ENG
+JUE	J VE
+LIAO	L I AO
+SA	S A
+TAI	T AI
+LEI	L EI
+SHUN	SH UN
+HAI	H AI
+NEN	N EN
+MIN	M IN
+PIAN	P I AN
+CHI	CH I
+CHANG	CH ANG
+NIAO	N I AO
+JI	J I
+TEI	T EI
+FANG	F ANG
+POU	P OU
+QUE	Q VE
+ZHOU	ZH OU
+NV	N V
+ER	^ ER
+YU	Y V
+XIE	X IE
+FAI	F AI
+EN	^ EN
+NVE	N VE
+KAN	K AN
+LUN	L UN
+ZHUANG	ZH U ANG
+HAN	H AN
+NG	N EN
+DI	D I
+SHEI	SH EI
+RUO	R U O
+KUN	K UN
+DUI	D UI
+TUAN	T U AN
+ZANG	Z ANG
+CUN	C UN
+YOU	Y OU
+SUI	S UI
+DEI	D EI
+RU	R U
+NU	N U
+ZHANG	ZH ANG
+BIAO	B I AO
+NUAN	N U AN
+SHUAN	SH U AN
+XIAO	X I AO
+TUN	T UN
+E	^ E
+SI	S I
+HOU	H OU
+FEI	F EI
+ZHA	ZH A
+CAI	C AI
+KIU	K IU
+DA	D A
+PEN	P EN
+LIAN	L I AN
+AI	^ AI

+ 21 - 0
PyTorch/SpeechSynthesis/FastPitch/common/utils.py

@@ -137,6 +137,27 @@ def get_padding(kernel_size, dilation=1):
     return int((kernel_size*dilation - dilation)/2)
 
 
+def load_pretrained_weights(model, ckpt_fpath):
+    model = getattr(model, "module", model)
+    weights = torch.load(ckpt_fpath, map_location="cpu")["state_dict"]
+    weights = {re.sub("^module.", "", k): v for k, v in weights.items()}
+
+    ckpt_emb = weights["encoder.word_emb.weight"]
+    new_emb = model.state_dict()["encoder.word_emb.weight"]
+
+    ckpt_vocab_size = ckpt_emb.size(0)
+    new_vocab_size = new_emb.size(0)
+    if ckpt_vocab_size != new_vocab_size:
+        print("WARNING: Resuming from a checkpoint with a different size "
+              "of embedding table. For best results, extend the vocab "
+              "and ensure the common symbols' indices match.")
+        min_len = min(ckpt_vocab_size, new_vocab_size)
+        weights["encoder.word_emb.weight"] = ckpt_emb if ckpt_vocab_size > new_vocab_size else new_emb
+        weights["encoder.word_emb.weight"][:min_len] = ckpt_emb[:min_len]
+
+    model.load_state_dict(weights)
+
+
 class AttrDict(dict):
     def __init__(self, *args, **kwargs):
         super(AttrDict, self).__init__(*args, **kwargs)

+ 10 - 2
PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py

@@ -38,7 +38,7 @@ from scipy import ndimage
 from scipy.stats import betabinom
 
 import common.layers as layers
-from common.text.text_processing import TextProcessing
+from common.text.text_processing import get_text_processing
 from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
 
 
@@ -179,7 +179,7 @@ class TTSDataset(torch.utils.data.Dataset):
             'Only 0.0 and 1.0 p_arpabet is currently supported. '
             'Variable probability breaks caching of betabinomial matrices.')
 
-        self.tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
+        self.tp = get_text_processing(symbol_set, text_cleaners, p_arpabet)
         self.n_speakers = n_speakers
         self.pitch_tmp_dir = pitch_online_dir
         self.f0_method = pitch_online_method
@@ -325,6 +325,14 @@ class TTSDataset(torch.utils.data.Dataset):
         return pitch_mel
 
 
+def ensure_disjoint(*tts_datasets):
+    paths = [set(list(zip(*d.audiopaths_and_text))[0]) for d in tts_datasets]
+    assert sum(len(p) for p in paths) == len(set().union(*paths)), (
+        "Your datasets (train, val) are not disjoint. "
+        "Review filelists and restart training."
+    )
+
+
 class TTSCollate:
     """Zero-pads model inputs and targets based on number of frames per step"""
 

+ 2 - 2
PyTorch/SpeechSynthesis/FastPitch/inference.py

@@ -35,7 +35,7 @@ from common import gpu_affinity
 from common.tb_dllogger import (init_inference_metadata, stdout_metric_format,
                                 unique_log_fpath)
 from common.text import cmudict
-from common.text.text_processing import TextProcessing
+from common.text.text_processing import get_text_processing
 from common.utils import l2_promote
 from fastpitch.pitch_transform import pitch_transform_custom
 from hifigan.data_function import MAX_WAV_VALUE, mel_spectrogram
@@ -161,7 +161,7 @@ def load_fields(fpath):
 def prepare_input_sequence(fields, device, symbol_set, text_cleaners,
                            batch_size=128, dataset=None, load_mels=False,
                            load_pitch=False, p_arpabet=0.0):
-    tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
+    tp = get_text_processing(symbol_set, text_cleaners, p_arpabet)
 
     fields['text'] = [torch.LongTensor(tp.encode_text(text))
                       for text in fields['text']]

+ 20 - 0
PyTorch/SpeechSynthesis/FastPitch/phrases/phrase_bilingual.txt

@@ -0,0 +1,20 @@
+nokia有跟facebook簽約。
+讓net backup同時強化重覆刪除和資料搜尋功能。
+classic仍有一定的價值。
+資料代管商ball的虛擬化工具。
+針對vmware虛擬化環境的基本功能。
+這跟微軟bing有何關連?
+由ben toyota所寫的the accidental billionaires。
+v d s技術提供一個如同伺服器般的獨立操作系統環境。
+專利設計通過美國f d a認證與臨床測試。
+你可直接把圖片丟進wave訊息中。
+由前英國陸軍軍官neil laughton領軍。
+這次android版也沿用了同樣的輸入法。
+facebook新註冊用戶。
+現在android跟iphone都支援這項功能。
+o r g的經理maxim weinstein。
+但本來就甚少舉辦活動的kingston金士頓。
+touchstone充電系統是還蠻酷的技術。
+雖然caspian市佔率不斷下滑。
+第一隻中文化的google android手機。
+因為google自家已經有android的同級競爭產品。

+ 7 - 1
PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py

@@ -77,6 +77,11 @@ def parse_args(parser):
     # Performance
     parser.add_argument('-b', '--batch-size', default=1, type=int)
     parser.add_argument('--n-workers', type=int, default=16)
+    
+    # Language
+    parser.add_argument('--symbol_set', default='english_basic',
+                        choices=['english_basic', 'english_mandarin_basic'],
+                        help='Symbols in the dataset')
     return parser
 
 
@@ -101,7 +106,7 @@ def main():
 
     if args.save_alignment_priors:
         Path(args.dataset_path, 'alignment_priors').mkdir(parents=False, exist_ok=True)
-
+        
     for filelist in args.wav_text_filelists:
 
         print(f'Processing {filelist}...')
@@ -111,6 +116,7 @@ def main():
             filelist,
             text_cleaners=['english_cleaners_v2'],
             n_mel_channels=args.n_mel_channels,
+            symbol_set=args.symbol_set,
             p_arpabet=0.0,
             n_speakers=args.n_speakers,
             load_mel_from_disk=False,

+ 1 - 0
PyTorch/SpeechSynthesis/FastPitch/scripts/inference_benchmark.sh

@@ -9,6 +9,7 @@ set -a
 : ${WARMUP:=64}
 : ${REPEATS:=500}
 : ${AMP:=false}
+: ${CUDNN_BENCHMARK:=true}
 
 for BATCH_SIZE in $BS_SEQUENCE ; do
     LOG_FILE="$OUTPUT_DIR"/perf-infer_amp-${AMP}_bs${BATCH_SIZE}.json

+ 2 - 3
PyTorch/SpeechSynthesis/FastPitch/scripts/inference_example.sh

@@ -12,6 +12,7 @@ export TORCH_CUDNN_V8_API_ENABLED=1
 : ${REPEATS:=1}
 : ${CPU:=false}
 : ${PHONE:=true}
+: ${CUDNN_BENCHMARK:=false}
 
 # Paths to pre-trained models downloadable from NVIDIA NGC (LJSpeech-1.1)
 FASTPITCH_LJ="pretrained_models/fastpitch/nvidia_fastpitch_210824.pt"
@@ -54,9 +55,7 @@ mkdir -p "$OUTPUT_DIR"
 
 echo -e "\nAMP=$AMP, batch_size=$BATCH_SIZE\n"
 
-ARGS=""
 ARGS+=" --cuda"
-# ARGS+=" --cudnn-benchmark"  # Enable for benchmarking or long operation
 ARGS+=" --dataset-path $DATASET_DIR"
 ARGS+=" -i $FILELIST"
 ARGS+=" -o $OUTPUT_DIR"
@@ -67,12 +66,12 @@ ARGS+=" --warmup-steps $WARMUP"
 ARGS+=" --repeats $REPEATS"
 ARGS+=" --speaker $SPEAKER"
 [ "$CPU" = false ]        && ARGS+=" --cuda"
-[ "$CPU" = false ]        && ARGS+=" --cudnn-benchmark"
 [ "$AMP" = true ]         && ARGS+=" --amp"
 [ "$TORCHSCRIPT" = true ] && ARGS+=" --torchscript"
 [ -n "$HIFIGAN" ]         && ARGS+=" --hifigan $HIFIGAN"
 [ -n "$WAVEGLOW" ]        && ARGS+=" --waveglow $WAVEGLOW"
 [ -n "$FASTPITCH" ]       && ARGS+=" --fastpitch $FASTPITCH"
 [ "$PHONE" = true ]       && ARGS+=" --p-arpabet 1.0"
+[[ "$CUDNN_BENCHMARK" = true && "$CPU" = false ]] && ARGS+=" --cudnn-benchmark"
 
 python inference.py $ARGS "$@"

+ 5 - 0
PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/README.md

@@ -0,0 +1,5 @@
+Scripts in this directory are meant for training a Mandarin Chinese model
+on a publicly available [SF Bilingual Speech in Chinese and English](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/sf_bilingual_speech_zh_en)
+dataset.
+
+A step-by-step guide is provided in the general [README.md](../../README.md#example-training-a-model-on-mandarin-chinese).

+ 18 - 0
PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/inference.sh

@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+
+set -a
+
+bash scripts/download_models.sh waveglow
+
+PYTHONIOENCODING=utf-8
+
+: ${BATCH_SIZE:=20}
+: ${FILELIST:="filelists/sf_test.tsv"}
+: ${FASTPITCH:="output_sf/FastPitch_checkpoint_1000.pt"}
+: ${OUTPUT_DIR:="output_sf/audio_sf_test_fastpitch1000ep_waveglow_denoise0.01"}
+
+# Disable HiFi-GAN and enable WaveGlow
+HIFIGAN=""
+WAVEGLOW="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"
+
+bash scripts/inference_example.sh "$@"

+ 57 - 0
PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/prepare_dataset.sh

@@ -0,0 +1,57 @@
+set -e
+
+URL="https://catalog.ngc.nvidia.com/orgs/nvidia/resources/sf_bilingual_speech_zh_en"
+
+if [[ $1 == "" ]]; then
+    echo -e "\n**************************************************************************************"
+    echo -e "\nThe dataset needs to be downloaded manually from NGC by a signed in user:"
+    echo -e "\n\t$URL\n"
+    echo -e "Save as files.zip and run the script:"
+    echo -e "\n\tbash $0 path/to/files.zip\n"
+    echo -e "**************************************************************************************\n"
+    exit 0
+fi
+
+mkdir -p data
+
+echo "Extracting the data..."
+# The dataset downloaded from NGC might be double-zipped as:
+#     SF_bilingual -> SF_bilingual.zip -> files.zip
+if [ $(basename $1) == "files.zip" ]; then
+    unzip $1 -d data/
+    unzip data/SF_bilingual.zip -d data/
+elif [ $(basename $1) == "SF_bilingual.zip" ]; then
+    unzip $1 -d data/
+else
+    echo "Unknown input file. Supply either files.zip or SF_bilingual.zip as the first argument:"
+    echo "\t$0 [files.zip|SF_bilingual.zip]"
+    exit 1
+fi
+echo "Extracting the data... OK"
+
+# Make filelists
+echo "Generating filelists..."
+python scripts/mandarin_chinese/split_sf.py data/SF_bilingual/text_SF.txt filelists/
+echo "Generating filelists... OK"
+
+# Extract pitch (optionally extract mels)
+set -e
+
+export PYTHONIOENCODING=utf-8
+
+: ${DATA_DIR:=data/SF_bilingual}
+: ${ARGS="--extract-mels"}
+
+echo "Extracting pitch..."
+python prepare_dataset.py \
+    --wav-text-filelists filelists/sf_audio_text.txt \
+    --n-workers 16 \
+    --batch-size 1 \
+    --dataset-path $DATA_DIR \
+    --extract-pitch \
+    --f0-method pyin \
+    --symbol_set english_mandarin_basic \
+    $ARGS
+
+echo "Extracting pitch... OK"
+echo "./data/SF_bilingual prepared successfully."

+ 1 - 0
PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/requirements.txt

@@ -0,0 +1 @@
+pypinyin==0.47.1

+ 94 - 0
PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/split_sf.py

@@ -0,0 +1,94 @@
+# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#           http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+from pathlib import Path
+
+
+# Define val and test; the remaining ones will be train IDs
+val_ids = {
+    'com_SF_ce227', 'com_SF_ce832', 'com_SF_ce912','com_SF_ce979',
+    'com_SF_ce998', 'com_SF_ce1045', 'com_SF_ce1282','com_SF_ce1329',
+    'com_SF_ce1350', 'com_SF_ce1376', 'com_SF_ce1519','com_SF_ce1664',
+    'com_SF_ce1777', 'com_SF_ce1843', 'com_SF_ce2017','com_SF_ce2042',
+    'com_SF_ce2100', 'com_SF_ce2251', 'com_SF_ce2443','com_SF_ce2566',
+}
+
+test_ids = {
+    'com_SF_ce161', 'com_SF_ce577', 'com_SF_ce781', 'com_SF_ce814',
+    'com_SF_ce1042', 'com_SF_ce1089', 'com_SF_ce1123', 'com_SF_ce1425',
+    'com_SF_ce1514', 'com_SF_ce1577', 'com_SF_ce1780', 'com_SF_ce1857',
+    'com_SF_ce1940', 'com_SF_ce2051', 'com_SF_ce2181', 'com_SF_ce2258',
+    'com_SF_ce2406', 'com_SF_ce2512', 'com_SF_ce2564', 'com_SF_ce2657'
+}
+
+
+def generate(fpath, ids_text, pitch=True, text=True):
+
+    with open(fpath, 'w') as f:
+        for id_, txt in ids_text.items():
+            row = f"wavs/{id_}.wav"
+            row += "|" + f"pitch/{id_}.pt" if pitch else ""
+            row += "|" + txt if text else ""
+            f.write(row + "\n")
+
+
+def generate_inference_tsv(fpath, ids_text):
+
+    with open(fpath, 'w') as f:
+        f.write("output\ttext\n")
+        for id_, txt in ids_text.items():
+            f.write(f"{id_}.wav\t{txt}\n")
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description='SF bilingual dataset filelists generator')
+    parser.add_argument('transcripts', type=Path, default='./text_SF.txt',
+                        help='Path to LJSpeech dataset metadata')
+    parser.add_argument('output_dir', default='data/filelists', type=Path,
+                        help='Directory to generate filelists to')
+    args = parser.parse_args()
+
+    with open(args.transcripts) as f:
+        # A dict of ID:transcript pairs
+        transcripts = dict(line.replace("\ufeff", "").replace("-", "-").strip().split(' ', 1)
+                           for line in f)
+    transcripts = {id_.replace("com_DL", "com_SF"): text.lower()
+                   for id_, text in transcripts.items()}
+
+    val_ids_text = {id_: transcripts[id_] for id_ in val_ids}
+    test_ids_text = {id_: transcripts[id_] for id_ in test_ids}
+    train_ids_text = {id_: transcripts[id_] for id_ in transcripts
+                      if id_ not in test_ids and id_ not in val_ids}
+
+    prefix = Path(args.output_dir, "sf_audio_pitch_text_")
+    generate(str(prefix) + "val.txt", val_ids_text)
+    generate(str(prefix) + "test.txt", test_ids_text)
+    generate(str(prefix) + "train.txt", train_ids_text)
+
+    prefix = Path(args.output_dir, "sf_audio_")
+    generate(str(prefix) + "val.txt", val_ids_text, False, False)
+    generate(str(prefix) + "test.txt", test_ids_text, False, False)
+    generate(str(prefix) + "train.txt", train_ids_text, False, False)
+
+    # train + val + test for pre-processing
+    generate(Path(args.output_dir, "sf_audio_text.txt"),
+             {**val_ids_text, **test_ids_text, **train_ids_text}, False, True)
+
+    generate_inference_tsv(Path(args.output_dir, "sf_test.tsv"), test_ids_text)
+
+
+if __name__ == '__main__':
+    main()

+ 21 - 0
PyTorch/SpeechSynthesis/FastPitch/scripts/mandarin_chinese/train.sh

@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+
+set -a
+
+PYTHONIOENCODING=utf-8
+
+# Mandarin & English bilingual
+ARGS+=" --symbol-set english_mandarin_basic"
+
+# Initialize weights with a pre-trained English model
+bash scripts/download_models.sh fastpitch
+ARGS+=" --init-from-checkpoint pretrained_models/fastpitch/nvidia_fastpitch_210824.pt"
+
+AMP=false  # FP32 training for better stability
+
+: ${DATASET_PATH:=data/SF_bilingual}
+: ${TRAIN_FILELIST:=filelists/sf_audio_pitch_text_train.txt}
+: ${VAL_FILELIST:=filelists/sf_audio_pitch_text_val.txt}
+: ${OUTPUT_DIR:=./output_sf}
+
+bash scripts/train.sh $ARGS "$@"

+ 13 - 12
PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh

@@ -42,7 +42,7 @@ GBS=$(($NUM_GPUS * $BATCH_SIZE * $GRAD_ACCUMULATION))
 echo -e "\nAMP=$AMP, ${NUM_GPUS}x${BATCH_SIZE}x${GRAD_ACCUMULATION}" \
         "(global batch size ${GBS})\n"
 
-ARGS=""
+# ARGS=""
 ARGS+=" --cuda"
 ARGS+=" -o $OUTPUT_DIR"
 ARGS+=" --log-file $LOG_FILE"
@@ -54,7 +54,7 @@ ARGS+=" --grad-accumulation $GRAD_ACCUMULATION"
 ARGS+=" --optimizer lamb"
 ARGS+=" --epochs $EPOCHS"
 ARGS+=" --epochs-per-checkpoint $EPOCHS_PER_CHECKPOINT"
-ARGS+=" --resume"
+
 ARGS+=" --warmup-steps $WARMUP_STEPS"
 ARGS+=" -lr $LEARNING_RATE"
 ARGS+=" --weight-decay 1e-6"
@@ -70,16 +70,17 @@ ARGS+=" --kl-loss-warmup-epochs $KL_LOSS_WARMUP"
 ARGS+=" --text-cleaners $TEXT_CLEANERS"
 ARGS+=" --n-speakers $NSPEAKERS"
 
-[ "$AMP" = "true" ]                && ARGS+=" --amp"
-[ "$PHONE" = "true" ]              && ARGS+=" --p-arpabet 1.0"
-[ "$ENERGY" = "true" ]             && ARGS+=" --energy-conditioning"
-[ "$SEED" != "" ]                  && ARGS+=" --seed $SEED"
-[ "$LOAD_MEL_FROM_DISK" = true ]   && ARGS+=" --load-mel-from-disk"
-[ "$LOAD_PITCH_FROM_DISK" = true ] && ARGS+=" --load-pitch-from-disk"
-[ "$PITCH_ONLINE_DIR" != "" ]      && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR"  # e.g., /dev/shm/pitch
-[ "$PITCH_ONLINE_METHOD" != "" ]   && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
-[ "$APPEND_SPACES" = true ]        && ARGS+=" --prepend-space-to-text"
-[ "$APPEND_SPACES" = true ]        && ARGS+=" --append-space-to-text"
+[ "$AMP" = "true" ]                    && ARGS+=" --amp"
+[ "$PHONE" = "true" ]                  && ARGS+=" --p-arpabet 1.0"
+[ "$ENERGY" = "true" ]                 && ARGS+=" --energy-conditioning"
+[ "$SEED" != "" ]                      && ARGS+=" --seed $SEED"
+[ "$LOAD_MEL_FROM_DISK" = true ]       && ARGS+=" --load-mel-from-disk"
+[ "$LOAD_PITCH_FROM_DISK" = true ]     && ARGS+=" --load-pitch-from-disk"
+[ "$PITCH_ONLINE_DIR" != "" ]          && ARGS+=" --pitch-online-dir $PITCH_ONLINE_DIR"  # e.g., /dev/shm/pitch
+[ "$PITCH_ONLINE_METHOD" != "" ]       && ARGS+=" --pitch-online-method $PITCH_ONLINE_METHOD"
+[ "$APPEND_SPACES" = true ]            && ARGS+=" --prepend-space-to-text"
+[ "$APPEND_SPACES" = true ]            && ARGS+=" --append-space-to-text"
+[[ "$ARGS" != *"--checkpoint-path"* ]] && ARGS+=" --resume"
 
 if [ "$SAMPLING_RATE" == "44100" ]; then
   ARGS+=" --sampling-rate 44100"

+ 9 - 2
PyTorch/SpeechSynthesis/FastPitch/train.py

@@ -47,9 +47,10 @@ from common.tb_dllogger import log
 from common.repeated_dataloader import (RepeatedDataLoader,
                                         RepeatedDistributedSampler)
 from common.text import cmudict
-from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
+from common.utils import (BenchmarkStats, Checkpointer,
+                          load_pretrained_weights, prepare_tmp)
 from fastpitch.attn_loss_function import AttentionBinarizationLoss
-from fastpitch.data_function import batch_to_gpu, TTSCollate, TTSDataset
+from fastpitch.data_function import batch_to_gpu, ensure_disjoint, TTSCollate, TTSDataset
 from fastpitch.loss_function import FastPitchLoss
 
 
@@ -95,6 +96,8 @@ def parse_args(parser):
                         help='Number of epochs for calculating final stats')
     train.add_argument('--validation-freq', type=int, default=1,
                        help='Validate every N epochs to use less compute')
+    train.add_argument('--init-from-checkpoint', type=str, default=None,
+                       help='Initialize model weights with a pre-trained ckpt')
 
     opt = parser.add_argument_group('optimization setup')
     opt.add_argument('--optimizer', type=str, default='lamb',
@@ -326,6 +329,9 @@ def main():
     model_config = models.get_model_config('FastPitch', args)
     model = models.get_model('FastPitch', model_config, device)
 
+    if args.init_from_checkpoint is not None:
+        load_pretrained_weights(model, args.init_from_checkpoint)
+
     attention_kl_loss = AttentionBinarizationLoss()
 
     # Store pitch mean/std as params to translate from Hz during inference
@@ -374,6 +380,7 @@ def main():
 
     trainset = TTSDataset(audiopaths_and_text=args.training_files, **vars(args))
     valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
+    ensure_disjoint(trainset, valset)
 
     if distributed_run:
         train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,