Преглед изворни кода

[nnUNet/PyT] Fix case with checkpoint path set to None

Michal Futrega пре 3 година
родитељ
комит
20bda775a7

+ 2 - 3
PyTorch/Segmentation/nnUNet/main.py

@@ -15,13 +15,12 @@
 import os
 
 import torch
+from data_loading.data_module import DataModule
+from nnunet.nn_unet import NNUnet
 from pytorch_lightning import Trainer, seed_everything
 from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary, RichProgressBar
 from pytorch_lightning.plugins.io import AsyncCheckpointIO
 from pytorch_lightning.strategies import DDPStrategy
-
-from data_loading.data_module import DataModule
-from nnunet.nn_unet import NNUnet
 from utils.args import get_main_args
 from utils.logger import LoggingCallback
 from utils.utils import make_empty_dir, set_cuda_devices, set_granularity, verify_ckpt_path

+ 4 - 5
PyTorch/Segmentation/nnUNet/nnunet/nn_unet.py

@@ -22,16 +22,15 @@ from apex.optimizers import FusedAdam, FusedSGD
 from data_loading.data_module import get_data_path, get_test_fnames
 from monai.inferers import sliding_window_inference
 from monai.networks.nets import DynUNet
+from nnunet.brats22_model import UNet3D
+from nnunet.loss import Loss, LossBraTS
+from nnunet.metrics import Dice
 from pytorch_lightning.utilities import rank_zero_only
 from scipy.special import expit, softmax
 from skimage.transform import resize
 from utils.logger import DLLogger
 from utils.utils import get_config_file, print0
 
-from nnunet.brats22_model import UNet3D
-from nnunet.loss import Loss, LossBraTS
-from nnunet.metrics import Dice
-
 
 class NNUnet(pl.LightningModule):
     def __init__(self, args, triton=False, data_dir=None):
@@ -279,7 +278,7 @@ class NNUnet(pl.LightningModule):
 
     @rank_zero_only
     def on_fit_end(self):
-        if not self.args.benchmark and self.args.skip_first_n_eval == 0:
+        if not self.args.benchmark:
             metrics = {}
             metrics["dice_score"] = round(self.best_mean.item(), 2)
             metrics["train_loss"] = round(sum(self.train_loss) / len(self.train_loss), 4)

+ 1 - 1
PyTorch/Segmentation/nnUNet/utils/utils.py

@@ -58,7 +58,7 @@ def verify_ckpt_path(args):
             return resume_path_results
         print("[Warning] Checkpoint not found. Starting training from scratch.")
         return None
-    if not os.path.isfile(args.ckpt_path):
+    if args.ckpt_path is None or not os.path.isfile(args.ckpt_path):
         print(f"Provided checkpoint {args.ckpt_path} is not a file. Starting training from scratch.")
         return None
     return args.ckpt_path