Răsfoiți Sursa

[TXL/PyT] Added barriers when reporting time, switched to correct averaging when reporting avg throughput

Szymon Migacz 3 ani în urmă
părinte
comite
8f82237113

+ 22 - 10
PyTorch/LanguageModeling/Transformer-XL/pytorch/eval.py

@@ -168,7 +168,9 @@ def format_log(loss, split, args):
     return log_str
 
 
-def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
+def evaluate(
+    eval_iter, model, device, meters, log_interval, max_size=None, repeat=1
+):
     total_len, total_loss = 0, 0.
     eval_step = 0
 
@@ -176,8 +178,9 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
     log_latency = 0
     log_loss = 0
 
-    torch.cuda.synchronize()
+    utils.distributed.barrier()
     start_time = time.time()
+
     with torch.no_grad():
         mems = None
         for _ in range(repeat):
@@ -186,10 +189,12 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
                     break
                 eval_step += 1
 
-                torch.cuda.synchronize()
+                utils.distributed.barrier()
                 start_iter = time.time()
+
                 loss, mems = model(data, target, mems)
-                torch.cuda.synchronize()
+
+                utils.distributed.barrier()
                 elapsed = time.time() - start_iter
 
                 loss = loss.float().mean()
@@ -204,7 +209,7 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
                 target_tokens = target.numel()
                 throughput = target_tokens / elapsed
                 throughput = utils.distributed.all_reduce_item(throughput, op='sum')
-                meters['eval_throughput'].update(throughput)
+                meters['eval_throughput'].update(throughput, elapsed)
                 log_throughput += throughput
 
                 if eval_step % log_interval == 0:
@@ -238,8 +243,8 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
                     log_loss = 0
 
     utils.distributed.barrier()
-    torch.cuda.synchronize()
     total_time = time.time() - start_time
+
     logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
             total_time, 1000 * total_time / (idx+1)))
 
@@ -251,13 +256,18 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
 def compile_model(model, device, args):
     inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
     tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
+
+    utils.distributed.barrier()
     start = time.time()
+
     with torch.no_grad():
         mems = None
         for _ in range(2):
             _, mems = model(inp, tgt, mems)
-    torch.cuda.synchronize()
+
+    utils.distributed.barrier()
     stop = time.time()
+
     logging.info(f'Building the model took {stop - start:.2f} seconds')
 
 
@@ -450,7 +460,7 @@ def main():
     meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
     meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
 
-    loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat)
+    loss = evaluate(iter, model, device, meters, args.log_interval, args.max_size, args.repeat)
     perplexity = math.exp(loss)
     log_str = format_log(loss, args.split, args)
 
@@ -476,7 +486,9 @@ def main():
             }
         with open(data_path, 'wb') as f:
             pickle.dump(data, f)
-        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
+
+        avg_throughput = meters['eval_throughput'].avg
+        logging.info(f'Throughput Avg: {avg_throughput:.2f} tok/s')
         logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
         for p in args.percentiles:
             logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
@@ -484,7 +496,7 @@ def main():
         logging.info('=' * 100)
 
         summary.update({
-            'eval_throughput': throughput_data.mean(),
+            'eval_throughput': avg_throughput,
             'eval_avg_latency': 1000 * latency_data.mean(),
             })
         for p in args.percentiles:

+ 22 - 5
PyTorch/LanguageModeling/Transformer-XL/pytorch/train.py

@@ -513,6 +513,7 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
     cur_loss = float('inf')
     target_tokens = 0
     log_step = 0
+    utils.distributed.barrier()
     log_start_time = time.time()
 
     if args.varlen:
@@ -586,16 +587,18 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
             cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
             train_loss = 0
 
-            elapsed = time.time() - log_start_time
+            utils.distributed.barrier()
+            current_time = time.time()
+            elapsed = current_time - log_start_time
             avg_elapsed = elapsed / log_step
             avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
-            log_start_time = time.time()
+            log_start_time = current_time
             log_step = 0
 
             lr = optimizer.param_groups[0]['lr']
             throughput = target_tokens / elapsed
             throughput = utils.distributed.all_reduce_item(throughput, op='sum')
-            meters['train_throughput'].update(throughput)
+            meters['train_throughput'].update(throughput, elapsed)
             target_tokens = 0
 
             log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
@@ -634,21 +637,26 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
         interrupted = timeout_handler.interrupted
 
         if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
+            utils.distributed.barrier()
             eval_start_time = time.time()
+
             val_loss = evaluate(va_iter, model, args)
             val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
 
+            utils.distributed.barrier()
+            eval_elapsed = time.time() - eval_start_time
+
             logging.info('-' * 100)
             log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
                       '| valid loss {:5.2f}'.format(
                           train_step // args.eval_interval,
                           train_step,
-                          (time.time() - eval_start_time),
+                          eval_elapsed,
                           val_loss,
                           )
 
             dllogger_data = {
-                'valid_elapsed': (time.time() - eval_start_time),
+                'valid_elapsed': eval_elapsed,
                 'valid_loss': val_loss,
                 }
 
@@ -683,6 +691,7 @@ def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
                     scheduler_sparse.step(val_loss)
 
             # subtract eval time from timers for training
+            utils.distributed.barrier()
             log_start_time += time.time() - eval_start_time
 
         if interrupted:
@@ -1022,7 +1031,10 @@ def main():
     ###########################################################################
     # Loop over epochs.
     # At any point you can hit Ctrl + C to break out of training early.
+
+    utils.distributed.barrier()
     start_time = time.time()
+
     with TimeoutHandler() as timeout_handler:
         try:
             for epoch in itertools.count(start=start_epoch):
@@ -1046,6 +1058,7 @@ def main():
         except KeyboardInterrupt:
             logging.info('-' * 100)
             logging.info('Exiting from training early')
+    utils.distributed.barrier()
     elapsed = time.time() - start_time
 
     ###########################################################################
@@ -1064,9 +1077,13 @@ def main():
         model.load_state_dict(checkpoint['model_state'])
 
         # Run on test data.
+        utils.distributed.barrier()
         test_start_time = time.time()
+
         test_loss = evaluate(te_iter, model, args)
         test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
+
+        utils.distributed.barrier()
         test_elapsed = time.time() - test_start_time
 
         logging.info('=' * 100)

+ 4 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py

@@ -37,10 +37,13 @@ def init_distributed(cuda):
 
 def barrier():
     """
-    Call torch.distributed.barrier() if distritubed is in use
+    Call torch.distributed.barrier() if distritubed is in use, else calls
+    torch.cuda.synchronize() if CUDA is initialized.
     """
     if torch.distributed.is_available() and torch.distributed.is_initialized():
         torch.distributed.barrier()
+    elif torch.cuda.is_available() and torch.cuda.is_initialized():
+        torch.cuda.synchronize()
 
 
 def get_rank():