Jelajahi Sumber

Merge: [WaveGlow/PyT] Enable TorchScript

Krzysztof Kudrynski 3 tahun lalu
induk
melakukan
6a160116ef

+ 10 - 5
PyTorch/SpeechSynthesis/Tacotron2/inference.py

@@ -106,13 +106,15 @@ def unwrap_distributed(state_dict):
     return new_state_dict
 
 
-def load_and_setup_model(model_name, parser, checkpoint, fp16_run, cpu_run, forward_is_infer=False):
+def load_and_setup_model(model_name, parser, checkpoint, fp16_run, cpu_run,
+                         forward_is_infer=False, jittable=False):
     model_parser = models.model_parser(model_name, parser, add_help=False)
     model_args, _ = model_parser.parse_known_args()
 
     model_config = models.get_model_config(model_name, model_args)
     model = models.get_model(model_name, model_config, cpu_run=cpu_run,
-                             forward_is_infer=forward_is_infer)
+                             forward_is_infer=forward_is_infer,
+                             jittable=jittable)
 
     if checkpoint is not None:
         if cpu_run:
@@ -207,11 +209,14 @@ def main():
     tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2,
                                      args.fp16, args.cpu, forward_is_infer=True)
     waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow,
-                                    args.fp16, args.cpu, forward_is_infer=True)
+                                    args.fp16, args.cpu, forward_is_infer=True,
+                                    jittable=True)
     denoiser = Denoiser(waveglow)
     if not args.cpu:
         denoiser.cuda()
 
+    waveglow.make_ts_scriptable()
+    jitted_waveglow = torch.jit.script(waveglow)
     jitted_tacotron2 = torch.jit.script(tacotron2)
 
     texts = []
@@ -231,7 +236,7 @@ def main():
         for i in range(3):
             with torch.no_grad():
                 mel, mel_lengths, _ = jitted_tacotron2(sequence, input_lengths)
-                _ = waveglow(mel)
+                _ = jitted_waveglow(mel)
 
     measurements = {}
 
@@ -241,7 +246,7 @@ def main():
         mel, mel_lengths, alignments = jitted_tacotron2(sequences_padded, input_lengths)
 
     with torch.no_grad(), MeasureTime(measurements, "waveglow_time", args.cpu):
-        audios = waveglow(mel, sigma=args.sigma_infer)
+        audios = jitted_waveglow(mel, sigma=args.sigma_infer)
         audios = audios.float()
     with torch.no_grad(), MeasureTime(measurements, "denoiser_time", args.cpu):
         audios = denoiser(audios, strength=args.denoising_strength).squeeze(1)

+ 6 - 7
PyTorch/SpeechSynthesis/Tacotron2/models.py

@@ -63,7 +63,8 @@ def init_bn(module):
 
 
 def get_model(model_name, model_config, cpu_run,
-              uniform_initialize_bn_weight=False, forward_is_infer=False):
+              uniform_initialize_bn_weight=False, forward_is_infer=False,
+              jittable=False):
     """ Code chooses a model based on name"""
     model = None
     if model_name == 'Tacotron2':
@@ -75,13 +76,11 @@ def get_model(model_name, model_config, cpu_run,
         else:
             model = Tacotron2(**model_config)
     elif model_name == 'WaveGlow':
+
+        model = WaveGlow(**model_config)
         if forward_is_infer:
-            class WaveGlow__forward_is_infer(WaveGlow):
-                def forward(self, spect, sigma=1.0):
-                    return self.infer(spect, sigma)
-            model = WaveGlow__forward_is_infer(**model_config)
-        else:
-            model = WaveGlow(**model_config)
+            model.forward = model.infer
+
     else:
         raise NotImplementedError(model_name)
 

+ 68 - 30
PyTorch/SpeechSynthesis/Tacotron2/waveglow/model.py

@@ -26,13 +26,14 @@
 # *****************************************************************************
 import torch
 torch._C._jit_set_autocast_mode(False)
-from torch.autograd import Variable
+import torch.nn as nn
 import torch.nn.functional as F
+from torch.autograd import Variable
 
 
 @torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
-    n_channels_int = n_channels[0]
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels : int):
+    n_channels_int = n_channels
     in_act = input_a + input_b
     t_act = torch.tanh(in_act[:, :n_channels_int, :])
     s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
@@ -73,22 +74,14 @@ class Invertible1x1Conv(torch.nn.Module):
         z = self.conv(z)
         return z, log_det_W
 
-
     def infer(self, z):
-        # shape
-        batch_size, group_size, n_of_groups = z.size()
-
-        W = self.conv.weight.squeeze()
+        self._invert()
+        return F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
 
+    def _invert(self):
         if not hasattr(self, 'W_inverse'):
-            # Reverse computation
-            W_inverse = W.float().inverse()
-            W_inverse = Variable(W_inverse[..., None])
-            if z.type() == 'torch.cuda.HalfTensor' or z.type() == 'torch.HalfTensor':
-                W_inverse = W_inverse.half()
-            self.W_inverse = W_inverse
-        z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
-        return z
+            W = self.conv.weight.squeeze()
+            self.W_inverse = W.float().inverse().unsqueeze(-1).to(W.dtype)
 
 
 class WN(torch.nn.Module):
