Просмотр исходного кода

[ConvNets/EffNetV2/TF2] model resume fixed

Nima Tajbakhsh 3 лет назад
Родитель
Сommit
cbbcc0ff70
1 измененных файлов с 2 добавлено и 1 удалено
  1. 2 1
      TensorFlow2/Classification/ConvNets/runtime/runner.py

+ 2 - 1
TensorFlow2/Classification/ConvNets/runtime/runner.py

@@ -282,7 +282,8 @@ class Runner(object):
 
 
             target_img_size = self.params.train_img_size
             target_img_size = self.params.train_img_size
             epochs_per_stage = train_epochs // n_stages 
             epochs_per_stage = train_epochs // n_stages 
-            for stage in range(resumed_epoch // epochs_per_stage, n_stages):
+            resumed_stage = min(resumed_epoch // epochs_per_stage, n_stages-1)
+            for stage in range(resumed_stage, n_stages):
                 epoch_st = stage * epochs_per_stage 
                 epoch_st = stage * epochs_per_stage 
                 epoch_end = (epoch_st + epochs_per_stage) if stage < n_stages-1 else train_epochs
                 epoch_end = (epoch_st + epochs_per_stage) if stage < n_stages-1 else train_epochs
                 epoch_curr = epoch_st if epoch_st >= resumed_epoch else resumed_epoch
                 epoch_curr = epoch_st if epoch_st >= resumed_epoch else resumed_epoch