|
|
@@ -250,7 +250,12 @@ def load_checkpoint(model, optimizer, epoch, config, amp_run, filepath, local_ra
|
|
|
epoch[0] = checkpoint['epoch']+1
|
|
|
device_id = local_rank % torch.cuda.device_count()
|
|
|
torch.cuda.set_rng_state(checkpoint['cuda_rng_state_all'][device_id])
|
|
|
- torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
|
|
|
+ if 'random_rng_states_all' in checkpoint:
|
|
|
+ torch.random.set_rng_state(checkpoint['random_rng_states_all'][device_id])
|
|
|
+ elif 'random_rng_state' in checkpoint:
|
|
|
+ torch.random.set_rng_state(checkpoint['random_rng_state'])
|
|
|
+ else:
|
|
|
+ raise Exception("Model checkpoint must have either 'random_rng_state' or 'random_rng_states_all' key.")
|
|
|
config = checkpoint['config']
|
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|