Browse Source

[TFT/PyTorch] Added torch.cuda.synchronize() calls in perf meter

Jan Baczek 3 years ago
parent
commit
d190b25f31

+ 2 - 1
PyTorch/Forecasting/TFT/inference.py

@@ -76,7 +76,7 @@ def predict(args, config, model, data_loader, scalers, cat_encodings, extend_tar
     predictions = []
     targets = []
     ids = []
-    perf_meter = PerformanceMeter()
+    perf_meter = PerformanceMeter(benchmark_mode=not args.disable_benchmark)
     n_workers = args.distributed_world_size if hasattr(args, 'distributed_world_size') else 1
 
     for step, batch in enumerate(data_loader):
@@ -235,5 +235,6 @@ if __name__=='__main__':
     parser.add_argument('--save_predictions', action='store_true')
     parser.add_argument('--results', type=str, default='/results')
     parser.add_argument('--log_file', type=str, default='dllogger.json')
+    parser.add_argument("--disable_benchmark", action='store_true', help='Disable benchmarking mode')
     ARGS = parser.parse_args()
     main(ARGS)

+ 4 - 1
PyTorch/Forecasting/TFT/train.py

@@ -129,7 +129,7 @@ def main(args):
     train_loader, valid_loader, test_loader = load_dataset(args, config)
 
     global_step = 0
-    perf_meter = PerformanceMeter()
+    perf_meter = PerformanceMeter(benchmark_mode=not args.disable_benchmark)
 
     for epoch in range(args.epochs):
         start = time.time()
@@ -209,6 +209,7 @@ def validate(args, config, model, criterion, dataloader, global_step):
     model.eval()
 
     losses = []
+    torch.cuda.synchronize()
     validation_start = time.time()
     for batch in dataloader:
         with torch.no_grad():
@@ -219,6 +220,7 @@ def validate(args, config, model, criterion, dataloader, global_step):
             bs = next(t for t in batch.values() if t is not None).shape[0]
             losses.append((p_losses, bs))
 
+    torch.cuda.synchronize()
     validation_end = time.time()
 
     p_losses = sum([l[0]*l[1] for l in losses])/sum([l[1] for l in losses]) #takes into accunt that the last batch is not full
@@ -280,6 +282,7 @@ if __name__ == '__main__':
                                   'disabled'],
                          help='type of CPU affinity')
     parser.add_argument("--ema_decay", type=float, default=0.0, help='Use exponential moving average')
+    parser.add_argument("--disable_benchmark", action='store_true', help='Disable benchmarking mode')
 
 
     ARGS = parser.parse_args()

+ 14 - 1
PyTorch/Forecasting/TFT/utils.py

@@ -13,12 +13,17 @@
 # limitations under the License.
 
 import time
+import torch.distributed as dist
+import torch
 
 class PerformanceMeter():
-    def __init__(self):
+    def __init__(self, benchmark_mode=True):
+        self.benchmark_mode = benchmark_mode
         self.reset()
 
     def reset(self):
+        if self.benchmark_mode:
+            torch.cuda.synchronize()
         self.avg = 0
         self.count = 0
         self.total_time = 0
@@ -26,6 +31,8 @@ class PerformanceMeter():
         self.intervals = []
 
     def update(self, n, exclude_from_total=False):
+        if self.benchmark_mode:
+            torch.cuda.synchronize()
         delta = time.time() - self.last_update_time
         self.intervals.append(delta)
         if not exclude_from_total:
@@ -37,6 +44,8 @@ class PerformanceMeter():
         return n/delta
 
     def reset_current_lap(self):
+        if self.benchmark_mode:
+            torch.cuda.synchronize()
         self.last_update_time = time.time()
 
     def p(self, i):
@@ -44,3 +53,7 @@ class PerformanceMeter():
         idx = int(len(self.intervals) * i / 100)
         return sorted(self.intervals)[idx]
 
+def print_once(*args, **kwargs):
+    if not dist.is_initialized() or dist.get_rank() == 0:
+        print(*args, **kwargs)
+