Browse Source

correct train.py script and add backward compability for dropblock

Michal Futrega 5 years ago
parent
commit
b741b3831b

+ 6 - 6
PyTorch/Segmentation/nnUNet/README.md

@@ -507,7 +507,7 @@ The following sections provide details on how to achieve the same performance an
 
 ##### Training accuracy: NVIDIA DGX A100 (8x A100 80G)
 
-Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX with (8x A100 80G) GPUs.
+Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX with (8x A100 80G) GPUs.
 
 | Dimension | GPUs | Batch size / GPU  | Accuracy - mixed precision | Accuracy - FP32 | Time to train - mixed precision | Time to train - TF32|  Time to train speedup (TF32 to mixed precision)        
 |:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
@@ -519,7 +519,7 @@ Our results were obtained by running the `python scripts/train.py --gpus {1,8} -
 
 ##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
 
-Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} --batch_size <bsize> [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
+Our results were obtained by running the `python scripts/train.py --gpus {1,8} --fold {0,1,2,3,4} --dim {2,3} [--amp]` training scripts and averaging results in the PyTorch 21.02 NGC container on NVIDIA DGX-1 with (8x V100 16G) GPUs.
 
 | Dimension | GPUs | Batch size / GPU | Accuracy - mixed precision |  Accuracy - FP32 |  Time to train - mixed precision | Time to train - FP32  | Time to train speedup (FP32 to mixed precision)        
 |:-:|:-:|:--:|:-----:|:-----:|:-----:|:-----:|:----:|
@@ -580,7 +580,7 @@ Our results were obtained by running the `python scripts/benchmark.py --mode pre
 
 FP16
 
-| Dimension | Batch size |   Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
+| Dimension | Batch size |  Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
 |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
 | 2 | 64 | 4x192x160 | 3198.8 | 20.01 | 24.1 | 30.5 | 33.75 |
 | 2 | 128 | 4x192x160 | 3587.89 | 35.68 | 36.0 | 36.08 | 36.16 |
@@ -591,7 +591,7 @@ FP16
 
 TF32
 
-| Dimension | Batch size |   Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
+| Dimension | Batch size |  Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
 |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
 | 2 | 64 | 4x192x160 | 2353.27 | 27.2 | 27.43 | 27.53 | 27.7 |
 | 2 | 128 | 4x192x160 | 2492.78 | 51.35 | 51.54 | 51.59 | 51.73 |
@@ -610,7 +610,7 @@ Our results were obtained by running the `python scripts/benchmark.py --mode pre
 
 FP16
  
-| Dimension | Batch size |   Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
+| Dimension | Batch size |  Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
 |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
 | 2 | 64 | 4x192x160 | 1866.52 | 34.29 | 34.7 | 48.87 | 52.44 |
 | 2 | 128 | 4x192x160 | 2032.74 | 62.97 | 63.21 | 63.25 | 63.32 |
@@ -620,7 +620,7 @@ FP16
 
 FP32
  
-| Dimension | Batch size |   Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
+| Dimension | Batch size |  Resolution  | Throughput Avg [img/s] | Latency Avg [ms] | Latency 90% [ms] | Latency 95% [ms] | Latency 99% [ms] |
 |:----------:|:---------:|:-------------:|:----------------------:|:----------------:|:----------------:|:----------------:|:----------------:|
 | 2 | 64 | 4x192x160 | 1051.46 | 60.87 | 61.21 | 61.48 | 62.87 |
 | 2 | 128 | 4x192x160 | 1051.68 | 121.71 | 122.29 | 122.44 | 122.6 |

+ 3 - 29
PyTorch/Segmentation/nnUNet/models/loss.py

@@ -12,42 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import torch
 import torch.nn as nn
-import torch.nn.functional as F
-from monai.losses import FocalLoss
-
-
-class DiceLoss(nn.Module):
-    def __init__(self, include_background=False, smooth=1e-5, eps=1e-7):
-        super(DiceLoss, self).__init__()
-        self.include_background = include_background
-        self.smooth = smooth
-        self.dims = (0, 2)
-        self.eps = eps
-
-    def forward(self, y_pred, y_true):
-        num_classes, batch_size = y_pred.size(1), y_true.size(0)
-        y_pred = y_pred.log_softmax(dim=1).exp()
-        y_true, y_pred = y_true.view(batch_size, -1), y_pred.view(batch_size, num_classes, -1)
-        y_true = F.one_hot(y_true.to(torch.int64), num_classes).permute(0, 2, 1)
-        if not self.include_background:
-            y_true, y_pred = y_true[:, 1:], y_pred[:, 1:]
-        intersection = torch.sum(y_true * y_pred, dim=self.dims)
-        cardinality = torch.sum(y_true + y_pred, dim=self.dims)
-        dice_loss = 1 - (2.0 * intersection + self.smooth) / (cardinality + self.smooth).clamp_min(self.eps)
-        mask = (y_true.sum(self.dims) > 0).to(dice_loss.dtype)
-        dice_loss *= mask.to(dice_loss.dtype)
-        dice_loss = dice_loss.sum() / mask.sum()
-        return dice_loss
+from monai.losses import DiceLoss, FocalLoss
 
 
 class Loss(nn.Module):
     def __init__(self, focal):
         super(Loss, self).__init__()
-        self.dice = DiceLoss()
-        self.cross_entropy = nn.CrossEntropyLoss()
+        self.dice = DiceLoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)
         self.focal = FocalLoss(gamma=2.0)
+        self.cross_entropy = nn.CrossEntropyLoss()
         self.use_focal = focal
 
     def forward(self, y_pred, y_true):

+ 2 - 0
PyTorch/Segmentation/nnUNet/models/nn_unet.py

@@ -42,6 +42,8 @@ class NNUnet(pl.LightningModule):
     def __init__(self, args):
         super(NNUnet, self).__init__()
         self.args = args
+        if not hasattr(self.args, "drop_block"):  # For backward compability
+            self.args.drop_block = False
         self.save_hyperparameters()
         self.build_nnunet()
         self.loss = Loss(self.args.focal)

+ 5 - 1
PyTorch/Segmentation/nnUNet/scripts/train.py

@@ -24,11 +24,15 @@ parser.add_argument("--fold", type=int, required=True, choices=[0, 1, 2, 3, 4],
 parser.add_argument("--dim", type=int, required=True, choices=[2, 3], help="Dimension of UNet")
 parser.add_argument("--amp", action="store_true", help="Enable automatic mixed precision")
 parser.add_argument("--tta", action="store_true", help="Enable test time augmentation")
+parser.add_argument("--results", type=str, default="/results", help="Path to results directory")
+parser.add_argument("--logname", type=str, default="log", help="Name of dlloger output")
 
 if __name__ == "__main__":
     args = parser.parse_args()
     path_to_main = os.path.join(dirname(dirname(os.path.realpath(__file__))), "main.py")
-    cmd = f"python {path_to_main} --exec_mode train --task {args.data} --deep_supervision --save_ckpt "
+    cmd = f"python {path_to_main} --exec_mode train --task {args.task} --deep_supervision --save_ckpt "
+    cmd += f"--results {args.results} "
+    cmd += f"--logname {args.logname} "
     cmd += f"--dim {args.dim} "
     cmd += f"--batch_size {2 if args.dim == 3 else 64} "
     cmd += f"--val_batch_size {4 if args.dim == 3 else 64} "