Parcourir la source

[BART/PyT] Add synchronize for benchmarking

Bobby Chen il y a 3 ans
Parent
commit
a7972142c3

+ 2 - 0
PyTorch/LanguageModeling/BART/run_eval.py

@@ -160,6 +160,7 @@ def generate_summaries_or_translations(
     results = []
     with torch.no_grad():
         for batch in tqdm(data_loader):
+            torch.cuda.synchronize()
             t0 = time.time()
 
             summaries = model.generate(
@@ -180,6 +181,7 @@ def generate_summaries_or_translations(
             if num_return_sequences > 1:
                 preds = chunks(preds, num_return_sequences)  # batch size chunks, each of size num_return_seq
 
+            torch.cuda.synchronize()
             eval_time = time.time() - t0
             for i, pred in enumerate(preds):
                 store_time = eval_time if i == 0 else None #only store latency for element 0 of every batch

+ 2 - 0
PyTorch/LanguageModeling/BART/training_base.py

@@ -410,9 +410,11 @@ def generic_train(
         for batch in dataloader:
             batch = {k: v.to(device) for k, v in batch.items()}
             local_step += 1
+            torch.cuda.synchronize()
             iter_start = time.time()
 
             total_loss, logs = train_one_step(args, trainer, optimizer, scheduler, batch, local_step, scaler)
+            torch.cuda.synchronize()
             train_perf = logs["bs"] * get_world_size() / (time.time() - iter_start)