data_function.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import random
  28. import numpy as np
  29. import torch
  30. import torch.utils.data
  31. import common.layers as layers
  32. from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
  33. from common.text import text_to_sequence
  34. class TextMelLoader(torch.utils.data.Dataset):
  35. """
  36. 1) loads audio,text pairs
  37. 2) normalizes text and converts them to sequences of one-hot vectors
  38. 3) computes mel-spectrograms from audio files.
  39. """
  40. def __init__(self, dataset_path, audiopaths_and_text, args, load_mel_from_disk=True):
  41. self.audiopaths_and_text = load_filepaths_and_text(dataset_path, audiopaths_and_text)
  42. self.text_cleaners = args.text_cleaners
  43. self.load_mel_from_disk = load_mel_from_disk
  44. if not load_mel_from_disk:
  45. self.max_wav_value = args.max_wav_value
  46. self.sampling_rate = args.sampling_rate
  47. self.stft = layers.TacotronSTFT(
  48. args.filter_length, args.hop_length, args.win_length,
  49. args.n_mel_channels, args.sampling_rate, args.mel_fmin,
  50. args.mel_fmax)
  51. def get_mel(self, filename):
  52. if not self.load_mel_from_disk:
  53. audio, sampling_rate = load_wav_to_torch(filename)
  54. if sampling_rate != self.stft.sampling_rate:
  55. raise ValueError("{} {} SR doesn't match target {} SR".format(
  56. sampling_rate, self.stft.sampling_rate))
  57. audio_norm = audio / self.max_wav_value
  58. audio_norm = audio_norm.unsqueeze(0)
  59. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  60. melspec = self.stft.mel_spectrogram(audio_norm)
  61. melspec = torch.squeeze(melspec, 0)
  62. else:
  63. melspec = torch.load(filename)
  64. # assert melspec.size(0) == self.stft.n_mel_channels, (
  65. # 'Mel dimension mismatch: given {}, expected {}'.format(
  66. # melspec.size(0), self.stft.n_mel_channels))
  67. return melspec
  68. def get_text(self, text):
  69. text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
  70. return text_norm
  71. def __getitem__(self, index):
  72. # separate filename and text
  73. audiopath, text = self.audiopaths_and_text[index]
  74. len_text = len(text)
  75. text = self.get_text(text)
  76. mel = self.get_mel(audiopath)
  77. return (text, mel, len_text)
  78. def __len__(self):
  79. return len(self.audiopaths_and_text)
  80. class TextMelCollate():
  81. """ Zero-pads model inputs and targets based on number of frames per step
  82. """
  83. def __init__(self, n_frames_per_step):
  84. self.n_frames_per_step = n_frames_per_step
  85. def __call__(self, batch):
  86. """Collate's training batch from normalized text and mel-spectrogram
  87. PARAMS
  88. ------
  89. batch: [text_normalized, mel_normalized]
  90. """
  91. # Right zero-pad all one-hot text sequences to max input length
  92. input_lengths, ids_sorted_decreasing = torch.sort(
  93. torch.LongTensor([len(x[0]) for x in batch]),
  94. dim=0, descending=True)
  95. max_input_len = input_lengths[0]
  96. text_padded = torch.LongTensor(len(batch), max_input_len)
  97. text_padded.zero_()
  98. for i in range(len(ids_sorted_decreasing)):
  99. text = batch[ids_sorted_decreasing[i]][0]
  100. text_padded[i, :text.size(0)] = text
  101. # Right zero-pad mel-spec
  102. num_mels = batch[0][1].size(0)
  103. max_target_len = max([x[1].size(1) for x in batch])
  104. if max_target_len % self.n_frames_per_step != 0:
  105. max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
  106. assert max_target_len % self.n_frames_per_step == 0
  107. # include mel padded and gate padded
  108. mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
  109. mel_padded.zero_()
  110. gate_padded = torch.FloatTensor(len(batch), max_target_len)
  111. gate_padded.zero_()
  112. output_lengths = torch.LongTensor(len(batch))
  113. for i in range(len(ids_sorted_decreasing)):
  114. mel = batch[ids_sorted_decreasing[i]][1]
  115. mel_padded[i, :, :mel.size(1)] = mel
  116. gate_padded[i, mel.size(1)-1:] = 1
  117. output_lengths[i] = mel.size(1)
  118. # count number of items - characters in text
  119. len_x = [x[2] for x in batch]
  120. len_x = torch.Tensor(len_x)
  121. # Return any extra fields as sorted lists
  122. num_fields = len(batch[0])
  123. extra_fields = tuple([batch[i][f] for i in ids_sorted_decreasing]
  124. for f in range(3, num_fields))
  125. return (text_padded, input_lengths, mel_padded, gate_padded, \
  126. output_lengths, len_x) + extra_fields
  127. def batch_to_gpu(batch):
  128. text_padded, input_lengths, mel_padded, gate_padded, \
  129. output_lengths, len_x = batch
  130. text_padded = to_gpu(text_padded).long()
  131. input_lengths = to_gpu(input_lengths).long()
  132. max_len = torch.max(input_lengths.data).item()
  133. mel_padded = to_gpu(mel_padded).float()
  134. gate_padded = to_gpu(gate_padded).float()
  135. output_lengths = to_gpu(output_lengths).long()
  136. x = (text_padded, input_lengths, mel_padded, max_len, output_lengths)
  137. y = (mel_padded, gate_padded)
  138. len_x = torch.sum(output_lengths)
  139. return (x, y, len_x)