model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import torch
  28. from torch.autograd import Variable
  29. import torch.nn.functional as F
  30. @torch.jit.script
  31. def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
  32. n_channels_int = n_channels[0]
  33. in_act = input_a + input_b
  34. t_act = torch.tanh(in_act[:, :n_channels_int, :])
  35. s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
  36. acts = t_act * s_act
  37. return acts
  38. class Invertible1x1Conv(torch.nn.Module):
  39. """
  40. The layer outputs both the convolution, and the log determinant
  41. of its weight matrix. If reverse=True it does convolution with
  42. inverse
  43. """
  44. def __init__(self, c):
  45. super(Invertible1x1Conv, self).__init__()
  46. self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
  47. bias=False)
  48. # Sample a random orthonormal matrix to initialize weights
  49. W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
  50. # Ensure determinant is 1.0 not -1.0
  51. if torch.det(W) < 0:
  52. W[:, 0] = -1 * W[:, 0]
  53. W = W.view(c, c, 1)
  54. self.conv.weight.data = W
  55. def forward(self, z, reverse=False):
  56. # shape
  57. batch_size, group_size, n_of_groups = z.size()
  58. W = self.conv.weight.squeeze()
  59. if reverse:
  60. if not hasattr(self, 'W_inverse'):
  61. # Reverse computation
  62. W_inverse = W.float().inverse()
  63. W_inverse = Variable(W_inverse[..., None])
  64. if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor':
  65. W_inverse = W_inverse.half()
  66. self.W_inverse = W_inverse
  67. z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
  68. return z
  69. else:
  70. # Forward computation
  71. log_det_W = batch_size * n_of_groups * torch.logdet(W.unsqueeze(0).float()).squeeze()
  72. z = self.conv(z)
  73. return z, log_det_W
  74. class WN(torch.nn.Module):
  75. """
  76. This is the WaveNet like layer for the affine coupling. The primary
  77. difference from WaveNet is the convolutions need not be causal. There is
  78. also no dilation size reset. The dilation only doubles on each layer
  79. """
  80. def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
  81. kernel_size):
  82. super(WN, self).__init__()
  83. assert(kernel_size % 2 == 1)
  84. assert(n_channels % 2 == 0)
  85. self.n_layers = n_layers
  86. self.n_channels = n_channels
  87. self.in_layers = torch.nn.ModuleList()
  88. self.res_skip_layers = torch.nn.ModuleList()
  89. self.cond_layers = torch.nn.ModuleList()
  90. start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
  91. start = torch.nn.utils.weight_norm(start, name='weight')
  92. self.start = start
  93. # Initializing last layer to 0 makes the affine coupling layers
  94. # do nothing at first. This helps with training stability
  95. end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
  96. end.weight.data.zero_()
  97. end.bias.data.zero_()
  98. self.end = end
  99. for i in range(n_layers):
  100. dilation = 2 ** i
  101. padding = int((kernel_size * dilation - dilation) / 2)
  102. in_layer = torch.nn.Conv1d(n_channels, 2 * n_channels, kernel_size,
  103. dilation=dilation, padding=padding)
  104. in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
  105. self.in_layers.append(in_layer)
  106. cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels, 1)
  107. cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
  108. self.cond_layers.append(cond_layer)
  109. # last one is not necessary
  110. if i < n_layers - 1:
  111. res_skip_channels = 2 * n_channels
  112. else:
  113. res_skip_channels = n_channels
  114. res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
  115. res_skip_layer = torch.nn.utils.weight_norm(
  116. res_skip_layer, name='weight')
  117. self.res_skip_layers.append(res_skip_layer)
  118. def forward(self, forward_input):
  119. audio, spect = forward_input
  120. audio = self.start(audio)
  121. for i in range(self.n_layers):
  122. acts = fused_add_tanh_sigmoid_multiply(
  123. self.in_layers[i](audio),
  124. self.cond_layers[i](spect),
  125. torch.IntTensor([self.n_channels]))
  126. res_skip_acts = self.res_skip_layers[i](acts)
  127. if i < self.n_layers - 1:
  128. audio = res_skip_acts[:, :self.n_channels, :] + audio
  129. skip_acts = res_skip_acts[:, self.n_channels:, :]
  130. else:
  131. skip_acts = res_skip_acts
  132. if i == 0:
  133. output = skip_acts
  134. else:
  135. output = skip_acts + output
  136. return self.end(output)
  137. class WaveGlow(torch.nn.Module):
  138. def __init__(self, n_mel_channels, n_flows, n_group, n_early_every,
  139. n_early_size, WN_config):
  140. super(WaveGlow, self).__init__()
  141. self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
  142. n_mel_channels,
  143. 1024, stride=256)
  144. assert(n_group % 2 == 0)
  145. self.n_flows = n_flows
  146. self.n_group = n_group
  147. self.n_early_every = n_early_every
  148. self.n_early_size = n_early_size
  149. self.WN = torch.nn.ModuleList()
  150. self.convinv = torch.nn.ModuleList()
  151. n_half = int(n_group / 2)
  152. # Set up layers with the right sizes based on how many dimensions
  153. # have been output already
  154. n_remaining_channels = n_group
  155. for k in range(n_flows):
  156. if k % self.n_early_every == 0 and k > 0:
  157. n_half = n_half - int(self.n_early_size / 2)
  158. n_remaining_channels = n_remaining_channels - self.n_early_size
  159. self.convinv.append(Invertible1x1Conv(n_remaining_channels))
  160. self.WN.append(WN(n_half, n_mel_channels * n_group, **WN_config))
  161. self.n_remaining_channels = n_remaining_channels
  162. def forward(self, forward_input):
  163. """
  164. forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
  165. forward_input[1] = audio: batch x time
  166. """
  167. spect, audio = forward_input
  168. # Upsample spectrogram to size of audio
  169. spect = self.upsample(spect)
  170. assert(spect.size(2) >= audio.size(1))
  171. if spect.size(2) > audio.size(1):
  172. spect = spect[:, :, :audio.size(1)]
  173. spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
  174. spect = spect.contiguous().view(spect.size(0), spect.size(1), -1)
  175. spect = spect.permute(0, 2, 1)
  176. audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
  177. output_audio = []
  178. log_s_list = []
  179. log_det_W_list = []
  180. for k in range(self.n_flows):
  181. if k % self.n_early_every == 0 and k > 0:
  182. output_audio.append(audio[:, :self.n_early_size, :])
  183. audio = audio[:, self.n_early_size:, :]
  184. audio, log_det_W = self.convinv[k](audio)
  185. log_det_W_list.append(log_det_W)
  186. n_half = int(audio.size(1) / 2)
  187. audio_0 = audio[:, :n_half, :]
  188. audio_1 = audio[:, n_half:, :]
  189. output = self.WN[k]((audio_0, spect))
  190. log_s = output[:, n_half:, :]
  191. b = output[:, :n_half, :]
  192. audio_1 = torch.exp(log_s) * audio_1 + b
  193. log_s_list.append(log_s)
  194. audio = torch.cat([audio_0, audio_1], 1)
  195. output_audio.append(audio)
  196. return torch.cat(output_audio, 1), log_s_list, log_det_W_list
  197. def infer(self, spect, sigma=1.0):
  198. spect = self.upsample(spect)
  199. # trim conv artifacts. maybe pad spec to kernel multiple
  200. time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
  201. spect = spect[:, :, :-time_cutoff]
  202. spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
  203. spect = spect.contiguous().view(spect.size(0), spect.size(1), -1)
  204. spect = spect.permute(0, 2, 1)
  205. audio = torch.randn(spect.size(0),
  206. self.n_remaining_channels,
  207. spect.size(2), device=spect.device).to(spect.dtype)
  208. audio = torch.autograd.Variable(sigma * audio)
  209. for k in reversed(range(self.n_flows)):
  210. n_half = int(audio.size(1) / 2)
  211. audio_0 = audio[:, :n_half, :]
  212. audio_1 = audio[:, n_half:, :]
  213. output = self.WN[k]((audio_0, spect))
  214. s = output[:, n_half:, :]
  215. b = output[:, :n_half, :]
  216. audio_1 = (audio_1 - b) / torch.exp(s)
  217. audio = torch.cat([audio_0, audio_1], 1)
  218. audio = self.convinv[k](audio, reverse=True)
  219. if k % self.n_early_every == 0 and k > 0:
  220. z = torch.randn(spect.size(0), self.n_early_size, spect.size(
  221. 2), device=spect.device).to(spect.dtype)
  222. audio = torch.cat((sigma * z, audio), 1)
  223. audio = audio.permute(
  224. 0, 2, 1).contiguous().view(
  225. audio.size(0), -1).data
  226. return audio
  227. @staticmethod
  228. def remove_weightnorm(model):
  229. waveglow = model
  230. for WN in waveglow.WN:
  231. WN.start = torch.nn.utils.remove_weight_norm(WN.start)
  232. WN.in_layers = remove(WN.in_layers)
  233. WN.cond_layers = remove(WN.cond_layers)
  234. WN.res_skip_layers = remove(WN.res_skip_layers)
  235. return waveglow
  236. def remove(conv_list):
  237. new_conv_list = torch.nn.ModuleList()
  238. for old_conv in conv_list:
  239. old_conv = torch.nn.utils.remove_weight_norm(old_conv)
  240. new_conv_list.append(old_conv)
  241. return new_conv_list