|
|
@@ -222,6 +222,8 @@ class Seq2SeqTrainer:
|
|
|
|
|
|
batch_size = data_loader.batch_size
|
|
|
|
|
|
+ if self.device.type == 'cuda':
|
|
|
+ torch.cuda.synchronize()
|
|
|
end = time.time()
|
|
|
for i, (src, tgt) in enumerate(data_loader):
|
|
|
self.save_counter += 1
|
|
|
@@ -241,12 +243,14 @@ class Seq2SeqTrainer:
|
|
|
losses_per_sentence.update(loss_per_sentence, batch_size)
|
|
|
|
|
|
# measure elapsed time
|
|
|
+ if self.device.type == 'cuda':
|
|
|
+ torch.cuda.synchronize()
|
|
|
elapsed = time.time() - end
|
|
|
batch_time.update(elapsed)
|
|
|
- src_tok_time.update(num_toks['src'] / elapsed)
|
|
|
- tgt_tok_time.update(num_toks['tgt'] / elapsed)
|
|
|
+ src_tok_time.update(num_toks['src'] / elapsed, elapsed)
|
|
|
+ tgt_tok_time.update(num_toks['tgt'] / elapsed, elapsed)
|
|
|
tot_num_toks = num_toks['tgt'] + num_toks['src']
|
|
|
- tot_tok_time.update(tot_num_toks / elapsed)
|
|
|
+ tot_tok_time.update(tot_num_toks / elapsed, elapsed)
|
|
|
self.loss = losses_per_token.avg
|
|
|
|
|
|
if training and i in eval_iters:
|
|
|
@@ -298,6 +302,8 @@ class Seq2SeqTrainer:
|
|
|
if rank == 0:
|
|
|
self.save(identifier=identifier)
|
|
|
|
|
|
+ if self.device.type == 'cuda':
|
|
|
+ torch.cuda.synchronize()
|
|
|
end = time.time()
|
|
|
|
|
|
tot_tok_time.reduce('sum')
|