Explorar el Código

[nnUNet/PyT] Add torch.cuda.synchronize call while benchmarking

Michal Futrega hace 3 años
padre
commit
9970904a94
Se han modificado 1 ficheros con 2 adiciones y 0 borrados
  1. 2 0
      PyTorch/Segmentation/nnUNet/utils/logger.py

+ 2 - 0
PyTorch/Segmentation/nnUNet/utils/logger.py

@@ -17,6 +17,7 @@ import time
 
 import dllogger as logger
 import numpy as np
+import torch
 from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
 from pytorch_lightning import Callback
 from pytorch_lightning.utilities import rank_zero_only
@@ -70,6 +71,7 @@ class LoggingCallback(Callback):
         if self.step > self.warmup_steps:
             self.step += 1
             return
+        torch.cuda.synchronize()
         self.timestamps.append(time.perf_counter())
 
     def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):