@@ -142,27 +135,25 @@ class WN(torch.nn.Module):
                 res_skip_layer, name='weight')
             self.res_skip_layers.append(res_skip_layer)
 
-    def forward(self, forward_input):
-        audio, spect = forward_input
+    def forward(self, audio, spect):
         audio = self.start(audio)
 
-        for i in range(self.n_layers):
+        output = 0
+        for i, (in_layer, cond_layer, res_skip_layer) in enumerate(
+                zip(self.in_layers, self.cond_layers, self.res_skip_layers)):
             acts = fused_add_tanh_sigmoid_multiply(
-                self.in_layers[i](audio),
-                self.cond_layers[i](spect),
-                torch.IntTensor([self.n_channels]))
+                in_layer(audio),
+                cond_layer(spect),
+                self.n_channels)
 
-            res_skip_acts = self.res_skip_layers[i](acts)
+            res_skip_acts = res_skip_layer(acts)
             if i < self.n_layers - 1:
                 audio = res_skip_acts[:, :self.n_channels, :] + audio
                 skip_acts = res_skip_acts[:, self.n_channels:, :]
             else:
                 skip_acts = res_skip_acts
 
-            if i == 0:
-                output = skip_acts
-            else:
-                output = skip_acts + output
+            output += skip_acts
         return self.end(output)
 
 
@@ -229,7 +220,7 @@ class WaveGlow(torch.nn.Module):
             audio_0 = audio[:, :n_half, :]
             audio_1 = audio[:, n_half:, :]
 
-            output = self.WN[k]((audio_0, spect))
+            output = self.WN[k](audio_0, spect)
             log_s = output[:, n_half:, :]
             b = output[:, :n_half, :]
             audio_1 = torch.exp(log_s) * audio_1 + b
@@ -262,7 +253,7 @@ class WaveGlow(torch.nn.Module):
             audio_0 = audio[:, :n_half, :]
             audio_1 = audio[:, n_half:, :]
 
-            output = self.WN[k]((audio_0, spect))
+            output = self.WN[k](audio_0, spect)
             s = output[:, n_half:, :]
             b = output[:, :n_half, :]
             audio_1 = (audio_1 - b) / torch.exp(s)
@@ -308,7 +299,7 @@ class WaveGlow(torch.nn.Module):
             audio_0 = audio[:, :n_half, :]
             audio_1 = audio[:, n_half:(n_half+n_half), :]
 
-            output = self.WN[k]((audio_0, spect))
+            output = self.WN[k](audio_0, spect)
             s = output[:, n_half:(n_half+n_half), :]
             b = output[:, :n_half, :]
             audio_1 = (audio_1 - b) / torch.exp(s)
@@ -323,6 +314,53 @@ class WaveGlow(torch.nn.Module):
 
         return audio
 
+    def _infer_ts(self, spect, sigma : float=1.0):
+
+        spect = self.upsample(spect)
+        # trim conv artifacts. maybe pad spec to kernel multiple
+        time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0]
+        spect = spect[:, :, :-time_cutoff]
+
+        spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
+        spect = spect.contiguous().view(spect.size(0), spect.size(1), -1)
+        spect = spect.permute(0, 2, 1)
+
+        audio = torch.randn(spect.size(0), self.n_remaining_channels,
+                            spect.size(2), device=spect.device,
+                            dtype=spect.dtype)
+        audio *= sigma
+
+        for kk, (wn, convinv) in enumerate(zip(self.WN_rev, self.convinv_rev)):
+            k = self.n_flows - kk - 1
+            n_half = int(audio.size(1) / 2)
+            audio_0 = audio[:, :n_half, :]
+            audio_1 = audio[:, n_half:, :]
+
+            output = wn(audio_0, spect)
+            s = output[:, n_half:, :]
+            b = output[:, :n_half, :]
+            audio_1 = (audio_1 - b) / torch.exp(s)
+            audio = torch.cat([audio_0, audio_1], 1)
+
+            audio = convinv.infer(audio)
+
+            if k % self.n_early_every == 0 and k > 0:
+                z = torch.randn(spect.size(0), self.n_early_size,
+                                spect.size(2), device=spect.device,
+                                dtype=spect.dtype)
+                audio = torch.cat((sigma * z, audio), 1)
+
+        return audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1).data
+
+    def make_ts_scriptable(self, forward_is_infer=True):
+        self.WN_rev = torch.nn.ModuleList(reversed(self.WN))
+        self.convinv_rev = torch.nn.ModuleList(reversed(self.convinv))
+        for conv in self.convinv_rev:
+            conv._invert()
+
+        self.infer = self._infer_ts
+        if forward_is_infer:
+            self.forward = self._infer_ts
 
     @staticmethod
     def remove_weightnorm(model):