|
|
@@ -22,16 +22,15 @@ from apex.optimizers import FusedAdam, FusedSGD
|
|
|
from data_loading.data_module import get_data_path, get_test_fnames
|
|
|
from monai.inferers import sliding_window_inference
|
|
|
from monai.networks.nets import DynUNet
|
|
|
+from nnunet.brats22_model import UNet3D
|
|
|
+from nnunet.loss import Loss, LossBraTS
|
|
|
+from nnunet.metrics import Dice
|
|
|
from pytorch_lightning.utilities import rank_zero_only
|
|
|
from scipy.special import expit, softmax
|
|
|
from skimage.transform import resize
|
|
|
from utils.logger import DLLogger
|
|
|
from utils.utils import get_config_file, print0
|
|
|
|
|
|
-from nnunet.brats22_model import UNet3D
|
|
|
-from nnunet.loss import Loss, LossBraTS
|
|
|
-from nnunet.metrics import Dice
|
|
|
-
|
|
|
|
|
|
class NNUnet(pl.LightningModule):
|
|
|
def __init__(self, args, triton=False, data_dir=None):
|
|
|
@@ -279,7 +278,7 @@ class NNUnet(pl.LightningModule):
|
|
|
|
|
|
@rank_zero_only
|
|
|
def on_fit_end(self):
|
|
|
- if not self.args.benchmark and self.args.skip_first_n_eval == 0:
|
|
|
+ if not self.args.benchmark:
|
|
|
metrics = {}
|
|
|
metrics["dice_score"] = round(self.best_mean.item(), 2)
|
|
|
metrics["train_loss"] = round(sum(self.train_loss) / len(self.train_loss), 4)
|