Explorar o código

[Tacotron2/PyT] stabilize inference performance results for Tacotron 2

Grzegorz Karch %!s(int64=3) %!d(string=hai) anos
pai
achega
0422fa0c88
Modificáronse 1 ficheiros con 11 adicións e 6 borrados
  1. 11 6
      PyTorch/SpeechSynthesis/Tacotron2/inference_perf.py

+ 11 - 6
PyTorch/SpeechSynthesis/Tacotron2/inference_perf.py

@@ -33,6 +33,7 @@ import json
 import time
 import os
 import sys
+import random
 
 from inference import checkpoint_from_distributed, unwrap_distributed, load_and_setup_model, MeasureTime, prepare_input_sequence
 
@@ -63,7 +64,7 @@ def parse_args(parser):
 
 def gen_text(use_synthetic_data):
     batch_size = 1
-    text_len = 140
+    text_len = 170
 
     if use_synthetic_data:
         text_padded = torch.randint(low=0, high=148,
@@ -72,9 +73,9 @@ def gen_text(use_synthetic_data):
         input_lengths = torch.IntTensor([text_padded.size(1)]*
                                         batch_size).cuda().long()
     else:
-        texts = ['The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves.']
-        texts = texts[:][:text_len]
-        text_padded, input_lengths = prepare_input_sequence(texts)
+        text = 'The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves. '*2
+        text = [text[:text_len]]
+        text_padded, input_lengths = prepare_input_sequence(text)
 
     return (text_padded, input_lengths)
 
@@ -106,6 +107,10 @@ def main():
 
     log_file = os.path.join(args.output, args.log_file)
 
+    torch.manual_seed(1234)
+    random.seed(1234)
+    np.random.seed(1234)
+
     DLLogger.init(backends=[JSONStreamBackend(Verbosity.DEFAULT, log_file),
                             StdOutBackend(Verbosity.VERBOSE)])
     for k,v in vars(args).items():
@@ -129,8 +134,8 @@ def main():
     if args.model_name == "Tacotron2":
         model = torch.jit.script(model)
 
-    warmup_iters = 3
-    num_iters = 1+warmup_iters
+    warmup_iters = 6
+    num_iters = warmup_iters + 1
 
     for i in range(num_iters):