Просмотр исходного кода

Merge pull request #67 from GrzegorzKarchNV/master

Update models.py - fix fp16 inference
nvpstr 6 лет назад
Родитель
Сommit
2619f172c7
1 измененных файлов с 1 добавлено и 2 удалено
  1. 1 2
      PyTorch/SpeechSynthesis/Tacotron2/models.py

+ 1 - 2
PyTorch/SpeechSynthesis/Tacotron2/models.py

@@ -74,8 +74,7 @@ def get_model(model_name, model_config, to_fp16, to_cuda, training=True):
         raise NotImplementedError(model_name)
     if to_fp16:
         model = batchnorm_to_float(model.half())
-        if training:
-            model = lstmcell_to_float(model)
+        model = lstmcell_to_float(model)
         if model_name == "WaveGlow":
             for k in model.convinv:
                 k.float()