data_function.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import random
  28. import numpy as np
  29. import torch
  30. import torch.utils.data
  31. import common.layers as layers
  32. from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
  33. from tacotron2.text import text_to_sequence
  34. class TextMelLoader(torch.utils.data.Dataset):
  35. """
  36. 1) loads audio,text pairs
  37. 2) normalizes text and converts them to sequences of one-hot vectors
  38. 3) computes mel-spectrograms from audio files.
  39. """
  40. def __init__(self, dataset_path, audiopaths_and_text, args):
  41. self.audiopaths_and_text = load_filepaths_and_text(dataset_path, audiopaths_and_text)
  42. self.text_cleaners = args.text_cleaners
  43. self.max_wav_value = args.max_wav_value
  44. self.sampling_rate = args.sampling_rate
  45. self.load_mel_from_disk = args.load_mel_from_disk
  46. self.stft = layers.TacotronSTFT(
  47. args.filter_length, args.hop_length, args.win_length,
  48. args.n_mel_channels, args.sampling_rate, args.mel_fmin,
  49. args.mel_fmax)
  50. random.seed(1234)
  51. random.shuffle(self.audiopaths_and_text)
  52. def get_mel_text_pair(self, audiopath_and_text):
  53. # separate filename and text
  54. audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
  55. len_text = len(text)
  56. text = self.get_text(text)
  57. mel = self.get_mel(audiopath)
  58. return (text, mel, len_text)
  59. def get_mel(self, filename):
  60. if not self.load_mel_from_disk:
  61. audio, sampling_rate = load_wav_to_torch(filename)
  62. if sampling_rate != self.stft.sampling_rate:
  63. raise ValueError("{} {} SR doesn't match target {} SR".format(
  64. sampling_rate, self.stft.sampling_rate))
  65. audio_norm = audio / self.max_wav_value
  66. audio_norm = audio_norm.unsqueeze(0)
  67. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  68. melspec = self.stft.mel_spectrogram(audio_norm)
  69. melspec = torch.squeeze(melspec, 0)
  70. else:
  71. melspec = torch.load(filename)
  72. assert melspec.size(0) == self.stft.n_mel_channels, (
  73. 'Mel dimension mismatch: given {}, expected {}'.format(
  74. melspec.size(0), self.stft.n_mel_channels))
  75. return melspec
  76. def get_text(self, text):
  77. text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
  78. return text_norm
  79. def __getitem__(self, index):
  80. return self.get_mel_text_pair(self.audiopaths_and_text[index])
  81. def __len__(self):
  82. return len(self.audiopaths_and_text)
  83. class TextMelCollate():
  84. """ Zero-pads model inputs and targets based on number of frames per setep
  85. """
  86. def __init__(self, n_frames_per_step):
  87. self.n_frames_per_step = n_frames_per_step
  88. def __call__(self, batch):
  89. """Collate's training batch from normalized text and mel-spectrogram
  90. PARAMS
  91. ------
  92. batch: [text_normalized, mel_normalized]
  93. """
  94. # Right zero-pad all one-hot text sequences to max input length
  95. input_lengths, ids_sorted_decreasing = torch.sort(
  96. torch.LongTensor([len(x[0]) for x in batch]),
  97. dim=0, descending=True)
  98. max_input_len = input_lengths[0]
  99. text_padded = torch.LongTensor(len(batch), max_input_len)
  100. text_padded.zero_()
  101. for i in range(len(ids_sorted_decreasing)):
  102. text = batch[ids_sorted_decreasing[i]][0]
  103. text_padded[i, :text.size(0)] = text
  104. # Right zero-pad mel-spec
  105. num_mels = batch[0][1].size(0)
  106. max_target_len = max([x[1].size(1) for x in batch])
  107. if max_target_len % self.n_frames_per_step != 0:
  108. max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
  109. assert max_target_len % self.n_frames_per_step == 0
  110. # include mel padded and gate padded
  111. mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
  112. mel_padded.zero_()
  113. gate_padded = torch.FloatTensor(len(batch), max_target_len)
  114. gate_padded.zero_()
  115. output_lengths = torch.LongTensor(len(batch))
  116. for i in range(len(ids_sorted_decreasing)):
  117. mel = batch[ids_sorted_decreasing[i]][1]
  118. mel_padded[i, :, :mel.size(1)] = mel
  119. gate_padded[i, mel.size(1)-1:] = 1
  120. output_lengths[i] = mel.size(1)
  121. # count number of items - characters in text
  122. len_x = [x[2] for x in batch]
  123. len_x = torch.Tensor(len_x)
  124. return text_padded, input_lengths, mel_padded, gate_padded, \
  125. output_lengths, len_x
  126. def batch_to_gpu(batch):
  127. text_padded, input_lengths, mel_padded, gate_padded, \
  128. output_lengths, len_x = batch
  129. text_padded = to_gpu(text_padded).long()
  130. input_lengths = to_gpu(input_lengths).long()
  131. max_len = torch.max(input_lengths.data).item()
  132. mel_padded = to_gpu(mel_padded).float()
  133. gate_padded = to_gpu(gate_padded).float()
  134. output_lengths = to_gpu(output_lengths).long()
  135. x = (text_padded, input_lengths, mel_padded, max_len, output_lengths)
  136. y = (mel_padded, gate_padded)
  137. len_x = torch.sum(output_lengths)
  138. return (x, y, len_x)