소스 검색

Merge: [nnUNet/PyT] Add torch.cuda.synchronize call while benchmarking

Krzysztof Kudrynski 3 년 전
부모
커밋
ff2f1201d9
1개의 변경된 파일2개의 추가작업 그리고 0개의 파일을 삭제
  1. 2 0
      PyTorch/Segmentation/nnUNet/utils/logger.py

+ 2 - 0
PyTorch/Segmentation/nnUNet/utils/logger.py

@@ -17,6 +17,7 @@ import time
 
 
 import dllogger as logger
 import dllogger as logger
 import numpy as np
 import numpy as np
+import torch
 from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
 from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
 from pytorch_lightning import Callback
 from pytorch_lightning import Callback
 from pytorch_lightning.utilities import rank_zero_only
 from pytorch_lightning.utilities import rank_zero_only
@@ -70,6 +71,7 @@ class LoggingCallback(Callback):
         if self.step > self.warmup_steps:
         if self.step > self.warmup_steps:
             self.step += 1
             self.step += 1
             return
             return
+        torch.cuda.synchronize()
         self.timestamps.append(time.perf_counter())
         self.timestamps.append(time.perf_counter())
 
 
     def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
     def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):