data_function.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. # *****************************************************************************
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import functools
  28. import json
  29. import re
  30. from pathlib import Path
  31. import librosa
  32. import numpy as np
  33. import torch
  34. import torch.nn.functional as F
  35. from scipy import ndimage
  36. from scipy.stats import betabinom
  37. import common.layers as layers
  38. from common.text.text_processing import get_text_processing
  39. from common.utils import load_wav_to_torch, load_filepaths_and_text, to_gpu
  40. class BetaBinomialInterpolator:
  41. """Interpolates alignment prior matrices to save computation.
  42. Calculating beta-binomial priors is costly. Instead cache popular sizes
  43. and use img interpolation to get priors faster.
  44. """
  45. def __init__(self, round_mel_len_to=100, round_text_len_to=20):
  46. self.round_mel_len_to = round_mel_len_to
  47. self.round_text_len_to = round_text_len_to
  48. self.bank = functools.lru_cache(beta_binomial_prior_distribution)
  49. def round(self, val, to):
  50. return max(1, int(np.round((val + 1) / to))) * to
  51. def __call__(self, w, h):
  52. bw = self.round(w, to=self.round_mel_len_to)
  53. bh = self.round(h, to=self.round_text_len_to)
  54. ret = ndimage.zoom(self.bank(bw, bh).T, zoom=(w / bw, h / bh), order=1)
  55. assert ret.shape[0] == w, ret.shape
  56. assert ret.shape[1] == h, ret.shape
  57. return ret
  58. def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
  59. P = phoneme_count
  60. M = mel_count
  61. x = np.arange(0, P)
  62. mel_text_probs = []
  63. for i in range(1, M+1):
  64. a, b = scaling * i, scaling * (M + 1 - i)
  65. rv = betabinom(P, a, b)
  66. mel_i_prob = rv.pmf(x)
  67. mel_text_probs.append(mel_i_prob)
  68. return torch.tensor(np.array(mel_text_probs))
  69. def estimate_pitch(wav, mel_len, method='pyin', normalize_mean=None,
  70. normalize_std=None, n_formants=1):
  71. if type(normalize_mean) is float or type(normalize_mean) is list:
  72. normalize_mean = torch.tensor(normalize_mean)
  73. if type(normalize_std) is float or type(normalize_std) is list:
  74. normalize_std = torch.tensor(normalize_std)
  75. if method == 'pyin':
  76. snd, sr = librosa.load(wav)
  77. pitch_mel, voiced_flag, voiced_probs = librosa.pyin(
  78. snd, fmin=librosa.note_to_hz('C2'),
  79. fmax=librosa.note_to_hz('C7'), frame_length=1024)
  80. assert np.abs(mel_len - pitch_mel.shape[0]) <= 1.0
  81. pitch_mel = np.where(np.isnan(pitch_mel), 0.0, pitch_mel)
  82. pitch_mel = torch.from_numpy(pitch_mel).unsqueeze(0)
  83. pitch_mel = F.pad(pitch_mel, (0, mel_len - pitch_mel.size(1)))
  84. if n_formants > 1:
  85. raise NotImplementedError
  86. else:
  87. raise ValueError
  88. pitch_mel = pitch_mel.float()
  89. if normalize_mean is not None:
  90. assert normalize_std is not None
  91. pitch_mel = normalize_pitch(pitch_mel, normalize_mean, normalize_std)
  92. return pitch_mel
  93. def normalize_pitch(pitch, mean, std):
  94. zeros = (pitch == 0.0)
  95. pitch -= mean[:, None]
  96. pitch /= std[:, None]
  97. pitch[zeros] = 0.0
  98. return pitch
  99. class TTSDataset(torch.utils.data.Dataset):
  100. """
  101. 1) loads audio,text pairs
  102. 2) normalizes text and converts them to sequences of one-hot vectors
  103. 3) computes mel-spectrograms from audio files.
  104. """
  105. def __init__(self,
  106. dataset_path,
  107. audiopaths_and_text,
  108. text_cleaners,
  109. n_mel_channels,
  110. symbol_set='english_basic',
  111. p_arpabet=1.0,
  112. n_speakers=1,
  113. load_mel_from_disk=True,
  114. load_pitch_from_disk=True,
  115. pitch_mean=214.72203, # LJSpeech defaults
  116. pitch_std=65.72038,
  117. max_wav_value=None,
  118. sampling_rate=None,
  119. filter_length=None,
  120. hop_length=None,
  121. win_length=None,
  122. mel_fmin=None,
  123. mel_fmax=None,
  124. prepend_space_to_text=False,
  125. append_space_to_text=False,
  126. pitch_online_dir=None,
  127. betabinomial_online_dir=None,
  128. use_betabinomial_interpolator=True,
  129. pitch_online_method='pyin',
  130. **ignored):
  131. # Expect a list of filenames
  132. if type(audiopaths_and_text) is str:
  133. audiopaths_and_text = [audiopaths_and_text]
  134. self.dataset_path = dataset_path
  135. self.audiopaths_and_text = load_filepaths_and_text(
  136. dataset_path, audiopaths_and_text,
  137. has_speakers=(n_speakers > 1))
  138. self.load_mel_from_disk = load_mel_from_disk
  139. if not load_mel_from_disk:
  140. self.max_wav_value = max_wav_value
  141. self.sampling_rate = sampling_rate
  142. self.stft = layers.TacotronSTFT(
  143. filter_length, hop_length, win_length,
  144. n_mel_channels, sampling_rate, mel_fmin, mel_fmax)
  145. self.load_pitch_from_disk = load_pitch_from_disk
  146. self.prepend_space_to_text = prepend_space_to_text
  147. self.append_space_to_text = append_space_to_text
  148. assert p_arpabet == 0.0 or p_arpabet == 1.0, (
  149. 'Only 0.0 and 1.0 p_arpabet is currently supported. '
  150. 'Variable probability breaks caching of betabinomial matrices.')
  151. self.tp = get_text_processing(symbol_set, text_cleaners, p_arpabet)
  152. self.n_speakers = n_speakers
  153. self.pitch_tmp_dir = pitch_online_dir
  154. self.f0_method = pitch_online_method
  155. self.betabinomial_tmp_dir = betabinomial_online_dir
  156. self.use_betabinomial_interpolator = use_betabinomial_interpolator
  157. if use_betabinomial_interpolator:
  158. self.betabinomial_interpolator = BetaBinomialInterpolator()
  159. expected_columns = (2 + int(load_pitch_from_disk) + (n_speakers > 1))
  160. assert not (load_pitch_from_disk and self.pitch_tmp_dir is not None)
  161. if len(self.audiopaths_and_text[0]) < expected_columns:
  162. raise ValueError(f'Expected {expected_columns} columns in audiopaths file. '
  163. 'The format is <mel_or_wav>|[<pitch>|]<text>[|<speaker_id>]')
  164. if len(self.audiopaths_and_text[0]) > expected_columns:
  165. print('WARNING: Audiopaths file has more columns than expected')
  166. to_tensor = lambda x: torch.Tensor([x]) if type(x) is float else x
  167. self.pitch_mean = to_tensor(pitch_mean)
  168. self.pitch_std = to_tensor(pitch_std)
  169. def __getitem__(self, index):
  170. # Separate filename and text
  171. if self.n_speakers > 1:
  172. audiopath, *extra, text, speaker = self.audiopaths_and_text[index]
  173. speaker = int(speaker)
  174. else:
  175. audiopath, *extra, text = self.audiopaths_and_text[index]
  176. speaker = None
  177. mel = self.get_mel(audiopath)
  178. text = self.get_text(text)
  179. pitch = self.get_pitch(index, mel.size(-1))
  180. energy = torch.norm(mel.float(), dim=0, p=2)
  181. attn_prior = self.get_prior(index, mel.shape[1], text.shape[0])
  182. assert pitch.size(-1) == mel.size(-1)
  183. # No higher formants?
  184. if len(pitch.size()) == 1:
  185. pitch = pitch[None, :]
  186. return (text, mel, len(text), pitch, energy, speaker, attn_prior,
  187. audiopath)
  188. def __len__(self):
  189. return len(self.audiopaths_and_text)
  190. def get_mel(self, filename):
  191. if not self.load_mel_from_disk:
  192. audio, sampling_rate = load_wav_to_torch(filename)
  193. if sampling_rate != self.stft.sampling_rate:
  194. raise ValueError("{} SR doesn't match target {} SR".format(
  195. sampling_rate, self.stft.sampling_rate))
  196. audio_norm = audio / self.max_wav_value
  197. audio_norm = audio_norm.unsqueeze(0)
  198. audio_norm = torch.autograd.Variable(audio_norm,
  199. requires_grad=False)
  200. melspec = self.stft.mel_spectrogram(audio_norm)
  201. melspec = torch.squeeze(melspec, 0)
  202. else:
  203. melspec = torch.load(filename)
  204. # assert melspec.size(0) == self.stft.n_mel_channels, (
  205. # 'Mel dimension mismatch: given {}, expected {}'.format(
  206. # melspec.size(0), self.stft.n_mel_channels))
  207. return melspec
  208. def get_text(self, text):
  209. text = self.tp.encode_text(text)
  210. space = [self.tp.encode_text("A A")[1]]
  211. if self.prepend_space_to_text:
  212. text = space + text
  213. if self.append_space_to_text:
  214. text = text + space
  215. return torch.LongTensor(text)
  216. def get_prior(self, index, mel_len, text_len):
  217. if self.use_betabinomial_interpolator:
  218. return torch.from_numpy(self.betabinomial_interpolator(mel_len,
  219. text_len))
  220. if self.betabinomial_tmp_dir is not None:
  221. audiopath, *_ = self.audiopaths_and_text[index]
  222. fname = Path(audiopath).relative_to(self.dataset_path)
  223. fname = fname.with_suffix('.pt')
  224. cached_fpath = Path(self.betabinomial_tmp_dir, fname)
  225. if cached_fpath.is_file():
  226. return torch.load(cached_fpath)
  227. attn_prior = beta_binomial_prior_distribution(text_len, mel_len)
  228. if self.betabinomial_tmp_dir is not None:
  229. cached_fpath.parent.mkdir(parents=True, exist_ok=True)
  230. torch.save(attn_prior, cached_fpath)
  231. return attn_prior
  232. def get_pitch(self, index, mel_len=None):
  233. audiopath, *fields = self.audiopaths_and_text[index]
  234. if self.n_speakers > 1:
  235. spk = int(fields[-1])
  236. else:
  237. spk = 0
  238. if self.load_pitch_from_disk:
  239. pitchpath = fields[0]
  240. pitch = torch.load(pitchpath)
  241. if self.pitch_mean is not None:
  242. assert self.pitch_std is not None
  243. pitch = normalize_pitch(pitch, self.pitch_mean, self.pitch_std)
  244. return pitch
  245. if self.pitch_tmp_dir is not None:
  246. fname = Path(audiopath).relative_to(self.dataset_path)
  247. fname_method = fname.with_suffix('.pt')
  248. cached_fpath = Path(self.pitch_tmp_dir, fname_method)
  249. if cached_fpath.is_file():
  250. return torch.load(cached_fpath)
  251. # No luck so far - calculate
  252. wav = audiopath
  253. if not wav.endswith('.wav'):
  254. wav = re.sub('/mels/', '/wavs/', wav)
  255. wav = re.sub('.pt$', '.wav', wav)
  256. pitch_mel = estimate_pitch(wav, mel_len, self.f0_method,
  257. self.pitch_mean, self.pitch_std)
  258. if self.pitch_tmp_dir is not None and not cached_fpath.is_file():
  259. cached_fpath.parent.mkdir(parents=True, exist_ok=True)
  260. torch.save(pitch_mel, cached_fpath)
  261. return pitch_mel
  262. def ensure_disjoint(*tts_datasets):
  263. paths = [set(list(zip(*d.audiopaths_and_text))[0]) for d in tts_datasets]
  264. assert sum(len(p) for p in paths) == len(set().union(*paths)), (
  265. "Your datasets (train, val) are not disjoint. "
  266. "Review filelists and restart training."
  267. )
  268. class TTSCollate:
  269. """Zero-pads model inputs and targets based on number of frames per step"""
  270. def __call__(self, batch):
  271. """Collate training batch from normalized text and mel-spec"""
  272. # Right zero-pad all one-hot text sequences to max input length
  273. input_lengths, ids_sorted_decreasing = torch.sort(
  274. torch.LongTensor([len(x[0]) for x in batch]),
  275. dim=0, descending=True)
  276. max_input_len = input_lengths[0]
  277. text_padded = torch.LongTensor(len(batch), max_input_len)
  278. text_padded.zero_()
  279. for i in range(len(ids_sorted_decreasing)):
  280. text = batch[ids_sorted_decreasing[i]][0]
  281. text_padded[i, :text.size(0)] = text
  282. # Right zero-pad mel-spec
  283. num_mels = batch[0][1].size(0)
  284. max_target_len = max([x[1].size(1) for x in batch])
  285. # Include mel padded and gate padded
  286. mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
  287. mel_padded.zero_()
  288. output_lengths = torch.LongTensor(len(batch))
  289. for i in range(len(ids_sorted_decreasing)):
  290. mel = batch[ids_sorted_decreasing[i]][1]
  291. mel_padded[i, :, :mel.size(1)] = mel
  292. output_lengths[i] = mel.size(1)
  293. n_formants = batch[0][3].shape[0]
  294. pitch_padded = torch.zeros(mel_padded.size(0), n_formants,
  295. mel_padded.size(2), dtype=batch[0][3].dtype)
  296. energy_padded = torch.zeros_like(pitch_padded[:, 0, :])
  297. for i in range(len(ids_sorted_decreasing)):
  298. pitch = batch[ids_sorted_decreasing[i]][3]
  299. energy = batch[ids_sorted_decreasing[i]][4]
  300. pitch_padded[i, :, :pitch.shape[1]] = pitch
  301. energy_padded[i, :energy.shape[0]] = energy
  302. if batch[0][5] is not None:
  303. speaker = torch.zeros_like(input_lengths)
  304. for i in range(len(ids_sorted_decreasing)):
  305. speaker[i] = batch[ids_sorted_decreasing[i]][5]
  306. else:
  307. speaker = None
  308. attn_prior_padded = torch.zeros(len(batch), max_target_len,
  309. max_input_len)
  310. attn_prior_padded.zero_()
  311. for i in range(len(ids_sorted_decreasing)):
  312. prior = batch[ids_sorted_decreasing[i]][6]
  313. attn_prior_padded[i, :prior.size(0), :prior.size(1)] = prior
  314. # Count number of items - characters in text
  315. len_x = [x[2] for x in batch]
  316. len_x = torch.Tensor(len_x)
  317. audiopaths = [batch[i][7] for i in ids_sorted_decreasing]
  318. return (text_padded, input_lengths, mel_padded, output_lengths, len_x,
  319. pitch_padded, energy_padded, speaker, attn_prior_padded,
  320. audiopaths)
  321. def batch_to_gpu(batch):
  322. (text_padded, input_lengths, mel_padded, output_lengths, len_x,
  323. pitch_padded, energy_padded, speaker, attn_prior, audiopaths) = batch
  324. text_padded = to_gpu(text_padded).long()
  325. input_lengths = to_gpu(input_lengths).long()
  326. mel_padded = to_gpu(mel_padded).float()
  327. output_lengths = to_gpu(output_lengths).long()
  328. pitch_padded = to_gpu(pitch_padded).float()
  329. energy_padded = to_gpu(energy_padded).float()
  330. attn_prior = to_gpu(attn_prior).float()
  331. if speaker is not None:
  332. speaker = to_gpu(speaker).long()
  333. # Alignments act as both inputs and targets - pass shallow copies
  334. x = [text_padded, input_lengths, mel_padded, output_lengths,
  335. pitch_padded, energy_padded, speaker, attn_prior, audiopaths]
  336. y = [mel_padded, input_lengths, output_lengths]
  337. len_x = torch.sum(output_lengths)
  338. return (x, y, len_x)