|
|
@@ -17,6 +17,7 @@ import time
|
|
|
|
|
|
import dllogger as logger
|
|
|
import numpy as np
|
|
|
+import torch
|
|
|
from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
|
|
|
from pytorch_lightning import Callback
|
|
|
from pytorch_lightning.utilities import rank_zero_only
|
|
|
@@ -70,6 +71,7 @@ class LoggingCallback(Callback):
|
|
|
if self.step > self.warmup_steps:
|
|
|
self.step += 1
|
|
|
return
|
|
|
+ torch.cuda.synchronize()
|
|
|
self.timestamps.append(time.perf_counter())
|
|
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|