Просмотр исходного кода

Merge: [nnUNet/PyT] fix checkpointing

Krzysztof Kudrynski 4 лет назад
Родитель
Сommit
c5a2c85efc
1 измененных файлов с 2 добавлено и 1 удалено
  1. 2 1
      PyTorch/Segmentation/nnUNet/main.py

+ 2 - 1
PyTorch/Segmentation/nnUNet/main.py

@@ -35,7 +35,7 @@ if __name__ == "__main__":
         print("Profiling enabled")
 
     if args.affinity != "disabled":
-        affinity = set_affinity(int(os.getenv("LOCAL_RANK", "0")), args.gpus, mode=args.affinity)
+        set_affinity(int(os.getenv("LOCAL_RANK", "0")), args.gpus, mode=args.affinity)
 
     # Limit number of CPU threads
     os.environ["OMP_NUM_THREADS"] = "1"
@@ -100,6 +100,7 @@ if __name__ == "__main__":
         default_root_dir=args.results,
         resume_from_checkpoint=ckpt_path,
         accelerator="ddp" if args.gpus > 1 else None,
+        checkpoint_callback=args.save_ckpt,
         limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
         limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
         limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,