|
|
@@ -627,10 +627,8 @@ def main(argv):
|
|
|
batch_iter = prefetcher(iter(data_loader_train), data_stream)
|
|
|
|
|
|
for step in range(len(data_loader_train)):
|
|
|
- timer.click()
|
|
|
-
|
|
|
numerical_features, categorical_features, click = next(batch_iter)
|
|
|
- torch.cuda.synchronize()
|
|
|
+ timer.click(synchronize=(device == 'cuda'))
|
|
|
|
|
|
global_step = steps_per_epoch * epoch + step
|
|
|
|
|
|
@@ -773,7 +771,7 @@ def dist_evaluate(model, data_loader):
|
|
|
batch_iter = prefetcher(iter(data_loader), data_stream)
|
|
|
loss_fn = torch.nn.BCELoss(reduction="mean")
|
|
|
|
|
|
- timer.click()
|
|
|
+ timer.click(synchronize=(device=='cuda'))
|
|
|
for step in range(len(data_loader)):
|
|
|
numerical_features, categorical_features, click = next(batch_iter)
|
|
|
torch.cuda.synchronize()
|
|
|
@@ -815,7 +813,7 @@ def dist_evaluate(model, data_loader):
|
|
|
y_true.append(click)
|
|
|
y_score.append(output)
|
|
|
|
|
|
- timer.click()
|
|
|
+ timer.click(synchronize=(device == 'cuda'))
|
|
|
|
|
|
if timer.measured is not None:
|
|
|
metric_logger.update(step_time=timer.measured)
|