|
|
@@ -26,13 +26,14 @@
|
|
|
# *****************************************************************************
|
|
|
import torch
|
|
|
torch._C._jit_set_autocast_mode(False)
|
|
|
-from torch.autograd import Variable
|
|
|
+import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
+from torch.autograd import Variable
|
|
|
|
|
|
|
|
|
@torch.jit.script
|
|
|
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
|
|
- n_channels_int = n_channels[0]
|
|
|
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels : int):
|
|
|
+ n_channels_int = n_channels
|
|
|
in_act = input_a + input_b
|
|
|
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
|
|
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
|
|
@@ -73,22 +74,14 @@ class Invertible1x1Conv(torch.nn.Module):
|
|
|
z = self.conv(z)
|
|
|
return z, log_det_W
|
|
|
|
|
|
-
|
|
|
def infer(self, z):
|
|
|
- # shape
|
|
|
- batch_size, group_size, n_of_groups = z.size()
|
|
|
-
|
|
|
- W = self.conv.weight.squeeze()
|
|
|
+ self._invert()
|
|
|
+ return F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
|
|
|
|
|
+ def _invert(self):
|
|
|
if not hasattr(self, 'W_inverse'):
|
|
|
- # Reverse computation
|
|
|
- W_inverse = W.float().inverse()
|
|
|
- W_inverse = Variable(W_inverse[..., None])
|
|
|
- if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor':
|
|
|
- W_inverse = W_inverse.half()
|
|
|
- self.W_inverse = W_inverse
|
|
|
- z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
|
|
|
- return z
|
|
|
+ W = self.conv.weight.squeeze()
|
|
|
+ self.W_inverse = W.float().inverse().unsqueeze(-1).to(W.dtype)
|
|
|
|
|
|
|
|
|
class WN(torch.nn.Module):
|
|
|
@@ -142,27 +135,25 @@ class WN(torch.nn.Module):
|
|
|
res_skip_layer, name='weight')
|
|
|
self.res_skip_layers.append(res_skip_layer)
|
|
|
|
|
|
- def forward(self, forward_input):
|
|
|
- audio, spect = forward_input
|
|
|
+ def forward(self, audio, spect):
|
|
|
audio = self.start(audio)
|
|
|
|
|
|
- for i in range(self.n_layers):
|
|
|
+ output = 0
|
|
|
+ for i, (in_layer, cond_layer, res_skip_layer) in enumerate(
|
|
|
+ zip(self.in_layers, self.cond_layers, self.res_skip_layers)):
|
|
|
acts = fused_add_tanh_sigmoid_multiply(
|
|
|
- self.in_layers[i](audio),
|
|
|
- self.cond_layers[i](spect),
|
|
|
- torch.IntTensor([self.n_channels]))
|
|
|
+ in_layer(audio),
|
|
|
+ cond_layer(spect),
|
|
|
+ self.n_channels)
|
|
|
|
|
|
- res_skip_acts = self.res_skip_layers[i](acts)
|
|
|
+ res_skip_acts = res_skip_layer(acts)
|
|
|
if i < self.n_layers - 1:
|
|
|
audio = res_skip_acts[:, :self.n_channels, :] + audio
|
|
|
skip_acts = res_skip_acts[:, self.n_channels:, :]
|
|
|
else:
|
|
|
skip_acts = res_skip_acts
|
|
|
|
|
|
- if i == 0:
|
|
|
- output = skip_acts
|
|
|
- else:
|
|
|
- output = skip_acts + output
|
|
|
+ output += skip_acts
|
|
|
return self.end(output)
|
|
|
|
|
|
|
|
|
@@ -229,7 +220,7 @@ class WaveGlow(torch.nn.Module):
|
|
|
audio_0 = audio[:, :n_half, :]
|
|
|
audio_1 = audio[:, n_half:, :]
|
|
|
|
|
|
- output = self.WN[k]((audio_0, spect))
|
|
|
+ output = self.WN[k](audio_0, spect)
|
|
|
log_s = output[:, n_half:, :]
|
|
|
b = output[:, :n_half, :]
|
|
|
audio_1 = torch.exp(log_s) * audio_1 + b
|
|
|
@@ -262,7 +253,7 @@ class WaveGlow(torch.nn.Module):
|
|
|
audio_0 = audio[:, :n_half, :]
|
|
|
audio_1 = audio[:, n_half:, :]
|
|
|
|
|
|
- output = self.WN[k]((audio_0, spect))
|
|
|
+ output = self.WN[k](audio_0, spect)
|
|
|
s = output[:, n_half:, :]
|
|
|
b = output[:, :n_half, :]
|
|
|
audio_1 = (audio_1 - b) / torch.exp(s)
|
|
|
@@ -308,7 +299,7 @@ class WaveGlow(torch.nn.Module):
|
|
|
audio_0 = audio[:, :n_half, :]
|
|
|
audio_1 = audio[:, n_half:(n_half+n_half), :]
|
|
|
|
|
|
- output = self.WN[k]((audio_0, spect))
|
|
|
+ output = self.WN[k](audio_0, spect)
|
|
|
s = output[:, n_half:(n_half+n_half), :]
|
|
|
b = output[:, :n_half, :]
|
|
|
audio_1 = (audio_1 - b) / torch.exp(s)
|
|
|
@@ -323,6 +314,53 @@ class WaveGlow(torch.nn.Module):
|
|
|
|
|
|
return audio
|
|
|
|
|
|
+ def _infer_ts(self, spect, sigma : float=1.0):
|
|
|
+
|
|
|
+ spect = self.upsample(spect)
|
|
|
+ # trim conv artifacts. maybe pad spec to kernel multiple
|
|
|
+ time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
|
|
|
+ spect = spect[:, :, :-time_cutoff]
|
|
|
+
|
|
|
+ spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
|
|
|
+ spect = spect.contiguous().view(spect.size(0), spect.size(1), -1)
|
|
|
+ spect = spect.permute(0, 2, 1)
|
|
|
+
|
|
|
+ audio = torch.randn(spect.size(0), self.n_remaining_channels,
|
|
|
+ spect.size(2), device=spect.device,
|
|
|
+ dtype=spect.dtype)
|
|
|
+ audio *= sigma
|
|
|
+
|
|
|
+ for kk, (wn, convinv) in enumerate(zip(self.WN_rev, self.convinv_rev)):
|
|
|
+ k = self.n_flows - kk - 1
|
|
|
+ n_half = int(audio.size(1) / 2)
|
|
|
+ audio_0 = audio[:, :n_half, :]
|
|
|
+ audio_1 = audio[:, n_half:, :]
|
|
|
+
|
|
|
+ output = wn(audio_0, spect)
|
|
|
+ s = output[:, n_half:, :]
|
|
|
+ b = output[:, :n_half, :]
|
|
|
+ audio_1 = (audio_1 - b) / torch.exp(s)
|
|
|
+ audio = torch.cat([audio_0, audio_1], 1)
|
|
|
+
|
|
|
+ audio = convinv.infer(audio)
|
|
|
+
|
|
|
+ if k % self.n_early_every == 0 and k > 0:
|
|
|
+ z = torch.randn(spect.size(0), self.n_early_size,
|
|
|
+ spect.size(2), device=spect.device,
|
|
|
+ dtype=spect.dtype)
|
|
|
+ audio = torch.cat((sigma * z, audio), 1)
|
|
|
+
|
|
|
+ return audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1).data
|
|
|
+
|
|
|
+ def make_ts_scriptable(self, forward_is_infer=True):
|
|
|
+ self.WN_rev = torch.nn.ModuleList(reversed(self.WN))
|
|
|
+ self.convinv_rev = torch.nn.ModuleList(reversed(self.convinv))
|
|
|
+ for conv in self.convinv_rev:
|
|
|
+ conv._invert()
|
|
|
+
|
|
|
+ self.infer = self._infer_ts
|
|
|
+ if forward_is_infer:
|
|
|
+ self.forward = self._infer_ts
|
|
|
|
|
|
@staticmethod
|
|
|
def remove_weightnorm(model):
|