Przeglądaj źródła

Merge: [Transformer/PyTorch] Add torch.cuda.synchronize() calls

Krzysztof Kudrynski 3 lat temu
rodzic
commit
d46a356f06

+ 5 - 0
PyTorch/Translation/Transformer/fairseq/log_helper.py

@@ -8,6 +8,7 @@ from collections import OrderedDict
 import dllogger
 from dllogger import Backend, JSONStreamBackend
 from tensorboardX import SummaryWriter
+import torch
 
 
 class AverageMeter():
@@ -43,6 +44,7 @@ class PerformanceMeter():
 
     def reset(self):
         self.updated = False
+        torch.cuda.synchronize()
         self.start = time.time()
         self.n = 0
 
@@ -56,6 +58,7 @@ class PerformanceMeter():
 
     @property
     def elapsed_time(self):
+        torch.cuda.synchronize()
         return time.time() - self.start
 
 
@@ -70,6 +73,7 @@ class AggregatorBackend(Backend):
         self.metrics.flushed = True
         self.step = 0
         self.epoch = 0
+        torch.cuda.synchronize()
         self.start_time = time.time()
 
     @property
@@ -115,6 +119,7 @@ class AggregatorBackend(Backend):
                 result_string += _name + ' {:.3f} |'.format(agg.value)
                 agg.reset()
 
+        torch.cuda.synchronize()
         result_string += 'walltime {:.3f} |'.format(time.time() - self.start_time)
         self.metrics.flushed = True
         print(result_string)

+ 6 - 0
PyTorch/Translation/Transformer/fairseq/meters.py

@@ -6,6 +6,7 @@
 # can be found in the PATENTS file in the same directory.
 
 import time
+import torch
 
 
 class AverageMeter(object):
@@ -33,12 +34,14 @@ class TimeMeter(object):
 
     def reset(self, init=0):
         self.init = init
+        torch.cuda.synchronize()
         self.start = time.time()
         self.n = 0
         self.last_update = time.time()
 
     def update(self, val=1):
         self.n += val
+        torch.cuda.synchronize()
         self.last_update = time.time()
 
     @property
@@ -47,6 +50,7 @@ class TimeMeter(object):
 
     @property
     def elapsed_time(self):
+        torch.cuda.synchronize()
         return self.init + (time.time() - self.start)
 
     @property
@@ -61,9 +65,11 @@ class StopwatchMeter(object):
         self.intervals = []
 
     def start(self):
+        torch.cuda.synchronize()
         self.start_time = time.time()
 
     def stop(self, n=1):
+        torch.cuda.synchronize()
         if self.start_time is not None:
             delta = time.time() - self.start_time
             self.intervals.append(delta)

+ 6 - 0
PyTorch/Translation/Transformer/inference.py

@@ -151,6 +151,7 @@ def main(args):
 
     use_cuda = torch.cuda.is_available() and not args.cpu
 
+    torch.cuda.synchronize()
     processing_start = time.time()
 
     # Load ensemble
@@ -229,7 +230,9 @@ def main(args):
             tokens = tokens.cuda()
             lengths = lengths.cuda()
 
+        torch.cuda.synchronize()
         translation_start = time.time()
+
         gen_timer.start()
         translations = translator.generate(
             tokens,
@@ -237,6 +240,8 @@ def main(args):
             maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
         )
         gen_timer.stop(sum(len(h[0]['tokens']) for h in translations))
+
+        torch.cuda.synchronize()
         dllogger.log(step='infer', data={'latency': time.time() - translation_start})
 
         return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
@@ -262,6 +267,7 @@ def main(args):
     if args.file:
         data_descriptor.close()
 
+    torch.cuda.synchronize()
     log_dict = {
                 'throughput': 1./gen_timer.avg,
                 'latency_avg': sum(gen_timer.intervals)/len(gen_timer.intervals),

+ 4 - 0
PyTorch/Translation/Transformer/train.py

@@ -164,6 +164,7 @@ def train(args, trainer, epoch_itr):
 
     max_update = args.max_update or math.inf
     num_batches = len(epoch_itr)
+    torch.cuda.synchronize()
     begin = time.time()
 
     # reset meters
@@ -189,6 +190,7 @@ def train(args, trainer, epoch_itr):
         if trainer.get_num_updates() >= max_update:
             break
 
+    torch.cuda.synchronize()
     print('Epoch time:', time.time() - begin)
 
     # Print epoch stats and reset training meters
@@ -235,6 +237,7 @@ def validate(args, trainer, datasets, subsets):
 
 def score(args, trainer, dataset, src_dict, tgt_dict, ref_file):
 
+    torch.cuda.synchronize()
     begin = time.time()
 
     src_dict = deepcopy(src_dict)  # This is necessary, generation of translations
@@ -324,6 +327,7 @@ def score(args, trainer, dataset, src_dict, tgt_dict, ref_file):
             float(args.distributed_world_size)/gen_timer.avg
             ))
 
+    torch.cuda.synchronize()
     print('| Eval completed in: {:.2f}s | {}CASED BLEU {:.2f}'.format(
         time.time()-begin,
         '' if args.test_cased_bleu else 'UN',