nn_unet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. import pytorch_lightning as pl
  17. import torch
  18. import torch.nn as nn
  19. import torch_optimizer as optim
  20. from dllogger import JSONStreamBackend, Logger, StdOutBackend, Verbosity
  21. from monai.inferers import sliding_window_inference
  22. from utils.utils import flip, get_config_file, is_main_process
  23. from models.metrics import Dice, Loss
  24. from models.unet import UNet
  25. class NNUnet(pl.LightningModule):
  26. def __init__(self, args):
  27. super(NNUnet, self).__init__()
  28. self.args = args
  29. self.save_hyperparameters()
  30. self.build_nnunet()
  31. self.loss = Loss()
  32. self.dice = Dice(self.n_class)
  33. self.best_sum = 0
  34. self.eval_dice = 0
  35. self.best_sum_epoch = 0
  36. self.best_dice = self.n_class * [0]
  37. self.best_epoch = self.n_class * [0]
  38. self.best_sum_dice = self.n_class * [0]
  39. self.learning_rate = args.learning_rate
  40. if self.args.exec_mode in ["train", "evaluate"]:
  41. self.dllogger = Logger(
  42. backends=[
  43. JSONStreamBackend(Verbosity.VERBOSE, os.path.join(args.results, "logs.json")),
  44. StdOutBackend(Verbosity.VERBOSE, step_format=lambda step: f"Epoch: {step} "),
  45. ]
  46. )
  47. self.tta_flips = (
  48. [[2], [3], [2, 3]] if self.args.dim == 2 else [[2], [3], [4], [2, 3], [2, 4], [3, 4], [2, 3, 4]]
  49. )
  50. def forward(self, img):
  51. if self.args.benchmark:
  52. return self.model(img)
  53. return self.tta_inference(img) if self.args.tta else self.do_inference(img)
  54. def training_step(self, batch, batch_idx):
  55. img, lbl = batch["image"], batch["label"]
  56. if self.args.dim == 2 and len(lbl.shape) == 3:
  57. lbl = lbl.unsqueeze(1)
  58. pred = self.model(img)
  59. loss = self.compute_loss(pred, lbl)
  60. return loss
  61. def validation_step(self, batch, batch_idx):
  62. img, lbl = batch["image"], batch["label"]
  63. if self.args.dim == 2 and len(lbl.shape) == 3:
  64. lbl = lbl.unsqueeze(1)
  65. pred = self.forward(img)
  66. loss = self.loss(pred, lbl)
  67. dice = self.dice(pred, lbl[:, 0])
  68. return {"val_loss": loss, "val_dice": dice}
  69. def test_step(self, batch, batch_idx):
  70. if self.args.exec_mode == "evaluate":
  71. return self.validation_step(batch, batch_idx)
  72. img = batch["image"]
  73. pred = self.forward(img)
  74. if self.args.save_preds:
  75. self.save_mask(pred, batch["fname"])
  76. def build_unet(self, in_channels, n_class, kernels, strides):
  77. return UNet(
  78. in_channels=in_channels,
  79. n_class=n_class,
  80. kernels=kernels,
  81. strides=strides,
  82. normalization_layer=self.args.norm,
  83. negative_slope=self.args.negative_slope,
  84. deep_supervision=self.args.deep_supervision,
  85. dimension=self.args.dim,
  86. )
  87. def get_unet_params(self):
  88. config = get_config_file(self.args)
  89. in_channels = config["in_channels"]
  90. patch_size = config["patch_size"]
  91. spacings = config["spacings"]
  92. n_class = config["n_class"]
  93. strides, kernels, sizes = [], [], patch_size[:]
  94. while True:
  95. spacing_ratio = [spacing / min(spacings) for spacing in spacings]
  96. stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
  97. kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
  98. if all(s == 1 for s in stride):
  99. break
  100. sizes = [i / j for i, j in zip(sizes, stride)]
  101. spacings = [i * j for i, j in zip(spacings, stride)]
  102. kernels.append(kernel)
  103. strides.append(stride)
  104. if len(strides) == 5:
  105. break
  106. strides.insert(0, len(spacings) * [1])
  107. kernels.append(len(spacings) * [3])
  108. return in_channels, n_class, kernels, strides, patch_size
  109. def build_nnunet(self):
  110. in_channels, n_class, kernels, strides, self.patch_size = self.get_unet_params()
  111. self.model = self.build_unet(in_channels, n_class, kernels, strides)
  112. self.n_class = n_class - 1
  113. if is_main_process():
  114. print(f"Filters: {self.model.filters}")
  115. print(f"Kernels: {kernels}")
  116. print(f"Strides: {strides}")
  117. def compute_loss(self, preds, label):
  118. if self.args.deep_supervision:
  119. loss = self.loss(preds[0], label)
  120. for i, pred in enumerate(preds[1:]):
  121. downsampled_label = nn.functional.interpolate(label, pred.shape[2:])
  122. loss += 0.5 ** (i + 1) * self.loss(pred, downsampled_label)
  123. c_norm = 1 / (2 - 2 ** (-len(preds)))
  124. return c_norm * loss
  125. return self.loss(preds, label)
  126. def do_inference(self, image):
  127. if self.args.dim == 2:
  128. if self.args.data2d_dim == 2:
  129. return self.model(image)
  130. if self.args.exec_mode == "predict" and not self.args.benchmark:
  131. return self.inference2d_test(image)
  132. return self.inference2d(image)
  133. return self.sliding_window_inference(image)
  134. def tta_inference(self, img):
  135. pred = self.do_inference(img)
  136. for flip_idx in self.tta_flips:
  137. pred += flip(self.do_inference(flip(img, flip_idx)), flip_idx)
  138. pred /= len(self.tta_flips) + 1
  139. return pred
  140. def inference2d(self, image):
  141. batch_modulo = image.shape[2] % self.args.val_batch_size
  142. if self.args.benchmark:
  143. image = image[:, :, batch_modulo:]
  144. elif batch_modulo != 0:
  145. batch_pad = self.args.val_batch_size - batch_modulo
  146. image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image)
  147. image = torch.transpose(image.squeeze(0), 0, 1)
  148. preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
  149. preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device)
  150. for start in range(0, image.shape[0] - self.args.val_batch_size + 1, self.args.val_batch_size):
  151. end = start + self.args.val_batch_size
  152. pred = self.model(image[start:end])
  153. preds[start:end] = pred.data
  154. if batch_modulo != 0 and not self.args.benchmark:
  155. preds = preds[batch_pad:]
  156. return torch.transpose(preds, 0, 1).unsqueeze(0)
  157. def inference2d_test(self, image):
  158. preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:])
  159. preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device)
  160. for depth in range(image.shape[2]):
  161. preds[:, :, depth] = self.sliding_window_inference(image[:, :, depth])
  162. return preds
  163. def sliding_window_inference(self, image):
  164. return sliding_window_inference(
  165. inputs=image,
  166. roi_size=self.patch_size,
  167. sw_batch_size=self.args.val_batch_size,
  168. predictor=self.model,
  169. overlap=self.args.overlap,
  170. mode=self.args.val_mode,
  171. )
  172. @staticmethod
  173. def metric_mean(name, outputs):
  174. return torch.stack([out[name] for out in outputs]).mean(dim=0)
  175. def validation_epoch_end(self, outputs):
  176. loss = self.metric_mean("val_loss", outputs)
  177. dice = 100 * self.metric_mean("val_dice", outputs)
  178. dice_sum = torch.sum(dice)
  179. if dice_sum >= self.best_sum:
  180. self.best_sum = dice_sum
  181. self.best_sum_dice = dice[:]
  182. self.best_sum_epoch = self.current_epoch
  183. for i, dice_i in enumerate(dice):
  184. if dice_i > self.best_dice[i]:
  185. self.best_dice[i], self.best_epoch[i] = dice_i, self.current_epoch
  186. if is_main_process():
  187. metrics = {}
  188. metrics.update({"mean dice": round(torch.mean(dice).item(), 2)})
  189. metrics.update({"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)})
  190. metrics.update({f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice)})
  191. metrics.update({f"TOP_L{i+1}": round(m.item(), 2) for i, m in enumerate(self.best_sum_dice)})
  192. metrics.update({"val_loss": round(loss.item(), 4)})
  193. self.dllogger.log(step=self.current_epoch, data=metrics)
  194. self.dllogger.flush()
  195. self.log("val_loss", loss)
  196. self.log("dice_sum", dice_sum)
  197. def test_epoch_end(self, outputs):
  198. if self.args.exec_mode == "evaluate":
  199. self.eval_dice = 100 * self.metric_mean("val_dice", outputs)
  200. def configure_optimizers(self):
  201. optimizer = {
  202. "sgd": torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum),
  203. "adam": torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
  204. "adamw": torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
  205. "radam": optim.RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay),
  206. }[self.args.optimizer.lower()]
  207. scheduler = {
  208. "none": None,
  209. "multistep": torch.optim.lr_scheduler.MultiStepLR(optimizer, self.args.steps, gamma=self.args.factor),
  210. "cosine": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.args.max_epochs),
  211. "plateau": torch.optim.lr_scheduler.ReduceLROnPlateau(
  212. optimizer, factor=self.args.factor, patience=self.args.lr_patience
  213. ),
  214. }[self.args.scheduler.lower()]
  215. opt_dict = {"optimizer": optimizer, "monitor": "val_loss"}
  216. if scheduler is not None:
  217. opt_dict.update({"lr_scheduler": scheduler})
  218. return opt_dict
  219. def save_mask(self, pred, fname):
  220. fname = str(fname[0].cpu().detach().numpy(), "utf-8").replace("_x", "_pred")
  221. pred = nn.functional.softmax(torch.tensor(pred), dim=1)
  222. pred = pred.squeeze().cpu().detach().numpy()
  223. np.save(os.path.join(self.save_dir, fname), pred, allow_pickle=False)