Sfoglia il codice sorgente

[DLRM/PyT] Add synchronization for more reliable benchmarking

Tomasz Grel 3 anni fa
parent
commit
f5111db0d5

+ 3 - 5
PyTorch/Recommendation/DLRM/dlrm/scripts/main.py

@@ -627,10 +627,8 @@ def main(argv):
         batch_iter = prefetcher(iter(data_loader_train), data_stream)
 
         for step in range(len(data_loader_train)):
-            timer.click()
-
             numerical_features, categorical_features, click = next(batch_iter)
-            torch.cuda.synchronize()
+            timer.click(synchronize=(device == 'cuda'))
 
             global_step = steps_per_epoch * epoch + step
 
@@ -773,7 +771,7 @@ def dist_evaluate(model, data_loader):
         batch_iter = prefetcher(iter(data_loader), data_stream)
         loss_fn = torch.nn.BCELoss(reduction="mean")
 
-        timer.click()
+        timer.click(synchronize=(device=='cuda'))
         for step in range(len(data_loader)):
             numerical_features, categorical_features, click = next(batch_iter)
             torch.cuda.synchronize()
@@ -815,7 +813,7 @@ def dist_evaluate(model, data_loader):
             y_true.append(click)
             y_score.append(output)
 
-            timer.click()
+            timer.click(synchronize=(device == 'cuda'))
 
             if timer.measured is not None:
                 metric_logger.update(step_time=timer.measured)

+ 4 - 1
PyTorch/Recommendation/DLRM/dlrm/scripts/utils.py

@@ -210,8 +210,11 @@ class StepTimer():
         self._new = None
         self.measured = None
 
-    def click(self):
+    def click(self, synchronize=False):
         self._previous = self._new
+
+        if synchronize:
+            torch.cuda.synchronize()
         self._new = time.time()
 
         if self._previous is not None: