Răsfoiți Sursa

[TXL/PyT] Minor update for PyTorch Transformer-XL (#688)

Szymon Migacz 5 ani în urmă
părinte
comite
6b82d3acb3
29 a modificat fișierele cu 171 adăugiri și 99 ștergeri
  1. 1 1
      PyTorch/LanguageModeling/Transformer-XL/Dockerfile
  2. 32 32
      PyTorch/LanguageModeling/Transformer-XL/README.md
  3. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/data_utils.py
  4. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/eval.py
  5. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/inference/mem_transformer_jit.py
  6. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/inference/proj_adaptive_softmax_jit.py
  7. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
  8. 3 3
      PyTorch/LanguageModeling/Transformer-XL/pytorch/mem_transformer.py
  9. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/run_multinode_wt103_large.sh
  10. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/run_wt103_base.sh
  11. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/run_wt103_large.sh
  12. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/docker/build.sh
  13. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/docker/interactive.sh
  14. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/inference_benchmark.sh
  15. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/infer_bench.sh
  16. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_bench.sh
  17. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_full.sh
  18. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_long.sh
  19. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_short.sh
  20. 107 39
      PyTorch/LanguageModeling/Transformer-XL/pytorch/train.py
  21. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/__init__.py
  22. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/adaptive_softmax.py
  23. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py
  24. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/exp_utils.py
  25. 2 2
      PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/proj_adaptive_softmax.py
  26. 1 1
      PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/vocabulary.py
  27. 1 0
      PyTorch/LanguageModeling/Transformer-XL/pytorch/wt103_base.yaml
  28. 3 0
      PyTorch/LanguageModeling/Transformer-XL/pytorch/wt103_large.yaml
  29. 1 1
      PyTorch/LanguageModeling/Transformer-XL/requirements.txt

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/Dockerfile

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 32 - 32
PyTorch/LanguageModeling/Transformer-XL/README.md

@@ -382,7 +382,7 @@ have the following components:
 * [PyTorch 20.06-py3 NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
 * GPU architecture:
   * [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
-  * [NVIDIA Turing](https://www.nvidia.com/pl-pl/geforce/turing/)
+  * [NVIDIA Turing](https://www.nvidia.com/en-us/geforce/turing/)
   * [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
 
 For more information about how to get started with NGC containers, see the
@@ -1387,44 +1387,44 @@ table summarizes the final perplexity on the test set.
 The Transformer-XL large model was trained for 25,000 training steps, starting
 from 10 different initial random seeds. After every 1,000 training steps, the
 model was evaluated on the validation dataset and validation perplexity was
-recorded. The training was performed in the pytorch-19.11-py3 NGC container on
+recorded. The training was performed in the pytorch-20.06-py3 NGC container on
 8x NVIDIA DGX-2H with 16x V100 32GB GPUs. The following table summarizes the
 perplexity of our validation dataset.
 
 |**Training step**|**Average perplexity**|**Standard deviation**|**Minimum**|**Maximum**|**Median**|
-|----------------:|---------------------:|---------------------:|----------:|----------:|---------:|
-| 1000  | 605.76 | 3.60068 | 598.00 | 610.41 | 606.04 |
-| 2000  | 142.91 | 1.12225 | 141.79 | 145.56 | 142.47 |
-| 3000  | 62.35  | 0.44710 | 61.75  | 63.25  | 62.37  |
-| 4000  | 40.35  | 0.27075 | 40.06  | 41.00  | 40.24  |
-| 5000  | 32.06  | 0.13979 | 31.85  | 32.25  | 32.12  |
-| 6000  | 28.11  | 0.12096 | 27.88  | 28.29  | 28.13  |
-| 7000  | 25.63  | 0.15906 | 25.44  | 25.89  | 25.59  |
-| 8000  | 24.20  | 0.07317 | 24.07  | 24.30  | 24.21  |
-| 9000  | 23.13  | 0.15848 | 22.90  | 23.46  | 23.09  |
-| 10000 | 22.96  | 0.16448 | 22.68  | 23.29  | 22.98  |
-| 11000 | 21.88  | 0.13801 | 21.74  | 22.20  | 21.90  |
-| 12000 | 21.67  | 0.13077 | 21.44  | 21.89  | 21.64  |
-| 13000 | 21.52  | 0.09049 | 21.43  | 21.72  | 21.50  |
-| 14000 | 21.26  | 0.09471 | 21.13  | 21.41  | 21.24  |
-| 15000 | 21.19  | 0.12189 | 21.07  | 21.47  | 21.15  |
-| 16000 | 21.15  | 0.11736 | 20.90  | 21.32  | 21.18  |
-| 17000 | 20.81  | 0.08846 | 20.68  | 20.97  | 20.83  |
-| 18000 | 20.33  | 0.08871 | 20.19  | 20.47  | 20.31  |
-| 19000 | 19.77  | 0.07522 | 19.68  | 19.93  | 19.77  |
-| 20000 | 19.16  | 0.11090 | 18.99  | 19.31  | 19.19  |
-| 21000 | 18.50  | 0.10299 | 18.34  | 18.71  | 18.49  |
-| 22000 | 18.18  | 0.04529 | 18.14  | 18.29  | 18.15  |
-| 23000 | 17.97  | 0.03982 | 17.92  | 18.04  | 17.96  |
-| 24000 | 17.89  | 0.03974 | 17.81  | 17.94  | 17.90  |
-| 25000 | 17.87  | 0.04264 | 17.80  | 17.92  | 17.88  |
+|----------------:|----------:|---------------------:|----------:|----------:|---------:|
+| 1000  | 608.09 | 3.80116 | 600.65 | 613.73 | 609.40 |
+| 2000  | 142.75 | 0.94452 | 141.21 | 143.84 | 143.07 |
+| 3000  | 62.19  | 0.44544 | 61.38  | 63.01  | 62.18  |
+| 4000  | 40.22  | 0.16397 | 39.93  | 40.54  | 40.20  |
+| 5000  | 32.00  | 0.15850 | 31.61  | 32.19  | 32.02  |
+| 6000  | 28.05  | 0.17854 | 27.81  | 28.41  | 28.05  |
+| 7000  | 25.65  | 0.10946 | 25.51  | 25.87  | 25.65  |
+| 8000  | 24.20  | 0.11385 | 23.98  | 24.36  | 24.20  |
+| 9000  | 23.18  | 0.14936 | 22.84  | 23.37  | 23.20  |
+| 10000 | 22.88  | 0.22752 | 22.54  | 23.33  | 22.94  |
+| 11000 | 21.99  | 0.16232 | 21.73  | 22.29  | 21.97  |
+| 12000 | 21.69  | 0.10824 | 21.46  | 21.81  | 21.73  |
+| 13000 | 21.42  | 0.09154 | 21.25  | 21.57  | 21.44  |
+| 14000 | 21.33  | 0.13821 | 21.15  | 21.55  | 21.27  |
+| 15000 | 21.24  | 0.15526 | 20.95  | 21.57  | 21.20  |
+| 16000 | 21.19  | 0.10521 | 21.01  | 21.44  | 21.18  |
+| 17000 | 20.89  | 0.18239 | 20.69  | 21.18  | 20.82  |
+| 18000 | 20.36  | 0.10715 | 20.21  | 20.53  | 20.34  |
+| 19000 | 19.74  | 0.12803 | 19.45  | 19.92  | 19.75  |
+| 20000 | 19.18  | 0.10020 | 19.05  | 19.39  | 19.15  |
+| 21000 | 18.49  | 0.06319 | 18.36  | 18.60  | 18.49  |
+| 22000 | 18.17  | 0.03674 | 18.11  | 18.22  | 18.16  |
+| 23000 | 17.98  | 0.03682 | 17.90  | 18.04  | 17.99  |
+| 24000 | 17.88  | 0.02880 | 17.84  | 17.92  | 17.89  |
+| 25000 | 17.85  | 0.02793 | 17.80  | 17.90  | 17.86  |
 
 After training, the models were evaluated on the test dataset. The following
 table summarizes the final perplexity on the test set.
 
 |**Average perplexity**|**Standard deviation**|**Minimum**|**Maximum**|**Median**|
-|---------------------:|---------------------:|----------:|----------:|---------:|
-| 18.29 | 0.05214 | 18.24 | 18.40 | 18.27 |
+|----------:|---------------------:|----------:|----------:|---------:|
+| 18.30 | 0.02747 | 18.24 | 18.33 | 18.30 |
 
 #### Training performance results
 
@@ -1463,7 +1463,7 @@ training iterations.
 | 2 | 8  | 27,304  | 40,222  | 1.473 | 1.883 | 1.866 |
 | 4 | 8  | 53,756  | 80,226  | 1.492 | 3.708 | 3.722 |
 | 8 | 8  | 106,651 | 159,185 | 1.493 | 7.357 | 7.385 |
-| 1 | 16 | N/A     | 25,084  | 1.730 | N/A   | 1     |
+| 1 | 16 | N/A     | 25,084  | 1.730 | N/A   | 1.000 |
 | 2 | 16 | N/A     | 48,562  | 1.779 | N/A   | 1.936 |
 | 4 | 16 | N/A     | 95,997  | 1.786 | N/A   | 3.827 |
 | 8 | 16 | N/A     | 191,148 | 1.792 | N/A   | 7.620 |
@@ -1513,7 +1513,7 @@ training iterations.
 | 2 | 2 | 6,153  | 11,272 | 1.832 | 1.729 | 1.632 |
 | 4 | 2 | 12,492 | 22,530 | 1.804 | 3.511 | 3.262 |
 | 8 | 2 | 24,595 | 40,920 | 1.664 | 6.913 | 5.925 |
-| 1 | 4 | N/A    | 10,210 | 2.870 | N/A   | 1     |
+| 1 | 4 | N/A    | 10,210 | 2.870 | N/A   | 1.000 |
 | 2 | 4 | N/A    | 17,984 | 2.923 | N/A   | 1.761 |
 | 4 | 4 | N/A    | 36,340 | 2.909 | N/A   | 3.559 |
 | 8 | 4 | N/A    | 66,716 | 2.713 | N/A   | 6.535 |

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/data_utils.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/eval.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/inference/mem_transformer_jit.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/inference/proj_adaptive_softmax_jit.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 3 - 3
PyTorch/LanguageModeling/Transformer-XL/pytorch/mem_transformer.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -492,14 +492,14 @@ class AdaptiveEmbedding(nn.Module):
                 l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
 
                 mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
-                indices_i = mask_i.nonzero().squeeze()
+                indices_i = mask_i.nonzero(as_tuple=False).squeeze()
 
                 if indices_i.numel() == 0:
                     continue
 
                 inp_i = inp_flat.index_select(0, indices_i) - l_idx
                 emb_i = self.emb_layers[i](inp_i)
-                emb_i = F.linear(emb_i, self.emb_projs[i])
+                emb_i = F.linear(emb_i, self.emb_projs[i]).to(emb_flat.dtype)
 
                 emb_flat.index_copy_(0, indices_i, emb_i)
 

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/run_multinode_wt103_large.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/run_wt103_base.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/run_wt103_large.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/docker/build.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/docker/interactive.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/inference_benchmark.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/infer_bench.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_bench.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_full.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_long.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/scripts/tests/train_short.sh

@@ -1,6 +1,6 @@
 #!/bin/bash
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 # 
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 107 - 39
PyTorch/LanguageModeling/Transformer-XL/pytorch/train.py

@@ -1,6 +1,6 @@
 # coding: utf-8
 
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import os
 import shutil
 import sys
 import time
+import warnings
 
 import dllogger
 import numpy as np
@@ -30,7 +31,11 @@ import torch
 import torch.nn as nn
 import torch.optim as optim
 import yaml
-from apex import amp
+try:
+    from apex import amp
+except ModuleNotFoundError:
+    warnings.warn('APEX AMP is unavailable')
+
 from torch.nn.parallel import DistributedDataParallel
 
 import lamb
@@ -101,9 +106,11 @@ def parse_args():
                          help='Target training throughput (for benchmarking)')
     general.add_argument('--target_perplexity', type=float, default=None,
                          help='Target validation perplexity (for benchmarking)')
-    general.add_argument('--amp_mode', type=str, default='O2',
+    general.add_argument('--apex_amp_opt_level', type=str, default='O2',
                          choices=['O0', 'O1', 'O2', 'O3'],
                          help='Optimization level for apex amp')
+    general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
+                         help='Implementation of automatic mixed precision')
 
     dataset = parser.add_argument_group('dataset setup')
     dataset.add_argument('--data', type=str, default='../data/wikitext-103',
@@ -220,6 +227,8 @@ def parse_args():
                           help='Use the same attn length for all tokens')
     training.add_argument('--varlen', action='store_true',
                           help='Use variable length')
+    training.add_argument('--swap_mem', action='store_true',
+                          help='Swap memory tensors to cpu')
 
     val = parser.add_argument_group('validation setup')
     val.add_argument('--eval_tgt_len', type=int, default=192,
@@ -244,17 +253,28 @@ def parse_args():
     if args.d_embed < 0:
         args.d_embed = args.d_model
 
-    assert args.ext_len >= 0, 'extended context length must be non-negative'
-    assert args.batch_size % args.batch_chunk == 0
+    if args.ext_len < 0:
+        raise RuntimeError('Extended context length must be non-negative')
+
+    if args.batch_size % args.batch_chunk != 0:
+        raise RuntimeError('Batch size needs to be divisible by batch chunk')
+
+    if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
+        raise RuntimeError(
+            'APEX AMP unavailable, install APEX or switch to pytorch AMP'
+        )
 
     return args
 
 
-def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
-                    epoch, batch, last_iter, train_step, best_val_loss,
+def save_checkpoint(args, model, model_config, optimizer, scheduler, scaler,
+                    vocab, epoch, batch, last_iter, train_step, best_val_loss,
                     is_best, work_dir):
     if args.fp16:
-        amp_state = amp.state_dict()
+        if args.amp == 'pytorch':
+            amp_state = scaler.state_dict()
+        elif args.amp == 'apex':
+            amp_state = amp.state_dict()
     else:
         amp_state = None
 
@@ -415,10 +435,40 @@ def evaluate(eval_iter, model, args):
     return total_loss / total_len
 
 
+def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
+                    optimizer, device, args):
+    cpu = torch.device('cpu')
+    data_i = data_chunks[i].contiguous()
+    target_i = target_chunks[i].contiguous()
+
+    if args.swap_mem and mems[i] is not None:
+        mems[i] = mems[i].to(device, non_blocking=True)
+
+    enable_autocast = args.fp16 and args.amp == 'pytorch'
+    with torch.cuda.amp.autocast(enable_autocast):
+        loss, mems[i] = model(data_i, target_i, mems[i])
+        loss = loss.float().mean().type_as(loss) / args.batch_chunk
+
+    if args.swap_mem and mems[i] is not None:
+        mems[i] = mems[i].to(cpu, non_blocking=True)
+
+    if args.fp16:
+        if args.amp == 'pytorch':
+            scaler.scale(loss).backward()
+        elif args.amp == 'apex':
+            with amp.scale_loss(loss, optimizer) as scaled_loss:
+                scaled_loss.backward()
+    else:
+        loss.backward()
+
+    train_loss = loss.float().item()
+    return train_loss
+
+
 def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
-          optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
+          optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch,
           last_batch, last_iter, train_step, best_val_loss, meters,
-          timeout_handler, args):
+          timeout_handler, device, args):
     # Turn on training mode which enables dropout.
     model.train()
 
@@ -444,27 +494,36 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
         target_chunks = torch.chunk(target, args.batch_chunk, 1)
 
         for i in range(args.batch_chunk):
-            data_i = data_chunks[i].contiguous()
-            target_i = target_chunks[i].contiguous()
-            loss, mems[i] = para_model(data_i, target_i, mems[i])
-            loss = loss.float().mean().type_as(loss) / args.batch_chunk
-
-            if args.fp16:
-                with amp.scale_loss(loss, optimizer) as scaled_loss:
-                    scaled_loss.backward()
+            if i < args.batch_chunk - 1 and isinstance(para_model, DistributedDataParallel):
+                with para_model.no_sync():
+                    train_loss_chunk = train_iteration(
+                        para_model, i, mems, data_chunks, target_chunks, scaler,
+                        optimizer, device, args
+                    )
             else:
-                loss.backward()
+                train_loss_chunk = train_iteration(
+                    para_model, i, mems, data_chunks, target_chunks, scaler,
+                    optimizer, device, args
+                )
 
-            train_loss += loss.float().item()
+            train_loss += train_loss_chunk
 
         if args.fp16:
-            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
+            if args.amp == 'pytorch':
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
+            elif args.amp == 'apex':
+                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
         else:
             torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
 
-        optimizer.step()
-        if optimizer_sparse:
-            optimizer_sparse.step()
+        if args.fp16 and args.amp == 'pytorch':
+            scaler.step(optimizer)
+            scaler.update()
+        else:
+            optimizer.step()
+            if optimizer_sparse:
+                optimizer_sparse.step()
 
         # step-wise learning rate annealing
         train_step += 1
@@ -575,8 +634,8 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
                 is_best = True
 
             if not args.debug:
-                save_checkpoint(args, model, model_config, optimizer,
-                                scheduler, vocab, epoch, batch, last_iter,
+                save_checkpoint(args, model, model_config, optimizer, scheduler,
+                                scaler, vocab, epoch, batch, last_iter,
                                 train_step, best_val_loss, is_best,
                                 args.work_dir)
 
@@ -772,12 +831,16 @@ def main():
 
     model = model.to(device)
 
+    scaler = None
     if args.fp16:
-        model, optimizer = amp.initialize(
-            model,
-            optimizer,
-            opt_level=args.amp_mode,
-            )
+        if args.amp == 'pytorch':
+            scaler = torch.cuda.amp.GradScaler()
+        elif args.amp == 'apex':
+            model, optimizer = amp.initialize(
+                model,
+                optimizer,
+                opt_level=args.apex_amp_opt_level,
+                )
 
     if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
         para_model = DistributedDataParallel(model,
@@ -862,7 +925,10 @@ def main():
             optimizer.load_state_dict(checkpoint['optimizer_state'])
             scheduler.load_state_dict(checkpoint['scheduler_state'])
             if args.fp16:
-                amp.load_state_dict(checkpoint['amp_state'])
+                if args.amp == 'pytorch':
+                    scaler.load_state_dict(checkpoint['amp_state'])
+                elif args.amp == 'apex':
+                    amp.load_state_dict(checkpoint['amp_state'])
             train_step = checkpoint['train_step']
             start_epoch = checkpoint['epoch']
             last_batch = checkpoint['batch']
@@ -871,8 +937,8 @@ def main():
 
             if train_step >= args.max_step:
                 logging.info(f'Loaded checkpoint after {train_step} steps, but '
-                            f'this run was scheduled for a total of '
-                            f'{args.max_step} steps, exiting')
+                             f'this run was scheduled for a total of '
+                             f'{args.max_step} steps, exiting')
                 sys.exit(1)
 
             model.apply(functools.partial(update_dropout, args=args))
@@ -898,8 +964,8 @@ def main():
                 train_step, best_val_loss = train(
                     tr_iter, va_iter, model, para_model, model_config,
                     optimizer, optimizer_sparse, scheduler, scheduler_sparse,
-                    vocab, epoch, last_batch, last_iter, train_step,
-                    best_val_loss, meters, timeout_handler, args
+                    scaler, vocab, epoch, last_batch, last_iter, train_step,
+                    best_val_loss, meters, timeout_handler, device, args
                     )
 
                 last_batch = 0
@@ -984,9 +1050,11 @@ if __name__ == "__main__":
         pass
 
     # Before we do anything with models, we want to ensure that we get fp16
-    # execution of torch.einsum.
+    # execution of torch.einsum in APEX AMP.
     # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
-    # Note that running `--amp_mode O2` will remove the need for this
+    # Note that running `--apex_amp_opt_level O2` will remove the need for this
     # code, but it is still valid.
-    amp.register_half_function(torch, 'einsum')
+    if 'apex' in sys.modules:
+        amp.register_half_function(torch, 'einsum')
+
     main()

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/__init__.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/adaptive_softmax.py

@@ -54,7 +54,7 @@ class AdaptiveLogSoftmax(nn.Module):
             l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
 
             mask_i = (target >= l_idx) & (target < h_idx)
-            indices_i = mask_i.nonzero().squeeze()
+            indices_i = mask_i.nonzero(as_tuple=False).squeeze()
 
             if indices_i.numel() == 0:
                 continue

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/exp_utils.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 2 - 2
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/proj_adaptive_softmax.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -177,7 +177,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
                 l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
 
                 mask_i = (target >= l_idx) & (target < r_idx)
-                indices_i = mask_i.nonzero().squeeze()
+                indices_i = mask_i.nonzero(as_tuple=False).squeeze()
 
                 if indices_i.numel() == 0:
                     continue

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/vocabulary.py

@@ -1,4 +1,4 @@
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.

+ 1 - 0
PyTorch/LanguageModeling/Transformer-XL/pytorch/wt103_base.yaml

@@ -257,3 +257,4 @@ trainbench: &trainbench
       <<: *train
       log_interval: 1
       max_step: 500
+      max_step_scheduler: 40000

+ 3 - 0
PyTorch/LanguageModeling/Transformer-XL/pytorch/wt103_large.yaml

@@ -118,6 +118,7 @@ dgx1_1gpu_fp32: &dgx1_1gpu_fp32
    train:
       <<: *train
       batch_chunk: 64
+      swap_mem: true
    eval:
       <<: *eval
 
@@ -339,9 +340,11 @@ trainbench: &trainbench
       <<: *train
       log_interval: 1
       max_step: 500
+      max_step_scheduler: 100000
 
 trainbench_multinode: &trainbench_multinode
    train:
       <<: *train_multinode
       log_interval: 1
       max_step: 500
+      max_step_scheduler: 25000

+ 1 - 1
PyTorch/LanguageModeling/Transformer-XL/requirements.txt

@@ -1,3 +1,3 @@
 pytorch-transformers==1.1.0
 sacremoses==0.0.35
-nvidia-ml-py3==7.352.0
+pynvml==8.0.4