Преглед изворни кода

[nnUNet/PyT] Add torch.cuda.synchronize call while benchmarking

Michal Futrega пре 3 година
родитељ
комит
9970904a94
1 измењених фајлова са 2 додато и 0 уклоњено
  1. 2 0
      PyTorch/Segmentation/nnUNet/utils/logger.py

+ 2 - 0
PyTorch/Segmentation/nnUNet/utils/logger.py

@@ -17,6 +17,7 @@ import time
 
 
 import dllogger as logger
 import dllogger as logger
 import numpy as np
 import numpy as np
+import torch
 from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
 from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
 from pytorch_lightning import Callback
 from pytorch_lightning import Callback
 from pytorch_lightning.utilities import rank_zero_only
 from pytorch_lightning.utilities import rank_zero_only
@@ -70,6 +71,7 @@ class LoggingCallback(Callback):
         if self.step > self.warmup_steps:
         if self.step > self.warmup_steps:
             self.step += 1
             self.step += 1
             return
             return
+        torch.cuda.synchronize()
         self.timestamps.append(time.perf_counter())
         self.timestamps.append(time.perf_counter())
 
 
     def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
     def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):