| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- # *****************************************************************************
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # * Neither the name of the NVIDIA CORPORATION nor the
- # names of its contributors may be used to endorse or promote products
- # derived from this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- #
- # *****************************************************************************
- import random
- import numpy as np
- import torch
- import torch.utils.data
- import common.layers as layers
- from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
- from common.text import text_to_sequence
- class TextMelLoader(torch.utils.data.Dataset):
- """
- 1) loads audio,text pairs
- 2) normalizes text and converts them to sequences of one-hot vectors
- 3) computes mel-spectrograms from audio files.
- """
- def __init__(self, dataset_path, audiopaths_and_text, args, load_mel_from_disk=True):
- self.audiopaths_and_text = load_filepaths_and_text(dataset_path, audiopaths_and_text)
- self.text_cleaners = args.text_cleaners
- self.load_mel_from_disk = load_mel_from_disk
- if not load_mel_from_disk:
- self.max_wav_value = args.max_wav_value
- self.sampling_rate = args.sampling_rate
- self.stft = layers.TacotronSTFT(
- args.filter_length, args.hop_length, args.win_length,
- args.n_mel_channels, args.sampling_rate, args.mel_fmin,
- args.mel_fmax)
- def get_mel(self, filename):
- if not self.load_mel_from_disk:
- audio, sampling_rate = load_wav_to_torch(filename)
- if sampling_rate != self.stft.sampling_rate:
- raise ValueError("{} {} SR doesn't match target {} SR".format(
- sampling_rate, self.stft.sampling_rate))
- audio_norm = audio / self.max_wav_value
- audio_norm = audio_norm.unsqueeze(0)
- audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
- melspec = self.stft.mel_spectrogram(audio_norm)
- melspec = torch.squeeze(melspec, 0)
- else:
- melspec = torch.load(filename)
- # assert melspec.size(0) == self.stft.n_mel_channels, (
- # 'Mel dimension mismatch: given {}, expected {}'.format(
- # melspec.size(0), self.stft.n_mel_channels))
- return melspec
- def get_text(self, text):
- text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
- return text_norm
- def __getitem__(self, index):
- # separate filename and text
- audiopath, text = self.audiopaths_and_text[index]
- len_text = len(text)
- text = self.get_text(text)
- mel = self.get_mel(audiopath)
- return (text, mel, len_text)
- def __len__(self):
- return len(self.audiopaths_and_text)
- class TextMelCollate():
- """ Zero-pads model inputs and targets based on number of frames per step
- """
- def __init__(self, n_frames_per_step):
- self.n_frames_per_step = n_frames_per_step
- def __call__(self, batch):
- """Collate's training batch from normalized text and mel-spectrogram
- PARAMS
- ------
- batch: [text_normalized, mel_normalized]
- """
- # Right zero-pad all one-hot text sequences to max input length
- input_lengths, ids_sorted_decreasing = torch.sort(
- torch.LongTensor([len(x[0]) for x in batch]),
- dim=0, descending=True)
- max_input_len = input_lengths[0]
- text_padded = torch.LongTensor(len(batch), max_input_len)
- text_padded.zero_()
- for i in range(len(ids_sorted_decreasing)):
- text = batch[ids_sorted_decreasing[i]][0]
- text_padded[i, :text.size(0)] = text
- # Right zero-pad mel-spec
- num_mels = batch[0][1].size(0)
- max_target_len = max([x[1].size(1) for x in batch])
- if max_target_len % self.n_frames_per_step != 0:
- max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
- assert max_target_len % self.n_frames_per_step == 0
- # include mel padded and gate padded
- mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
- mel_padded.zero_()
- gate_padded = torch.FloatTensor(len(batch), max_target_len)
- gate_padded.zero_()
- output_lengths = torch.LongTensor(len(batch))
- for i in range(len(ids_sorted_decreasing)):
- mel = batch[ids_sorted_decreasing[i]][1]
- mel_padded[i, :, :mel.size(1)] = mel
- gate_padded[i, mel.size(1)-1:] = 1
- output_lengths[i] = mel.size(1)
- # count number of items - characters in text
- len_x = [x[2] for x in batch]
- len_x = torch.Tensor(len_x)
- # Return any extra fields as sorted lists
- num_fields = len(batch[0])
- extra_fields = tuple([batch[i][f] for i in ids_sorted_decreasing]
- for f in range(3, num_fields))
- return (text_padded, input_lengths, mel_padded, gate_padded, \
- output_lengths, len_x) + extra_fields
- def batch_to_gpu(batch):
- text_padded, input_lengths, mel_padded, gate_padded, \
- output_lengths, len_x = batch
- text_padded = to_gpu(text_padded).long()
- input_lengths = to_gpu(input_lengths).long()
- max_len = torch.max(input_lengths.data).item()
- mel_padded = to_gpu(mel_padded).float()
- gate_padded = to_gpu(gate_padded).float()
- output_lengths = to_gpu(output_lengths).long()
- x = (text_padded, input_lengths, mel_padded, max_len, output_lengths)
- y = (mel_padded, gate_padded)
- len_x = torch.sum(output_lengths)
- return (x, y, len_x)
|