model.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. from math import sqrt
  28. import torch
  29. from torch import nn
  30. from torch.nn import functional as F
  31. import sys
  32. from os.path import abspath, dirname
  33. # enabling modules discovery from global entrypoint
  34. sys.path.append(abspath(dirname(__file__)+'/../'))
  35. from tacotron2_common.layers import ConvNorm, LinearNorm
  36. from tacotron2_common.utils import to_gpu, get_mask_from_lengths
  37. class LocationLayer(nn.Module):
  38. def __init__(self, attention_n_filters, attention_kernel_size,
  39. attention_dim):
  40. super(LocationLayer, self).__init__()
  41. padding = int((attention_kernel_size - 1) / 2)
  42. self.location_conv = ConvNorm(2, attention_n_filters,
  43. kernel_size=attention_kernel_size,
  44. padding=padding, bias=False, stride=1,
  45. dilation=1)
  46. self.location_dense = LinearNorm(attention_n_filters, attention_dim,
  47. bias=False, w_init_gain='tanh')
  48. def forward(self, attention_weights_cat):
  49. processed_attention = self.location_conv(attention_weights_cat)
  50. processed_attention = processed_attention.transpose(1, 2)
  51. processed_attention = self.location_dense(processed_attention)
  52. return processed_attention
  53. class Attention(nn.Module):
  54. def __init__(self, attention_rnn_dim, embedding_dim,
  55. attention_dim, attention_location_n_filters,
  56. attention_location_kernel_size):
  57. super(Attention, self).__init__()
  58. self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
  59. bias=False, w_init_gain='tanh')
  60. self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
  61. w_init_gain='tanh')
  62. self.v = LinearNorm(attention_dim, 1, bias=False)
  63. self.location_layer = LocationLayer(attention_location_n_filters,
  64. attention_location_kernel_size,
  65. attention_dim)
  66. self.score_mask_value = -float("inf")
  67. def get_alignment_energies(self, query, processed_memory,
  68. attention_weights_cat):
  69. """
  70. PARAMS
  71. ------
  72. query: decoder output (batch, n_mel_channels * n_frames_per_step)
  73. processed_memory: processed encoder outputs (B, T_in, attention_dim)
  74. attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
  75. RETURNS
  76. -------
  77. alignment (batch, max_time)
  78. """
  79. processed_query = self.query_layer(query.unsqueeze(1))
  80. processed_attention_weights = self.location_layer(attention_weights_cat)
  81. energies = self.v(torch.tanh(
  82. processed_query + processed_attention_weights + processed_memory))
  83. energies = energies.squeeze(2)
  84. return energies
  85. def forward(self, attention_hidden_state, memory, processed_memory,
  86. attention_weights_cat, mask):
  87. """
  88. PARAMS
  89. ------
  90. attention_hidden_state: attention rnn last output
  91. memory: encoder outputs
  92. processed_memory: processed encoder outputs
  93. attention_weights_cat: previous and cummulative attention weights
  94. mask: binary mask for padded data
  95. """
  96. alignment = self.get_alignment_energies(
  97. attention_hidden_state, processed_memory, attention_weights_cat)
  98. alignment = alignment.masked_fill(mask, self.score_mask_value)
  99. attention_weights = F.softmax(alignment, dim=1)
  100. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  101. attention_context = attention_context.squeeze(1)
  102. return attention_context, attention_weights
  103. class Prenet(nn.Module):
  104. def __init__(self, in_dim, sizes):
  105. super(Prenet, self).__init__()
  106. in_sizes = [in_dim] + sizes[:-1]
  107. self.layers = nn.ModuleList(
  108. [LinearNorm(in_size, out_size, bias=False)
  109. for (in_size, out_size) in zip(in_sizes, sizes)])
  110. def forward(self, x):
  111. for linear in self.layers:
  112. x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
  113. return x
  114. class Postnet(nn.Module):
  115. """Postnet
  116. - Five 1-d convolution with 512 channels and kernel size 5
  117. """
  118. def __init__(self, n_mel_channels, postnet_embedding_dim,
  119. postnet_kernel_size, postnet_n_convolutions):
  120. super(Postnet, self).__init__()
  121. self.convolutions = nn.ModuleList()
  122. self.convolutions.append(
  123. nn.Sequential(
  124. ConvNorm(n_mel_channels, postnet_embedding_dim,
  125. kernel_size=postnet_kernel_size, stride=1,
  126. padding=int((postnet_kernel_size - 1) / 2),
  127. dilation=1, w_init_gain='tanh'),
  128. nn.BatchNorm1d(postnet_embedding_dim))
  129. )
  130. for i in range(1, postnet_n_convolutions - 1):
  131. self.convolutions.append(
  132. nn.Sequential(
  133. ConvNorm(postnet_embedding_dim,
  134. postnet_embedding_dim,
  135. kernel_size=postnet_kernel_size, stride=1,
  136. padding=int((postnet_kernel_size - 1) / 2),
  137. dilation=1, w_init_gain='tanh'),
  138. nn.BatchNorm1d(postnet_embedding_dim))
  139. )
  140. self.convolutions.append(
  141. nn.Sequential(
  142. ConvNorm(postnet_embedding_dim, n_mel_channels,
  143. kernel_size=postnet_kernel_size, stride=1,
  144. padding=int((postnet_kernel_size - 1) / 2),
  145. dilation=1, w_init_gain='linear'),
  146. nn.BatchNorm1d(n_mel_channels))
  147. )
  148. self.n_convs = len(self.convolutions)
  149. def forward(self, x):
  150. i = 0
  151. for conv in self.convolutions:
  152. if i < self.n_convs - 1:
  153. x = F.dropout(torch.tanh(conv(x)), 0.5, training=self.training)
  154. else:
  155. x = F.dropout(conv(x), 0.5, training=self.training)
  156. i += 1
  157. return x
  158. class Encoder(nn.Module):
  159. """Encoder module:
  160. - Three 1-d convolution banks
  161. - Bidirectional LSTM
  162. """
  163. def __init__(self, encoder_n_convolutions,
  164. encoder_embedding_dim, encoder_kernel_size):
  165. super(Encoder, self).__init__()
  166. convolutions = []
  167. for _ in range(encoder_n_convolutions):
  168. conv_layer = nn.Sequential(
  169. ConvNorm(encoder_embedding_dim,
  170. encoder_embedding_dim,
  171. kernel_size=encoder_kernel_size, stride=1,
  172. padding=int((encoder_kernel_size - 1) / 2),
  173. dilation=1, w_init_gain='relu'),
  174. nn.BatchNorm1d(encoder_embedding_dim))
  175. convolutions.append(conv_layer)
  176. self.convolutions = nn.ModuleList(convolutions)
  177. self.lstm = nn.LSTM(encoder_embedding_dim,
  178. int(encoder_embedding_dim / 2), 1,
  179. batch_first=True, bidirectional=True)
  180. @torch.jit.ignore
  181. def forward(self, x, input_lengths):
  182. for conv in self.convolutions:
  183. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  184. x = x.transpose(1, 2)
  185. # pytorch tensor are not reversible, hence the conversion
  186. input_lengths = input_lengths.cpu().numpy()
  187. x = nn.utils.rnn.pack_padded_sequence(
  188. x, input_lengths, batch_first=True)
  189. self.lstm.flatten_parameters()
  190. outputs, _ = self.lstm(x)
  191. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  192. outputs, batch_first=True)
  193. return outputs
  194. @torch.jit.export
  195. def infer(self, x, input_lengths):
  196. device = x.device
  197. for conv in self.convolutions:
  198. x = F.dropout(F.relu(conv(x.to(device))), 0.5, self.training)
  199. x = x.transpose(1, 2)
  200. input_lengths = input_lengths.cpu()
  201. x = nn.utils.rnn.pack_padded_sequence(
  202. x, input_lengths, batch_first=True)
  203. outputs, _ = self.lstm(x)
  204. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  205. outputs, batch_first=True)
  206. return outputs
  207. class Decoder(nn.Module):
  208. def __init__(self, n_mel_channels, n_frames_per_step,
  209. encoder_embedding_dim, attention_dim,
  210. attention_location_n_filters,
  211. attention_location_kernel_size,
  212. attention_rnn_dim, decoder_rnn_dim,
  213. prenet_dim, max_decoder_steps, gate_threshold,
  214. p_attention_dropout, p_decoder_dropout,
  215. early_stopping):
  216. super(Decoder, self).__init__()
  217. self.n_mel_channels = n_mel_channels
  218. self.n_frames_per_step = n_frames_per_step
  219. self.encoder_embedding_dim = encoder_embedding_dim
  220. self.attention_rnn_dim = attention_rnn_dim
  221. self.decoder_rnn_dim = decoder_rnn_dim
  222. self.prenet_dim = prenet_dim
  223. self.max_decoder_steps = max_decoder_steps
  224. self.gate_threshold = gate_threshold
  225. self.p_attention_dropout = p_attention_dropout
  226. self.p_decoder_dropout = p_decoder_dropout
  227. self.early_stopping = early_stopping
  228. self.prenet = Prenet(
  229. n_mel_channels * n_frames_per_step,
  230. [prenet_dim, prenet_dim])
  231. self.attention_rnn = nn.LSTMCell(
  232. prenet_dim + encoder_embedding_dim,
  233. attention_rnn_dim)
  234. self.attention_layer = Attention(
  235. attention_rnn_dim, encoder_embedding_dim,
  236. attention_dim, attention_location_n_filters,
  237. attention_location_kernel_size)
  238. self.decoder_rnn = nn.LSTMCell(
  239. attention_rnn_dim + encoder_embedding_dim,
  240. decoder_rnn_dim, 1)
  241. self.linear_projection = LinearNorm(
  242. decoder_rnn_dim + encoder_embedding_dim,
  243. n_mel_channels * n_frames_per_step)
  244. self.gate_layer = LinearNorm(
  245. decoder_rnn_dim + encoder_embedding_dim, 1,
  246. bias=True, w_init_gain='sigmoid')
  247. def get_go_frame(self, memory):
  248. """ Gets all zeros frames to use as first decoder input
  249. PARAMS
  250. ------
  251. memory: decoder outputs
  252. RETURNS
  253. -------
  254. decoder_input: all zeros frames
  255. """
  256. B = memory.size(0)
  257. dtype = memory.dtype
  258. device = memory.device
  259. decoder_input = torch.zeros(
  260. B, self.n_mel_channels*self.n_frames_per_step,
  261. dtype=dtype, device=device)
  262. return decoder_input
  263. def initialize_decoder_states(self, memory):
  264. """ Initializes attention rnn states, decoder rnn states, attention
  265. weights, attention cumulative weights, attention context, stores memory
  266. and stores processed memory
  267. PARAMS
  268. ------
  269. memory: Encoder outputs
  270. mask: Mask for padded data if training, expects None for inference
  271. """
  272. B = memory.size(0)
  273. MAX_TIME = memory.size(1)
  274. dtype = memory.dtype
  275. device = memory.device
  276. attention_hidden = torch.zeros(
  277. B, self.attention_rnn_dim, dtype=dtype, device=device)
  278. attention_cell = torch.zeros(
  279. B, self.attention_rnn_dim, dtype=dtype, device=device)
  280. decoder_hidden = torch.zeros(
  281. B, self.decoder_rnn_dim, dtype=dtype, device=device)
  282. decoder_cell = torch.zeros(
  283. B, self.decoder_rnn_dim, dtype=dtype, device=device)
  284. attention_weights = torch.zeros(
  285. B, MAX_TIME, dtype=dtype, device=device)
  286. attention_weights_cum = torch.zeros(
  287. B, MAX_TIME, dtype=dtype, device=device)
  288. attention_context = torch.zeros(
  289. B, self.encoder_embedding_dim, dtype=dtype, device=device)
  290. processed_memory = self.attention_layer.memory_layer(memory)
  291. return (attention_hidden, attention_cell, decoder_hidden,
  292. decoder_cell, attention_weights, attention_weights_cum,
  293. attention_context, processed_memory)
  294. def parse_decoder_inputs(self, decoder_inputs):
  295. """ Prepares decoder inputs, i.e. mel outputs
  296. PARAMS
  297. ------
  298. decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
  299. RETURNS
  300. -------
  301. inputs: processed decoder inputs
  302. """
  303. # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
  304. decoder_inputs = decoder_inputs.transpose(1, 2)
  305. decoder_inputs = decoder_inputs.view(
  306. decoder_inputs.size(0),
  307. int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
  308. # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
  309. decoder_inputs = decoder_inputs.transpose(0, 1)
  310. return decoder_inputs
  311. def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
  312. """ Prepares decoder outputs for output
  313. PARAMS
  314. ------
  315. mel_outputs:
  316. gate_outputs: gate output energies
  317. alignments:
  318. RETURNS
  319. -------
  320. mel_outputs:
  321. gate_outpust: gate output energies
  322. alignments:
  323. """
  324. # (T_out, B) -> (B, T_out)
  325. alignments = alignments.transpose(0, 1).contiguous()
  326. # (T_out, B) -> (B, T_out)
  327. gate_outputs = gate_outputs.transpose(0, 1).contiguous()
  328. # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
  329. mel_outputs = mel_outputs.transpose(0, 1).contiguous()
  330. # decouple frames per step
  331. shape = (mel_outputs.shape[0], -1, self.n_mel_channels)
  332. mel_outputs = mel_outputs.view(*shape)
  333. # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
  334. mel_outputs = mel_outputs.transpose(1, 2)
  335. return mel_outputs, gate_outputs, alignments
  336. def decode(self, decoder_input, attention_hidden, attention_cell,
  337. decoder_hidden, decoder_cell, attention_weights,
  338. attention_weights_cum, attention_context, memory,
  339. processed_memory, mask):
  340. """ Decoder step using stored states, attention and memory
  341. PARAMS
  342. ------
  343. decoder_input: previous mel output
  344. RETURNS
  345. -------
  346. mel_output:
  347. gate_output: gate output energies
  348. attention_weights:
  349. """
  350. cell_input = torch.cat((decoder_input, attention_context), -1)
  351. attention_hidden, attention_cell = self.attention_rnn(
  352. cell_input, (attention_hidden, attention_cell))
  353. attention_hidden = F.dropout(
  354. attention_hidden, self.p_attention_dropout, self.training)
  355. attention_weights_cat = torch.cat(
  356. (attention_weights.unsqueeze(1),
  357. attention_weights_cum.unsqueeze(1)), dim=1)
  358. attention_context, attention_weights = self.attention_layer(
  359. attention_hidden, memory, processed_memory,
  360. attention_weights_cat, mask)
  361. attention_weights_cum += attention_weights
  362. decoder_input = torch.cat(
  363. (attention_hidden, attention_context), -1)
  364. decoder_hidden, decoder_cell = self.decoder_rnn(
  365. decoder_input, (decoder_hidden, decoder_cell))
  366. decoder_hidden = F.dropout(
  367. decoder_hidden, self.p_decoder_dropout, self.training)
  368. decoder_hidden_attention_context = torch.cat(
  369. (decoder_hidden, attention_context), dim=1)
  370. decoder_output = self.linear_projection(
  371. decoder_hidden_attention_context)
  372. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  373. return (decoder_output, gate_prediction, attention_hidden,
  374. attention_cell, decoder_hidden, decoder_cell, attention_weights,
  375. attention_weights_cum, attention_context)
  376. @torch.jit.ignore
  377. def forward(self, memory, decoder_inputs, memory_lengths):
  378. """ Decoder forward pass for training
  379. PARAMS
  380. ------
  381. memory: Encoder outputs
  382. decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
  383. memory_lengths: Encoder output lengths for attention masking.
  384. RETURNS
  385. -------
  386. mel_outputs: mel outputs from the decoder
  387. gate_outputs: gate outputs from the decoder
  388. alignments: sequence of attention weights from the decoder
  389. """
  390. decoder_input = self.get_go_frame(memory).unsqueeze(0)
  391. decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
  392. decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
  393. decoder_inputs = self.prenet(decoder_inputs)
  394. mask = get_mask_from_lengths(memory_lengths)
  395. (attention_hidden,
  396. attention_cell,
  397. decoder_hidden,
  398. decoder_cell,
  399. attention_weights,
  400. attention_weights_cum,
  401. attention_context,
  402. processed_memory) = self.initialize_decoder_states(memory)
  403. mel_outputs, gate_outputs, alignments = [], [], []
  404. while len(mel_outputs) < decoder_inputs.size(0) - 1:
  405. decoder_input = decoder_inputs[len(mel_outputs)]
  406. (mel_output,
  407. gate_output,
  408. attention_hidden,
  409. attention_cell,
  410. decoder_hidden,
  411. decoder_cell,
  412. attention_weights,
  413. attention_weights_cum,
  414. attention_context) = self.decode(decoder_input,
  415. attention_hidden,
  416. attention_cell,
  417. decoder_hidden,
  418. decoder_cell,
  419. attention_weights,
  420. attention_weights_cum,
  421. attention_context,
  422. memory,
  423. processed_memory,
  424. mask)
  425. mel_outputs += [mel_output.squeeze(1)]
  426. gate_outputs += [gate_output.squeeze()]
  427. alignments += [attention_weights]
  428. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  429. torch.stack(mel_outputs),
  430. torch.stack(gate_outputs),
  431. torch.stack(alignments))
  432. return mel_outputs, gate_outputs, alignments
  433. @torch.jit.export
  434. def infer(self, memory, memory_lengths):
  435. """ Decoder inference
  436. PARAMS
  437. ------
  438. memory: Encoder outputs
  439. RETURNS
  440. -------
  441. mel_outputs: mel outputs from the decoder
  442. gate_outputs: gate outputs from the decoder
  443. alignments: sequence of attention weights from the decoder
  444. """
  445. decoder_input = self.get_go_frame(memory)
  446. mask = get_mask_from_lengths(memory_lengths)
  447. (attention_hidden,
  448. attention_cell,
  449. decoder_hidden,
  450. decoder_cell,
  451. attention_weights,
  452. attention_weights_cum,
  453. attention_context,
  454. processed_memory) = self.initialize_decoder_states(memory)
  455. mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32, device=memory.device)
  456. not_finished = torch.ones([memory.size(0)], dtype=torch.int32, device=memory.device)
  457. mel_outputs, gate_outputs, alignments = (
  458. torch.zeros(1), torch.zeros(1), torch.zeros(1))
  459. first_iter = True
  460. while True:
  461. decoder_input = self.prenet(decoder_input)
  462. (mel_output,
  463. gate_output,
  464. attention_hidden,
  465. attention_cell,
  466. decoder_hidden,
  467. decoder_cell,
  468. attention_weights,
  469. attention_weights_cum,
  470. attention_context) = self.decode(decoder_input,
  471. attention_hidden,
  472. attention_cell,
  473. decoder_hidden,
  474. decoder_cell,
  475. attention_weights,
  476. attention_weights_cum,
  477. attention_context,
  478. memory,
  479. processed_memory,
  480. mask)
  481. if first_iter:
  482. mel_outputs = mel_output.unsqueeze(0)
  483. gate_outputs = gate_output
  484. alignments = attention_weights
  485. first_iter = False
  486. else:
  487. mel_outputs = torch.cat(
  488. (mel_outputs, mel_output.unsqueeze(0)), dim=0)
  489. gate_outputs = torch.cat((gate_outputs, gate_output), dim=0)
  490. alignments = torch.cat((alignments, attention_weights), dim=0)
  491. dec = torch.le(torch.sigmoid(gate_output),
  492. self.gate_threshold).to(torch.int32).squeeze(1)
  493. not_finished = not_finished*dec
  494. mel_lengths += not_finished
  495. if self.early_stopping and torch.sum(not_finished) == 0:
  496. break
  497. if len(mel_outputs) == self.max_decoder_steps:
  498. print("Warning! Reached max decoder steps")
  499. break
  500. decoder_input = mel_output
  501. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  502. mel_outputs, gate_outputs, alignments)
  503. return mel_outputs, gate_outputs, alignments, mel_lengths
  504. class Tacotron2(nn.Module):
  505. def __init__(self, mask_padding, n_mel_channels,
  506. n_symbols, symbols_embedding_dim, encoder_kernel_size,
  507. encoder_n_convolutions, encoder_embedding_dim,
  508. attention_rnn_dim, attention_dim, attention_location_n_filters,
  509. attention_location_kernel_size, n_frames_per_step,
  510. decoder_rnn_dim, prenet_dim, max_decoder_steps, gate_threshold,
  511. p_attention_dropout, p_decoder_dropout,
  512. postnet_embedding_dim, postnet_kernel_size,
  513. postnet_n_convolutions, decoder_no_early_stopping):
  514. super(Tacotron2, self).__init__()
  515. self.mask_padding = mask_padding
  516. self.n_mel_channels = n_mel_channels
  517. self.n_frames_per_step = n_frames_per_step
  518. self.embedding = nn.Embedding(n_symbols, symbols_embedding_dim)
  519. std = sqrt(2.0 / (n_symbols + symbols_embedding_dim))
  520. val = sqrt(3.0) * std # uniform bounds for std
  521. self.embedding.weight.data.uniform_(-val, val)
  522. self.encoder = Encoder(encoder_n_convolutions,
  523. encoder_embedding_dim,
  524. encoder_kernel_size)
  525. self.decoder = Decoder(n_mel_channels, n_frames_per_step,
  526. encoder_embedding_dim, attention_dim,
  527. attention_location_n_filters,
  528. attention_location_kernel_size,
  529. attention_rnn_dim, decoder_rnn_dim,
  530. prenet_dim, max_decoder_steps,
  531. gate_threshold, p_attention_dropout,
  532. p_decoder_dropout,
  533. not decoder_no_early_stopping)
  534. self.postnet = Postnet(n_mel_channels, postnet_embedding_dim,
  535. postnet_kernel_size,
  536. postnet_n_convolutions)
  537. def parse_batch(self, batch):
  538. text_padded, input_lengths, mel_padded, gate_padded, \
  539. output_lengths = batch
  540. text_padded = to_gpu(text_padded).long()
  541. input_lengths = to_gpu(input_lengths).long()
  542. max_len = torch.max(input_lengths.data).item()
  543. mel_padded = to_gpu(mel_padded).float()
  544. gate_padded = to_gpu(gate_padded).float()
  545. output_lengths = to_gpu(output_lengths).long()
  546. return (
  547. (text_padded, input_lengths, mel_padded, max_len, output_lengths),
  548. (mel_padded, gate_padded))
  549. def parse_output(self, outputs, output_lengths):
  550. # type: (List[Tensor], Tensor) -> List[Tensor]
  551. if self.mask_padding and output_lengths is not None:
  552. mask = get_mask_from_lengths(output_lengths)
  553. mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
  554. mask = mask.permute(1, 0, 2)
  555. outputs[0].masked_fill_(mask, 0.0)
  556. outputs[1].masked_fill_(mask, 0.0)
  557. outputs[2].masked_fill_(mask[:, 0, :], 1e3) # gate energies
  558. return outputs
  559. def forward(self, inputs):
  560. inputs, input_lengths, targets, max_len, output_lengths = inputs
  561. input_lengths, output_lengths = input_lengths.data, output_lengths.data
  562. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  563. encoder_outputs = self.encoder(embedded_inputs, input_lengths)
  564. mel_outputs, gate_outputs, alignments = self.decoder(
  565. encoder_outputs, targets, memory_lengths=input_lengths)
  566. mel_outputs_postnet = self.postnet(mel_outputs)
  567. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  568. return self.parse_output(
  569. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
  570. output_lengths)
  571. def infer(self, inputs, input_lengths):
  572. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  573. encoder_outputs = self.encoder.infer(embedded_inputs, input_lengths)
  574. mel_outputs, gate_outputs, alignments, mel_lengths = self.decoder.infer(
  575. encoder_outputs, input_lengths)
  576. mel_outputs_postnet = self.postnet(mel_outputs)
  577. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  578. BS = mel_outputs_postnet.size(0)
  579. alignments = alignments.unfold(1, BS, BS).transpose(0,2)
  580. return mel_outputs_postnet, mel_lengths, alignments