Przeglądaj źródła

fixing rng_state for backward compatibility

gkarch 5 lat temu
rodzic
commit
9a6c5241d7
1 zmienionych plików z 6 dodań i 1 usunięć
  1. 6 1
      PyTorch/SpeechSynthesis/Tacotron2/train.py

+ 6 - 1
PyTorch/SpeechSynthesis/Tacotron2/train.py

@@ -250,7 +250,12 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
     epoch[0] = checkpoint['epoch']+1
     device_id = local_rank % torch.cuda.device_count()
     torch.cuda.set_rng_state(checkpoint['cuda_rng_state_all'][device_id])
-    torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
+    if 'random_rng_states_all' in checkpoint:
+        torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
+    elif 'random_rng_state' in checkpoint:
+        torch.random.set_rng_state(checkpoint['random_rng_state'])
+    else:
+        raise Exception("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key.")
     config = checkpoint['config']
     model.load_state_dict(checkpoint['state_dict'])
     optimizer.load_state_dict(checkpoint['optimizer'])