Update models.py - fix fp16 inference
@@ -74,8 +74,7 @@ def get_model(model_name, model_config, to_fp16, to_cuda, training=True):
raise NotImplementedError(model_name)
if to_fp16:
model = batchnorm_to_float(model.half())
- if training:
- model = lstmcell_to_float(model)
+ model = lstmcell_to_float(model)
if model_name == "WaveGlow":
for k in model.convinv:
k.float()