Ver código fonte

[TXL/PyT] Fixed issue with AMP training together with gradient accumulation (#720)

Szymon Migacz 5 anos atrás
pai
commit
533f7444ae

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/.gitignore

@@ -2,6 +2,6 @@
 __pycache__/
 data/
 results/
+pytorch/LM-TFM/*
 *.out
 *.log
-*.json

+ 3 - 3
PyTorch/LanguageModeling/Transformer-XL/pytorch/data_utils.py

@@ -46,7 +46,7 @@ class LMOrderedIterator(object):
         data = data[:n_step * bsz]
 
         # Evenly divide the data across the bsz batches.
-        self.data = data.view(bsz, -1).t().contiguous()
+        self.data = data.view(bsz, -1).t().contiguous().pin_memory()
 
         if mem_len and warmup:
             self.warmup_batches = (mem_len + bptt - 1) // bptt
@@ -83,8 +83,8 @@ class LMOrderedIterator(object):
         end_idx = i + seq_len
         beg_idx = max(0, i - self.ext_len)
 
-        data = self.data[beg_idx:end_idx].to(self.device)
-        target = self.data[i+1:i+1+seq_len].to(self.device)
+        data = self.data[beg_idx:end_idx].to(self.device, non_blocking=True)
+        target = self.data[i+1:i+1+seq_len].to(self.device, non_blocking=True)
 
         if self.mem_len and self.warmup:
             warm = i >= self.warmup_elems

+ 4 - 4
PyTorch/LanguageModeling/Transformer-XL/pytorch/train.py

@@ -436,7 +436,7 @@ def evaluate(eval_iter, model, args):
 
 
 def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
-                    optimizer, device, args):
+                    optimizer, device, delay_unscale, args):
     cpu = torch.device('cpu')
     data_i = data_chunks[i].contiguous()
     target_i = target_chunks[i].contiguous()
@@ -456,7 +456,7 @@ def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
         if args.amp == 'pytorch':
             scaler.scale(loss).backward()
         elif args.amp == 'apex':
-            with amp.scale_loss(loss, optimizer) as scaled_loss:
+            with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss:
                 scaled_loss.backward()
     else:
         loss.backward()
@@ -498,12 +498,12 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
                 with para_model.no_sync():
                     train_loss_chunk = train_iteration(
                         para_model, i, mems, data_chunks, target_chunks, scaler,
-                        optimizer, device, args
+                        optimizer, device, True, args
                     )
             else:
                 train_loss_chunk = train_iteration(
                     para_model, i, mems, data_chunks, target_chunks, scaler,
-                    optimizer, device, args
+                    optimizer, device, False, args
                 )
 
             train_loss += train_loss_chunk

+ 6 - 0
PyTorch/LanguageModeling/Transformer-XL/pytorch/wt103_base.yaml

@@ -44,6 +44,12 @@ default:
    eval:
       <<: *eval
 
+manual_eval:
+   train:
+      <<: *train
+   eval:
+      <<: *eval
+      manual_config: '{"n_token": 267735, "n_layer": 16, "n_head": 8, "d_model": 512, "d_head": 64, "d_inner": 2048, "dropout": 0.1, "dropatt": 0.0, "dtype": null, "tie_weight": true, "d_embed": 512, "div_val": 1, "tie_projs": [false, true, true, true], "pre_lnorm": false, "tgt_len": 192, "ext_len": 0, "mem_len": 192, "cutoffs": [19997, 39997, 199997], "same_length": false, "attn_type": 0, "clamp_len": -1, "sample_softmax": -1}'
 
 # Full training configs for NVIDIA DGX-1 (8x NVIDIA V100 16GB GPU)
 dgx1_8gpu_fp16: &dgx1_8gpu_fp16

+ 6 - 0
PyTorch/LanguageModeling/Transformer-XL/pytorch/wt103_large.yaml

@@ -55,6 +55,12 @@ default:
    eval:
       <<: *eval
 
+manual_eval:
+   train:
+      <<: *train
+   eval:
+      <<: *eval
+      manual_config: '{"n_token": 267735, "n_layer": 18, "n_head": 16, "d_model": 1024, "d_head": 64, "d_inner": 4096, "dropout": 0.2, "dropatt": 0.2, "dtype": null, "tie_weight": true, "d_embed": 1024, "div_val": 4, "tie_projs": [false, true, true, true], "pre_lnorm": false, "tgt_len": 384, "ext_len": 0, "mem_len": 384, "cutoffs": [19997, 39997, 199997], "same_length": false, "attn_type": 0, "clamp_len": -1, "sample_softmax": -1}'
 
 # Full training configs for NVIDIA DGX-1 (8x NVIDIA V100 16GB GPU)
 dgx1_8gpu_fp16: &dgx1_8gpu_fp16