Przemek Strzelczyk 6 лет назад
Родитель
Сommit
3d46067af9
45 измененных файлов с 6572 добавлено и 0 удалено
  1. 3 0
      PyTorch/LanguageModeling/TransformerXL/.gitignore
  2. 201 0
      PyTorch/LanguageModeling/TransformerXL/LICENSE
  3. 9 0
      PyTorch/LanguageModeling/TransformerXL/NOTICE
  4. 1159 0
      PyTorch/LanguageModeling/TransformerXL/README.md
  5. 120 0
      PyTorch/LanguageModeling/TransformerXL/getdata.sh
  6. 62 0
      PyTorch/LanguageModeling/TransformerXL/prep_text8.py
  7. 2 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/.dockerignore
  8. 30 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/Dockerfile
  9. 333 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/data_utils.py
  10. 320 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/eval.py
  11. BIN
      PyTorch/LanguageModeling/TransformerXL/pytorch/img/model.png
  12. BIN
      PyTorch/LanguageModeling/TransformerXL/pytorch/img/training_loss.png
  13. 469 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/inference/mem_transformer_base_jit.py
  14. 141 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/inference/proj_adaptive_softmax_jit.py
  15. 247 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/lamb.py
  16. 842 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/mem_transformer.py
  17. 2 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/requirements.txt
  18. 41 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_enwik8_base.sh
  19. 41 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_enwik8_large.sh
  20. 43 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_lm1b_base.sh
  21. 43 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_lm1b_large.sh
  22. 41 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_text8_base.sh
  23. 38 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_text8_large.sh
  24. 58 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_wt103_base.sh
  25. 54 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/run_wt103_large.sh
  26. 17 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/docker/build.sh
  27. 17 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/docker/interactive.sh
  28. 39 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/inference_benchmark.sh
  29. 71 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/infer_bench.sh
  30. 6 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/reference_inference_throughput
  31. 10 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/reference_training_throughput
  32. 78 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_bench.sh
  33. 79 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_full.sh
  34. 80 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_long.sh
  35. 80 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_short.sh
  36. 810 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/train.py
  37. 16 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/__init__.py
  38. 90 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/adaptive_softmax.py
  39. 91 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/data_parallel.py
  40. 110 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/distributed.py
  41. 147 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/exp_utils.py
  42. 147 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/log_uniform_sampler.py
  43. 154 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/proj_adaptive_softmax.py
  44. 229 0
      PyTorch/LanguageModeling/TransformerXL/pytorch/utils/vocabulary.py
  45. 2 0
      README.md

+ 3 - 0
PyTorch/LanguageModeling/TransformerXL/.gitignore

@@ -0,0 +1,3 @@
+**/.DS_Store
+__pycache__/
+data/

+ 201 - 0
PyTorch/LanguageModeling/TransformerXL/LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 9 - 0
PyTorch/LanguageModeling/TransformerXL/NOTICE

@@ -0,0 +1,9 @@
+Transformer-XL for PyTorch
+
+This repository includes software from https://github.com/kimiyoung/transformer-xl licensed under the Apache License 2.0.
+
+This repository includes software from https://github.com/salesforce/awd-lstm-lm licensed under the BSD-3-Clause license.
+
+This repository includes software from https://github.com/cybertronai/transformer-xl licensed under the Apache License 2.0.
+
+This repository includes software from https://github.com/cybertronai/pytorch-lamb licensed under the MIT license.

+ 1159 - 0
PyTorch/LanguageModeling/TransformerXL/README.md

@@ -0,0 +1,1159 @@
+# Transformer-XL For PyTorch
+
+This repository provides a script and recipe to train the Transformer-XL model
+to achieve state-of-the-art accuracy, and is tested and maintained by NVIDIA.
+
+## Table Of Contents
+
+<!-- TOC GFM -->
+
+* [Model overview](#model-overview)
+  * [Model architecture](#model-architecture)
+  * [Default configuration](#default-configuration)
+  * [Feature support matrix](#feature-support-matrix)
+    * [Features](#features)
+  * [Mixed precision training](#mixed-precision-training)
+    * [Enabling mixed precision](#enabling-mixed-precision)
+* [Setup](#setup)
+  * [Requirements](#requirements)
+* [Quick Start Guide](#quick-start-guide)
+* [Advanced](#advanced)
+  * [Scripts and sample code](#scripts-and-sample-code)
+  * [Parameters](#parameters)
+  * [Command-line options](#command-line-options)
+  * [Getting the data](#getting-the-data)
+    * [Dataset guidelines](#dataset-guidelines)
+    * [Multi-dataset](#multi-dataset)
+  * [Training process](#training-process)
+  * [Inference process](#inference-process)
+* [Performance](#performance)
+  * [Benchmarking](#benchmarking)
+    * [Training performance benchmark](#training-performance-benchmark)
+    * [Inference performance benchmark](#inference-performance-benchmark)
+  * [Results](#results)
+    * [Training accuracy results](#training-accuracy-results)
+      * [Training accuracy: NVIDIA DGX-1 (8x V100 16G)](#training-accuracy-nvidia-dgx-1-8x-v100-16g)
+      * [Training accuracy: NVIDIA DGX-2 (16x V100 32G)](#training-accuracy-nvidia-dgx-2-16x-v100-32g)
+      * [Training stability test](#training-stability-test)
+    * [Training performance results](#training-performance-results)
+      * [Training performance: NVIDIA DGX-1 (8x V100 16G)](#training-performance-nvidia-dgx-1-8x-v100-16g)
+      * [Training performance: NVIDIA DGX-2 (16x V100 32G)](#training-performance-nvidia-dgx-2-16x-v100-32g)
+    * [Inference performance results](#inference-performance-results)
+      * [Inference performance: NVIDIA DGX-1 (1x V100 16G)](#inference-performance-nvidia-dgx-1-1x-v100-16g)
+      * [Inference performance: NVIDIA T4](#inference-performance-nvidia-t4)
+* [Release notes](#release-notes)
+  * [Changelog](#changelog)
+  * [Known issues](#known-issues)
+
+<!-- /TOC -->
+
+## Model overview
+
+This repository provides an implementation of the Transformer-XL model in
+PyTorch from the paper [Transformer-XL: Attentive Language Models Beyond a
+Fixed-Length Context](https://arxiv.org/abs/1901.02860). Transformer-XL is
+a transformer-based language model with a segment-level recurrence and a novel
+relative positional encoding. Enhancements introduced in Transformer-XL help
+capture better long-term dependencies by attending to tokens from multiple
+previous segments.
+
+Our implementation is based on the
+[codebase](https://github.com/kimiyoung/transformer-xl) published by the
+authors of the Transformer-XL paper.
+Our implementation uses modified model architecture hyperparameters. Our
+modifications were made to achieve better hardware utilization and to take
+advantage of Tensor Cores. Similar modifications were also proposed in an
+implementation available from
+[github.com/cybertronai/transformer-xl](https://github.com/cybertronai/transformer-xl).
+Refer to the [Model architecture](#model-architecture) section for more
+details.
+
+This model is trained with mixed precision using Tensor Cores on NVIDIA Volta
+GPUs and evaluated on Volta and Turing GPUs. Therefore, researchers can get
+results up to 2.5x faster than training without Tensor Cores, while
+experiencing the benefits of mixed precision training. This model is tested
+against each NGC monthly container release to ensure consistent accuracy and
+performance over time.
+
+### Model architecture
+
+The Transformer-XL "base" model for WikiText-103 dataset available in this
+repository was modified to use the following values of hyperparameters:
+
+|**Hyperparameter**|**Description**|**Original setting**|**Our modification**|
+|------------------|---------------|-------------------:|-------------------:|
+| `d_model` | hidden size                                                      | 410  | 512  |
+| `n_head`  | number of attention heads                                        | 10   | 8    |
+| `d_head`  | size of each attention head                                      | 41   | 64   |
+| `d_inner` | hidden size in fully-connected layers                            | 2100 | 2048 |
+| `tgt_len` | number of tokens to predict during training                      | 150  | 192  |
+| `mem_len` | number of tokens cached from previous iterations during training | 150  | 192  |
+
+Changes described above were made to align certain hyperparameters with powers
+of two, with this modification, the model is able to achieve better hardware
+utilization, and therefore higher training throughput.
+
+
+The Transformer-XL model addresses the limitations of vanilla transformer-based
+language models, which are only able to use relatively short context, bounded
+by the segment length. The Transformer-XL introduces a recurrence mechanism,
+which is able to use a cached hidden state from previous segments. During
+training, the context consists of a concatenation of current segment's hidden
+state and cached states from previous iterations. Gradients are backpropagated
+only through the current segment, although the model is able to take advantage
+of the extra information stored in the cache and therefore is able to model
+long-term dependencies.
+
+An illustration of the recurrence mechanism taken from the [Transformer-XL
+paper](https://arxiv.org/abs/1901.02860) is shown below.
+![model](pytorch/img/model.png)
+
+
+### Default configuration
+
+The following features were implemented in this model:
+
+* general
+  * single-node, data-parallel multi-GPU training,
+  * training and inference with mixed precision using Tensor Cores,
+  * mixed precision training implemented using 
+    [Apex AMP](https://nvidia.github.io/apex/amp.html), with `O2` optimization
+    level and with a dynamic loss scaling,
+
+* model
+  * a 16-layer base Transformer-XL model with hidden size 512, 8 attention heads,
+    each head with hidden size 64,
+  * the model trained on
+    [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)
+    dataset, using word-level vocabulary and
+    adaptive softmax,
+  * embedding weights are tied with weights in the classifier,
+
+* training
+  * training with [LAMB](https://arxiv.org/abs/1904.00962) optimizer,
+  * linear learning rate warmup for 1000 iterations, followed by cosine
+    learning rate schedule, initial learning rate is set to 0.01, final
+    learning rate is set to 0.001,
+  * training for 40,000 steps, using batch size of 256,
+  * support for a training with a gradient accumulation,
+
+* inference
+  * support for multi-gpu inference,
+  * support for TorchScript and pure Python inference,
+  * target length is set to 64, length of memory is set to 640,
+  * positional embeddings are clamped after 400 time steps,
+  * each token is using the same size of the context from previous time steps.
+
+### Feature support matrix
+
+The following features are supported by this model:
+
+| **Feature** | **Transformer-XL** |
+|:------------|------------:|
+|[Apex AMP](https://nvidia.github.io/apex/amp.html) | Yes |
+|[Apex DistributedDataParallel](https://nvidia.github.io/apex/parallel.html#apex.parallel.DistributedDataParallel) | Yes |
+
+#### Features
+
+[Apex AMP](https://nvidia.github.io/apex/amp.html) - a tool that enables Tensor
+Core-accelerated training. Refer to the [Enabling mixed
+precision](#enabling-mixed-precision) section for more details.
+
+[Apex
+DistributedDataParallel](https://nvidia.github.io/apex/parallel.html#apex.parallel.DistributedDataParallel) -
+a module wrapper that enables easy multiprocess distributed data parallel
+training, similar to
+[torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel).
+`DistributedDataParallel` is optimized for use with
+[NCCL](https://github.com/NVIDIA/nccl). It achieves high performance by
+overlapping communication with computation during `backward()` and bucketing
+smaller gradient transfers to reduce the total number of transfers required.
+
+### Mixed precision training
+
+Mixed precision is the combined use of different numerical precisions in a
+computational method.
+[Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant
+computational speedup by performing operations in half-precision format, while
+storing minimal information in single-precision to retain as much information
+as possible in critical parts of the network. Since the introduction of [Tensor
+Cores](https://developer.nvidia.com/tensor-cores) in the Volta and Turing
+architectures, significant training speedups are experienced by switching to
+mixed precision -- up to 3x overall speedup on the most arithmetically intense
+model architectures. Using mixed precision training previously required two
+steps:
+
+1. Porting the model to use the FP16 data type where appropriate.
+2. Manually adding loss scaling to preserve small gradient values.
+
+The ability to train deep learning networks with lower precision was introduced
+in the Pascal architecture and first supported in [CUDA
+8](https://devblogs.nvidia.com/parallelforall/tag/fp16/) in the NVIDIA Deep
+Learning SDK.
+
+For information about:
+* How to train using mixed precision, see the [Mixed Precision
+  Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed
+  Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
+  documentation.
+* Techniques used for mixed precision training, see the [Mixed-Precision
+  Training of Deep Neural
+  Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/)
+  blog.
+* APEX tools for mixed precision training, see the [NVIDIA Apex: Tools for Easy
+  Mixed-Precision Training in
+  PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/)
+  .
+
+#### Enabling mixed precision
+The `pytorch/train.py` training script launches mixed precision training
+with Tensor Cores if the flag `--fp16` is set.
+
+Mixed precision is enabled in PyTorch by using the Automatic Mixed Precision
+(AMP), library from [APEX](https://github.com/NVIDIA/apex) that casts variables
+to half-precision upon retrieval, while storing variables in single-precision
+format. Furthermore, to preserve small gradient magnitudes in backpropagation,
+a [loss
+scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling)
+step must be included when applying gradients. In PyTorch, loss scaling can be
+easily applied by using `scale_loss()` method provided by AMP. The scaling
+value to be used can be
+[dynamic](https://nvidia.github.io/apex/amp.html#apex.amp.initialize) or fixed.
+
+For an in-depth walk through on AMP, check out sample usage
+[here](https://nvidia.github.io/apex/amp.html#).
+[APEX](https://github.com/NVIDIA/apex) is a PyTorch extension that contains
+utility libraries, such as AMP, which require minimal network code changes to
+leverage Tensor Cores performance.
+
+The following steps were needed to enable mixed precision training in
+Transformer-XL:
+
+1. Import AMP from APEX:
+
+```
+from apex import amp
+```
+
+2. Initialize AMP and wrap the model and the optimizer before starting the
+  training:
+
+```
+model, optimizer = amp.initialize(
+    model,
+    optimizer,
+    opt_level='O2',
+    )
+```
+
+3. Apply `scale_loss` context manager:
+
+```
+with amp.scale_loss(loss, optimizer) as scaled_loss:
+    scaled_loss.backward()
+```
+
+4. Apply gradient clipping on single precision master weights:
+
+```
+torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
+```
+
+## Setup
+
+The following section lists the requirements that you need to meet in order to
+start training the Transformer-XL model.
+
+### Requirements
+
+This repository contains `Dockerfile` which extends the PyTorch NGC container
+and encapsulates some dependencies.  Aside from these dependencies, ensure you
+have the following components:
+
+* [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
+* [PyTorch 19.09-py3 NGC container](https://ngc.nvidia.com/registry/nvidia-pytorch)
+* [NVIDIA Volta](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
+  or [Turing](https://www.nvidia.com/pl-pl/geforce/turing/) based GPU
+
+For more information about how to get started with NGC containers, see the
+following sections from the NVIDIA GPU Cloud Documentation and the Deep
+Learning DGX Documentation:
+
+* [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html),
+* [Accessing And Pulling From The NGC container registry](https://docs.nvidia.com/deeplearning/dgx/user-guide/index.html#accessing_registry),
+* [Running PyTorch](https://docs.nvidia.com/deeplearning/dgx/pytorch-release-notes/running.html#running).
+
+For those unable to use the Pytorch NGC container, to set up the required
+environment or create your own container, see the versioned [NVIDIA Container
+Support
+Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
+
+## Quick Start Guide
+
+To train your model using mixed precision with Tensor Cores or using FP32,
+perform the following steps using the default parameters of the Transformer-XL
+base model on the
+[WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)
+dataset. 
+
+For the specifics concerning training
+and inference, see the [Advanced](#advanced) section.
+
+1. Clone the repository.
+
+```
+git clone https://github.com/NVIDIA/DeepLearningExamples
+cd DeepLearningExamples/PyTorch/LanguageModeling/Transformer-XL
+```
+
+2. Download and preprocess the dataset.
+
+```
+bash getdata.sh
+```
+
+3. Build the Transformer-XL PyTorch NGC container.
+
+From now on, all scripts should be executed from the `pytorch` directory.
+
+```
+cd pytorch
+bash scripts/docker/build.sh
+```
+
+4. Start an interactive session in the NGC container to run training/inference.
+
+```
+bash scripts/docker/interactive.sh
+```
+
+5. Start training.
+
+To start 8 GPU mixed precision training on DGX-1, run:
+
+```
+bash run_wt103_base.sh train 8 --vocab word --adaptive --fp16 --batch_chunk 1 
+```
+
+To start 8 GPU FP32 training on DGX-1, run:
+
+```
+bash run_wt103_base.sh train 8 --vocab word --adaptive --batch_chunk 2
+```
+
+To start 16 GPU mixed precision training on DGX-2, run:
+
+```
+bash run_wt103_base.sh train 16 --vocab word --adaptive --fp16 --batch_chunk 1 
+```
+
+To start 16 GPU FP32 training on DGX-2, run:
+
+```
+bash run_wt103_base.sh train 16 --vocab word --adaptive --batch_chunk 1
+```
+
+For more information on the available options, refer to the [Training
+process](#training-process) section.
+
+6. Start evaluation.
+
+To start mixed precision inference on the test set using `<#GPUs>` GPUs, run:
+
+```
+bash run_wt103_base.sh eval <#GPUs> [--fp16] [--type {pytorch, torchscript}]
+```
+
+The `--fp16` flag is optional, however, if it's specified, then the script
+launches mixed precision inference with Tensor Cores. If the flag is not
+present, then the script launches FP32 inference.
+By default, the script is loading the checkpoint from
+`LM-TFM/checkpoint_best.pt`, which contains the model corresponding to the
+lowest value of the validation loss from the previous training run. Path to the
+checkpoint can be customized by setting the `--model` flag.
+
+Inference can use pure Python execution or TorchScript from using the `--type`
+flag.
+
+Supported values for `<#GPUs>` are: 1, 2, 4, 8, 16.
+
+Additionally, one can pass the input text directly from the command-line using
+the `--manual` flag. This mode of operation supports only 1 GPU and batch size
+of 1. The script outputs average loss and perplexity for the provided input
+text.
+
+Examples:
+
+
+```
+bash run_wt103_base.sh eval 1 \
+  --model LM-TFM/checkpoint_best.pt \
+  --fp16 \
+  --manual "recognize speech"
+
+===============================================================================
+| test loss  6.20 | test ppl   494.291
+===============================================================================
+```
+
+```
+bash run_wt103_base.sh eval 1 \
+  --model LM-TFM/checkpoint_best.pt \
+  --fp16 \
+  --manual "wreck a nice beach"
+
+===============================================================================
+| test loss  8.04 | test ppl  3099.706
+===============================================================================
+```
+
+For more information on the available options, refer to the [Inference
+process](#inference-process) section.
+
+## Advanced
+
+The following sections provide greater details of the dataset, running training
+and inference, and the training results.
+
+### Scripts and sample code
+
+In the `pytorch` directory, the most important files are:
+
+* `Dockerfile`: container with the basic set of dependencies to run Transformer-XL
+* `data_utils.py`: data loading utilities
+* `eval.py`: serves as the entry point to launch the evaluation and inference
+* `lamb.py`: implementation of [LAMB](https://arxiv.org/abs/1904.00962) optimizer
+* `mem_transformer.py`: implementation of the Transformer-XL model
+* `requirements.txt`: set of extra requirements for running Transformer-XL
+* `train.py`: serves as the entry point to launch the training
+
+The `pytorch/utils` directory contains the following additional modules:
+
+* `adaptive_softmax.py`: implementation of adaptive softmax
+* `data_parallel.py`: implementation of `BalancedDataParallel` class
+* `distributed.py`: utility functions for running distributed training
+* `exp_utils.py`: utility functions for running training and benchmarking
+* `log_uniform_sampler.py`: implementation of log-uniform sampler
+* `proj_adaptive_softmax.py`: implementation of projected adaptive softmax
+* `vocabulary.py`: implementation of word-level vocabulary and BPE-based vocabulary
+
+### Parameters
+
+**Training**
+
+The complete list of available parameters for the `pytorch/train.py` training script
+contains:
+
+```
+general setup:
+  --work_dir WORK_DIR   Directory for the results (default: LM-TFM)
+  --append_dataset      Automatically append dataset name to work_dir
+                        (default: False)
+  --append_time         Automatically append current time to work_dir
+                        (default: False)
+  --cuda                Use CUDA (default: False)
+  --fp16                Run training in fp16/mixed precision (default: False)
+  --restart RESTART     Restart training from the saved checkpoint (default: )
+  --debug               Run in debug mode (do not create exp dir) (default:
+                        False)
+  --log_all_ranks       Enable logging from all distributed ranks (default:
+                        False)
+  --save-all            Save all checkpoints (default: False)
+  --log_interval LOG_INTERVAL
+                        Report interval (default: 10)
+  --target_throughput TARGET_THROUGHPUT
+                        Target training throughput (for benchmarking)
+                        (default: None)
+  --target_perplexity TARGET_PERPLEXITY
+                        Target validation perplexity (for benchmarking)
+                        (default: None)
+
+dataset setup:
+  --data DATA           Location of the data corpus (default:
+                        ../data/wikitext-103)
+  --dataset {wt103,lm1b,enwik8,text8}
+                        Dataset name (default: wt103)
+  --vocab {word,bpe}    Type of vocabulary (default: word)
+
+model setup:
+  --n_layer N_LAYER     Number of total layers (default: 16)
+  --n_head N_HEAD       Number of heads (default: 8)
+  --d_head D_HEAD       Head dimension (default: 64)
+  --d_embed D_EMBED     Embedding dimension (default: -1)
+  --d_model D_MODEL     Model dimension (default: 512)
+  --d_inner D_INNER     Inner dimension in feedforward layer (default: 2048)
+  --dropout DROPOUT     Global dropout rate (default: 0.1)
+  --dropatt DROPATT     Attention probability dropout rate (default: 0.0)
+  --pre_lnorm           Apply LayerNorm to the input instead of the output
+                        (default: False)
+  --attn_type ATTN_TYPE
+                        Attention type. 0 for ours, 1 for Shaw et al,2 for
+                        Vaswani et al, 3 for Al Rfou et al. (default: 0)
+  --not_tied            Do not tie the word embedding and softmax weights
+                        (default: False)
+  --clamp_len CLAMP_LEN
+                        Use the same pos embeddings after clamp_len (default:
+                        -1)
+  --adaptive            Use adaptive softmax (default: False)
+  --div_val DIV_VAL     Dividend value for adaptive input and softmax
+                        (default: 1)
+  --sample_softmax SAMPLE_SOFTMAX
+                        Number of samples in sampled softmax (default: -1)
+  --init INIT           Parameter initializer to use (default: normal)
+  --emb_init EMB_INIT   Parameter initializer to use (default: normal)
+  --init_range INIT_RANGE
+                        Parameters initialized by U(-init_range, init_range)
+                        (default: 0.1)
+  --emb_init_range EMB_INIT_RANGE
+                        Parameters initialized by U(-init_range, init_range)
+                        (default: 0.01)
+  --init_std INIT_STD   Parameters initialized by N(0, init_std) (default:
+                        0.02)
+  --proj_init_std PROJ_INIT_STD
+                        Parameters initialized by N(0, init_std) (default:
+                        0.01)
+
+optimizer setup:
+  --optim {adam,sgd,adagrad,lamb}
+                        Optimizer to use (default: lamb)
+  --lr LR               Initial learning rate (default: 0.01)
+  --mom MOM             Momentum for sgd (default: 0.0)
+  --scheduler {cosine,inv_sqrt,dev_perf,constant}
+                        LR scheduler to use (default: cosine)
+  --max_step_scheduler MAX_STEP_SCHEDULER
+                        Max number of training steps for LR scheduler
+                        (default: None)
+  --warmup_step WARMUP_STEP
+                        Number of iterations for LR warmup (default: 1000)
+  --decay_rate DECAY_RATE
+                        Decay factor when ReduceLROnPlateau is used (default:
+                        0.5)
+  --lr_min LR_MIN       Minimum learning rate during annealing (default: 0.0)
+  --clip CLIP           Gradient clipping (default: 0.25)
+  --weight_decay WEIGHT_DECAY
+                        Weight decay for adam|lamb (default: 0.0)
+  --clip_nonemb         Only clip the gradient of non-embedding params
+                        (default: False)
+  --patience PATIENCE   Patience (default: 0)
+  --eta_min ETA_MIN     Min learning rate for cosine scheduler (default:
+                        0.001)
+
+training setup:
+  --max_step MAX_STEP   Max number of training steps (default: 40000)
+  --batch_size BATCH_SIZE
+                        Global batch size (default: 256)
+  --batch_chunk BATCH_CHUNK
+                        Split batch into chunks to save memory (default: 1)
+  --roll                Enable random shifts within each data stream (default:
+                        False)
+  --tgt_len TGT_LEN     Number of tokens to predict (default: 192)
+  --ext_len EXT_LEN     Length of the extended context (default: 0)
+  --mem_len MEM_LEN     Length of the retained previous heads (default: 192)
+  --seed SEED           Random seed (default: 1111)
+  --multi_gpu {ddp,dp}  Use multiple GPU (default: None)
+  --gpu0_bsz GPU0_BSZ   Batch size on gpu 0 (for "dp" backend) (default: -1)
+  --same_length         Use the same attn length for all tokens (default:
+                        False)
+  --varlen              Use variable length (default: False)
+
+validation setup:
+  --eval_tgt_len EVAL_TGT_LEN
+                        Number of tokens to predict for evaluation (default:
+                        192)
+  --eval_batch_size EVAL_BATCH_SIZE
+                        Eval batch size (default: 16)
+  --eval_max_steps EVAL_MAX_STEPS
+                        Max eval steps (default: -1)
+  --eval_interval EVAL_INTERVAL
+                        Evaluation interval (default: 5000)
+```
+
+**Inference**
+
+The complete list of available parameters for the `eval.py` inference
+script contains:
+
+```
+  --work_dir WORK_DIR   experiment directory (default: LM-TFM)
+  --debug               run in debug mode (do not create exp dir) (default:
+                        False)
+  --data DATA           location of the data corpus (default:
+                        ../data/wikitext-103)
+  --manual MANUAL [MANUAL ...]
+                        run model on raw input data (default: None)
+  --dataset {wt103,lm1b,enwik8,text8}
+                        dataset name (default: wt103)
+  --split {all,valid,test}
+                        which split to evaluate (default: all)
+  --type {pytorch,torchscript,onnx}
+                        type of runtime to use (default: pytorch)
+  --batch_size BATCH_SIZE
+                        batch size (default: 16)
+  --tgt_len TGT_LEN     number of tokens to predict (default: 64)
+  --ext_len EXT_LEN     length of the extended context (default: 0)
+  --mem_len MEM_LEN     length of the retained previous heads (default: 640)
+  --clamp_len CLAMP_LEN
+                        max positional embedding index (default: -1)
+  --cuda                use CUDA (default: False)
+  --model MODEL         path to the checkpoint (default: )
+  --fp16                Run training in fp16/mixed precision (default: False)
+  --log_all_ranks       Enable logging for all distributed ranks (default:
+                        False)
+  --same_length         set same length attention with masking (default:
+                        False)
+  --target_perplexity TARGET_PERPLEXITY
+                        target perplexity (default: None)
+  --target_throughput TARGET_THROUGHPUT
+                        target throughput (default: None)
+  --save_data           save latency and throughput data to a file (default:
+                        False)
+  --repeat REPEAT       loop over the dataset REPEAT times (default: 1)
+  --max_size MAX_SIZE   run inference on up to MAX_SIZE batches (default:
+                        None)
+  --percentiles PERCENTILES [PERCENTILES ...]
+                        percentiles for latency confidence intervals (default:
+                        [90, 95, 99])
+  --save_torchscript SAVE_TORCHSCRIPT
+                        save torchscript model to a file (default: None)
+  --load_torchscript LOAD_TORCHSCRIPT
+                        load torchscript model from a file (default: None)
+```
+
+
+### Command-line options
+
+To see the full list of available options and their descriptions, use the `-h`
+or `--help` command-line option. For example, for training:
+
+```
+python3 train.py --help
+
+usage: train.py [-h] [--work_dir WORK_DIR] [--append_dataset] [--append_time]
+                [--cuda] [--fp16] [--restart RESTART] [--debug]
+                [--log_all_ranks] [--save-all] [--log_interval LOG_INTERVAL]
+                [--target_throughput TARGET_THROUGHPUT]
+                [--target_perplexity TARGET_PERPLEXITY] [--data DATA]
+                [--dataset {wt103,lm1b,enwik8,text8}] [--vocab {word,bpe}]
+                [--n_layer N_LAYER] [--n_head N_HEAD] [--d_head D_HEAD]
+                [--d_embed D_EMBED] [--d_model D_MODEL] [--d_inner D_INNER]
+                [--dropout DROPOUT] [--dropatt DROPATT] [--pre_lnorm]
+                [--attn_type ATTN_TYPE] [--not_tied] [--clamp_len CLAMP_LEN]
+                [--adaptive] [--div_val DIV_VAL]
+                [--sample_softmax SAMPLE_SOFTMAX] [--init INIT]
+                [--emb_init EMB_INIT] [--init_range INIT_RANGE]
+                [--emb_init_range EMB_INIT_RANGE] [--init_std INIT_STD]
+                [--proj_init_std PROJ_INIT_STD]
+                [--optim {adam,sgd,adagrad,lamb}] [--lr LR] [--mom MOM]
+                [--scheduler {cosine,inv_sqrt,dev_perf,constant}]
+                [--max_step_scheduler MAX_STEP_SCHEDULER]
+                [--warmup_step WARMUP_STEP] [--decay_rate DECAY_RATE]
+                [--lr_min LR_MIN] [--clip CLIP] [--weight_decay WEIGHT_DECAY]
+                [--clip_nonemb] [--patience PATIENCE] [--eta_min ETA_MIN]
+                [--max_step MAX_STEP] [--batch_size BATCH_SIZE]
+                [--batch_chunk BATCH_CHUNK] [--roll] [--tgt_len TGT_LEN]
+                [--ext_len EXT_LEN] [--mem_len MEM_LEN] [--seed SEED]
+                [--multi_gpu {ddp,dp}] [--gpu0_bsz GPU0_BSZ] [--same_length]
+                [--varlen] [--eval_tgt_len EVAL_TGT_LEN]
+                [--eval_batch_size EVAL_BATCH_SIZE]
+                [--eval_max_steps EVAL_MAX_STEPS]
+                [--eval_interval EVAL_INTERVAL] [--local_rank LOCAL_RANK]
+```
+
+For example, for inference:
+
+```
+python3 eval.py --help
+
+usage: eval.py [-h] [--work_dir WORK_DIR] [--debug] [--data DATA]
+               [--manual MANUAL [MANUAL ...]]
+               [--dataset {wt103,lm1b,enwik8,text8}]
+               [--split {all,valid,test}] [--type {pytorch,torchscript,onnx}]
+               [--batch_size BATCH_SIZE] [--tgt_len TGT_LEN]
+               [--ext_len EXT_LEN] [--mem_len MEM_LEN] [--clamp_len CLAMP_LEN]
+               [--cuda] [--model MODEL] [--fp16] [--log_all_ranks]
+               [--same_length] [--target_perplexity TARGET_PERPLEXITY]
+               [--target_throughput TARGET_THROUGHPUT] [--save_data]
+               [--repeat REPEAT] [--max_size MAX_SIZE]
+               [--percentiles PERCENTILES [PERCENTILES ...]]
+               [--save_torchscript SAVE_TORCHSCRIPT]
+               [--load_torchscript LOAD_TORCHSCRIPT] [--local_rank LOCAL_RANK]
+```
+
+
+### Getting the data
+
+The Transformer-XL base model was trained on the
+[WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)
+dataset. The WikiText-103 dataset is a collection of over 100 million tokens
+extracted from the set of verified
+[Good](https://en.wikipedia.org/wiki/Wikipedia:Good_articles) and
+[Featured](https://en.wikipedia.org/wiki/Wikipedia:Featured_articles) articles
+on Wikipedia.
+
+This repository contains the `getdata.sh` download script which
+automatically downloads and extracts the training, validation and test
+datasets. By default, data is downloaded to the `data` directory.
+
+In order to test with other datasets, the script needs to be customized
+accordingly.
+
+#### Dataset guidelines
+
+The WikiText-103 dataset was already pre-tokenized with word-level tokens. The
+dataset features a large vocabulary of 267,735 tokens and retains the original
+case, punctuation and numbers.
+
+The `getdata.sh` script downloads the data, extracts the archive and renames
+the training, validation and test set to `train.txt`, `valid.txt`, `test.txt`
+respectively.
+
+#### Multi-dataset
+
+Using other datasets requires changes in the following files:
+
+* `pytorch/train.py`:
+  * name of the new dataset should be added to the `dataset` argument in the `parse_args()` function
+  * desired values of cutoffs for adaptive softmax should be added in the
+    `main()` function, after the section which builds train/valid/test data
+    iterators
+* `pytorch/data_utils.py`:
+  * support for the new dataset needs to be added to the `Corpus` class: names
+    of files containing training, validation and test data, options for the
+    tokenizer, and dataset iterator
+
+The current codebase supports training with word-level vocabulary
+(automatically generated based on the provided dataset) and with BPE vocabulary
+(using pre-built vocabulary from pretrained GPT2 model imported from
+[github.com/huggingface/transformers](https://github.com/huggingface/transformers).
+
+Additionally, using other datasets may require changes in some hyperparameters
+(for example, batch size, learning rate, number of training steps,
+configuration of learning rate scheduler). 
+
+### Training process
+
+The default training configuration can be launched by running the
+`run_wt103_base.sh` script with the first argument set to `train`. By default,
+the training results are saved to the `LM-TFM` directory; this can be
+customized by setting the `--work_dir` parameter.
+
+The training script launches a single node data-parallel training with a fixed
+global batch size of 256, optionally with gradient accumulation to allow
+training on configurations with less than 8 GPUs. Logs from the training are
+automatically saved to the `LM-TFT/log.log` file.
+
+**Command-line**
+
+```
+bash run_wt103_base.sh train <#GPUs> --vocab word --adaptive [--fp16] [--batch_chunk CHUNK]
+```
+
+Launches training of Transformer-XL base model on WikiText-103 dataset with word-based vocabulary and adaptive softmax using `<#GPUs>` GPUs.
+
+The `--fp16` flag is optional, if it's specified, then the script launches mixed
+precision training with Tensor Cores, if the flag is not present, then the
+script launches FP32 training.
+
+The `--batch_chunk CHUNK` parameter controls gradient accumulation. With gradient
+accumulation the batch size is split into `CHUNK` chunks of equal size, the
+training script executes the forward and backward pass using each chunk and
+then executes the optimizer using accumulated gradients.
+
+**Examples**
+
+```
+bash run_wt103_base.sh train 16 --fp16 --vocab word --adaptive --batch_chunk 1
+```
+
+Launches mixed precision training of Transformer-XL base model on WikiText-103
+using 16 GPUs. Batch size per GPU is equal to the default global batch size of 256
+divided by the product of the number of GPUs times the number of chunks, in this
+case batch size per GPU is equal to `256 / (16 * 1) = 16`.
+
+```
+bash run_wt103_base.sh train 8 --vocab word --adaptive --batch_chunk 2
+```
+
+Launches a FP32 training using 8 GPUs, the batch size per GPU is equal to 16
+(`--batch_chunk` was set to `2` because a local batch size of 32 runs out
+of memory on a DGX-1 with Tesla V100 16G in FP32 training).
+
+A summary of the training progress is printed after every 10 training
+iterations; this can be customized by setting the `--log_interval` parameter.
+
+
+The summary is printed in the following format:
+
+```
+| epoch  18 step    36000 | batches    283 / 2101 | lr 1.220e-03 | ms/batch 185.1 | tok/s  265585 | loss  3.12 | ppl     22.71
+```
+
+which contains information about a current training epoch, current training
+step, number of batches processed within the current epoch, current learning
+rate, execution time in milliseconds per batch, throughput in tokens per
+second, current training loss and training perplexity.
+
+The script saves two checkpoints: `checkpoint_best.pt` which contains the model
+corresponding to the lowest value of the validation loss and
+`checkpoint_last.pt` which contains the model corresponding to the last
+execution of the validation step. By default, the validation is executed every
+5000 training steps, this can be customized by setting the `--eval_interval`
+parameter. The summary of results on the validation dataset is printed in the
+following format:
+
+```
+| Eval   7 at step    35000 | time:  1.37s | valid loss  3.14 | valid ppl    23.132
+```
+
+which contains information about the current epoch, current training step, time
+needed to execute the validation, current validation loss and validation
+perplexity.
+
+### Inference process
+
+Inference can be run by launching the `run_wt103_base.sh` script with the first
+argument set to `eval`. Running inference requires a pre-trained model
+checkpoint.
+
+The script supports single node multi-GPU inference, each batch is split
+equally among all GPUs running the inference and the loss is averaged over the
+global batch.
+
+**Command-line**
+
+```
+bash run_wt103_base.sh eval <#GPUs> --model <PATH TO THE CHECKPOINT> [--fp16] [--type {pytorch, torchscript}]
+```
+
+The `--fp16` flag is optional, if it's specified, then the script launches inference
+with Tensor Cores, if the flag is not present, then the script launches FP32
+inference.
+
+The `--type` flag selects between pure Python pytorch execution and TorchScript execution.
+
+Supported values for `<#GPUs>` are: 1, 2, 4, 8, 16.
+
+**Examples**
+
+```
+bash run_wt103_base.sh eval 8 --model LM-TFM/checkpoint_best.pt --fp16 --type torchscript
+```
+
+Launches TorchScript mixed precision inference on 8 GPUs using a checkpoint loaded from
+`LM-TFM/checkpoint_best.pt`.
+
+```
+bash run_wt103_base.sh eval 1 --model LM-TFM/checkpoint_best.pt --type pytorch
+```
+
+Launches pure Python FP32 inference on a single GPU using a checkpoint loaded from
+`LM-TFM/checkpoint_best.pt`.
+
+After the execution, the script prints a summary in the following format:
+
+```
+Evaluating with math fp16 type torchscript bsz 16 tgt_len 64 ext_len 0 mem_len 640 clamp_len 400
+Time : 5.29s, 22.05ms/segment
+====================================================================================================
+| test loss  3.15 | test ppl    23.304
+====================================================================================================
+```
+
+which contains information about runtime parameters, execution time, loss and
+perplexity on the test dataset.
+
+## Performance
+
+### Benchmarking
+
+The following section shows how to run benchmarks measuring the model
+performance in training and inference modes.
+
+#### Training performance benchmark
+
+To benchmark the training performance on a specific global batch size `<BS>`,
+with a specific number of GPUs `<#GPUs>` for a specific number of training
+iterations `<ITER>` run:
+
+```
+bash run_wt103_base.sh train <#GPUs> --batch_size <BS> --max_step <ITER> --vocab word --adaptive --log_interval 1 --debug [--fp16] [--batch_chunk CHUNK]
+```
+
+It's recommended to launch at least 500 training steps to get a reliable
+estimate of training performance. For more information about the available
+options, refer to the [Training process](#training-process) section.
+
+The training script prints information in the following format:
+
+```
+(...)
+| epoch   1 step      499 | batches    499 / 16802 | lr 4.990e-03 | ms/batch 219.9 | tok/s   27947 | loss  6.43 | ppl    620.80
+| epoch   1 step      500 | batches    500 / 16802 | lr 5.000e-03 | ms/batch 221.4 | tok/s   27747 | loss  6.42 | ppl    611.70
+-------------------------------------------------------------------------------
+(...)
+Training time: 1.81 minutes
+Training throughput: 28508.91 tok/s
+```
+
+The last two lines contain information on the total training time and on the
+average training throughput measured in tokens per second.
+
+#### Inference performance benchmark
+
+The inference performance and accuracy benchmarks require a checkpoint from a
+trained model.
+
+To benchmark the inference performance on a specific global batch size `<BS>`
+with a specific number of GPUs `<#GPUs>`, run:
+
+```
+bash run_wt103_base.sh eval <#GPUs> --model <CHECKPOINT> --batch_size <BS> --save_data [--fp16] [--type {pytorch, torchscript}]
+```
+
+The inference script prints information in the following format:
+
+```
+Evaluating with math fp16 type torchscript bsz 16 tgt_len 64 ext_len 0 mem_len 640 clamp_len 400
+Time : 5.25s, 21.88ms/segment
+====================================================================================================
+| test loss  3.15 | test ppl    23.304
+====================================================================================================
+Throughput Avg: 46316.64 tok/s
+Latency Avg: 22.09 ms
+Latency 90%: 22.22 ms
+Latency 95%: 22.25 ms
+Latency 99%: 22.37 ms
+====================================================================================================
+```
+
+The output contains information on the achieved test loss and test perplexity,
+average inference throughput (measured in tokens per second), average inference
+latency and latency at 90%, 95% and 99% confidence intervals (measured in
+milliseconds).
+
+The `scripts/inference_benchmark.sh` benchmarking script is provided for
+convenience, it automatically launches FP32 and FP16 inference for various
+batch sizes.
+
+### Results
+
+The following sections provide details on how we achieved our performance and
+accuracy in training and inference.
+
+#### Training accuracy results
+
+##### Training accuracy: NVIDIA DGX-1 (8x V100 16G)
+
+Our results were obtained by running the `pytorch/run_wt103_base.sh`
+training script in the in the pytorch-19.09-py3 NGC container on NVIDIA DGX-1
+with 8x V100 16G GPUs.
+
+|**GPUs**|**Batch Size / GPU**|**Accuracy - FP32 (perplexity)**|**Accuracy - Mixed precision (perplexity)**|**Time to Train - FP32 (minutes)**|**Time to Train - Mixed precision (minutes)**|**Time to Train Speedup (FP32 to Mixed precision)**|
+|-------:|-------------------:|-------------------------------:|------------------------------------------:|---------------------------------:|--------------------------------------------:|--------------------------------------------------:|
+| 1 | 16 | 23.24 | 23.42 | 2542.0 | 1037.8 | 2.45 |
+| 8 | 16 | 23.38 | 23.44 | 366.9  | 168.9  | 2.17 |
+| 1 | 32 | N/A   | 23.38 | N/A    | 894.3  | 2.84 |
+| 8 | 32 | N/A   | 23.38 | N/A    | 140.7  | 2.61 |
+
+##### Training accuracy: NVIDIA DGX-2 (16x V100 32G)
+
+Our results were obtained by running the `pytorch/run_wt103_base.sh`
+training script in the in the pytorch-19.09-py3 NGC container on NVIDIA DGX-2
+with 16x V100 32G GPUs.
+
+|**GPUs**|**Batch Size / GPU**|**Accuracy - FP32 (perplexity)**|**Accuracy - Mixed precision (perplexity)**|**Time to Train - FP32 (minutes)**|**Time to Train - Mixed precision (minutes)**|**Time to Train Speedup (FP32 to Mixed precision)**|
+|-------:|-------------------:|-------------------------------:|------------------------------------------:|---------------------------------:|--------------------------------------------:|--------------------------------------------------:|
+| 16 | 16 | 23.36 | 23.32 | 184.4 | 91.2 | 2.02 |
+
+![TrainingLoss](pytorch/img/training_loss.png)
+
+##### Training stability test
+
+The Transformer-XL model was trained for 40000 training steps, starting from 20
+different initial random seeds. After every 5000 training steps, the model was
+evaluated on the validation dataset and validation perplexity was recorded. The
+training was performed in the pytorch-19.09-py3 NGC container on NVIDIA
+DGX-1 with 8x V100 16G GPUs. The following table summarizes the perplexity
+on our validation dataset.
+
+|**Training step**|**Average**|**Standard deviation**|**Minimum**|**Maximum**|**Median**|
+|----------------:|----------:|---------------------:|----------:|----------:|---------:|
+| 5000  | 42.58 | 0.28639 | 41.98 | 43.11 | 42.62 |
+| 10000 | 32.39 | 0.19765 | 32.09 | 32.78 | 32.41 |
+| 15000 | 28.49 | 0.15000 | 28.28 | 28.78 | 28.49 |
+| 20000 | 26.22 | 0.11862 | 26.06 | 26.52 | 26.22 |
+| 25000 | 24.73 | 0.11190 | 24.45 | 24.88 | 24.74 |
+| 30000 | 23.88 | 0.10489 | 23.67 | 24.04 | 23.87 |
+| 35000 | 23.31 | 0.10010 | 23.09 | 23.45 | 23.33 |
+| 40000 | 23.10 | 0.09857 | 22.86 | 23.23 | 23.11 |
+
+After training, the models were evaluated on the test dataset. The following
+table summarizes the final perplexity on the test set.
+
+|**Average**|**Standard deviation**|**Minimum**|**Maximum**|**Median**|
+|----------:|---------------------:|----------:|----------:|---------:|
+| 23.39 | 0.06817 | 23.26 | 23.51 | 23.39 |
+
+#### Training performance results
+
+##### Training performance: NVIDIA DGX-1 (8x V100 16G)
+
+Our results were obtained by running the `pytorch/run_wt103_base.sh`
+training script in the pytorch-19.09-py3 NGC container on NVIDIA DGX-1 with 8x
+V100 16G GPUs. Performance numbers (in tokens per second) were averaged 500
+training iterations.
+
+|**GPUs**|**Batch Size / GPU**|**Throughput - FP32 (tok/s)**|**Throughput - Mixed precision (tok/s)**|**Throughput speedup (FP32 to Mixed precision)**|**Weak Scaling - FP32**|**Weak Scaling - Mixed precision**|
+|-------:|-------------------:|----------------------------:|----------------------------------------:|-----------------------------------------------:|------------------------:|-----------------------------------:|
+| 1 | 16 | 11,499.8 | 24,028.9  | 2.089 | 1.000 | 1.000 |
+| 2 | 16 | 19,574.0 | 40,001.7  | 2.044 | 1.702 | 1.665 |
+| 4 | 16 | 42,184.9 | 85,391.2  | 2.024 | 3.668 | 3.554 |
+| 8 | 16 | 84,803.6 | 159,122.2 | 1.876 | 7.374 | 6.622 |
+| 1 | 32 | N/A      | 31,072.4  | 2.702 | N/A   | 1.000 |
+| 2 | 32 | N/A      | 55,534.1  | 2.837 | N/A   | 1.787 |
+| 4 | 32 | N/A      | 117,200.6 | 2.778 | N/A   | 3.772 |
+| 8 | 32 | N/A      | 234,437.3 | 2.764 | N/A   | 7.545 |
+
+To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
+
+##### Training performance: NVIDIA DGX-2 (16x V100 32G)
+
+Our results were obtained by running the `pytorch/run_wt103_base.sh`
+training script in the pytorch-19.09-py3 NGC container on NVIDIA DGX-2 with 16x
+V100 32G GPUs. Performance numbers (in tokens per second) were averaged 500
+training iterations.
+
+|**GPUs**|**Batch Size / GPU**|**Throughput - FP32 (tok/s)**|**Throughput - Mixed precision (tok/s)**|**Throughput speedup (FP32 to Mixed precision)**|**Weak Scaling - FP32**|**Weak Scaling - Mixed precision**|
+|-------:|-------------------:|----------------------------:|---------------------------------------:|-----------------------------------------------:|----------------------:|---------------------------------:|
+| 1  | 16 | 12,204.0  | 25,337.5  | 2.076 | 1.000  | 1.000  |
+| 2  | 16 | 22,995.2  | 46,605.1  | 2.027 | 1.884  | 1.839  |
+| 4  | 16 | 45,321.1  | 91,537.4  | 2.020 | 3.714  | 3.613  |
+| 8  | 16 | 89,427.3  | 179,920.5 | 2.012 | 7.328  | 7.101  |
+| 16 | 16 | 177,245.0 | 357,343.6 | 2.016 | 14.524 | 14.103 |
+
+To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
+
+#### Inference performance results
+
+##### Inference performance: NVIDIA DGX-1 (1x V100 16G)
+
+Our results were obtained by running the
+`pytorch/scripts/inference_benchmark.sh` inferencing benchmarking script in the
+pytorch-19.09-py3 NGC container on NVIDIA DGX-1 with 1x V100 16G GPU.
+
+The command to launch the inference performance benchmark is provided in the
+[Inference performance benchmark](#inference-performance-benchmark) section.
+
+**FP16, pure Python**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 3,346.3  | 19.13 | 19.47 | 19.64 | 20.67 |
+| 2  | 64 | 640 | 6,486.5  | 19.74 | 20.04 | 20.20 | 21.52 |
+| 4  | 64 | 640 | 13,007.2 | 19.68 | 19.93 | 20.12 | 21.45 |
+| 8  | 64 | 640 | 24,783.1 | 20.65 | 20.95 | 21.22 | 22.78 |
+| 16 | 64 | 640 | 42,777.7 | 23.93 | 24.09 | 24.41 | 25.82 |
+| 32 | 64 | 640 | 52,961.6 | 38.64 | 38.86 | 39.77 | 41.25 |
+
+**FP16, TorchScript**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 5,117.4  | 12.52 | 12.76 | 12.87 | 13.39 |
+| 2  | 64 | 640 | 9,703.5  | 13.20 | 13.39 | 13.52 | 14.78 |
+| 4  | 64 | 640 | 18,259.5 | 14.02 | 14.22 | 14.38 | 15.72 |
+| 8  | 64 | 640 | 35,758.7 | 14.32 | 14.52 | 14.66 | 16.04 |
+| 16 | 64 | 640 | 50,159.9 | 20.41 | 20.44 | 20.69 | 21.98 |
+| 32 | 64 | 640 | 57,223.4 | 35.76 | 35.95 | 36.25 | 37.57 |
+
+**FP32, pure Python**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 3,216.4  | 19.91  | 20.31  | 20.45  | 21.61  |
+| 2  | 64 | 640 | 6,314.0  | 20.28  | 20.63  | 20.80  | 21.95  |
+| 4  | 64 | 640 | 10,991.5 | 23.28  | 23.56  | 23.74  | 25.20  |
+| 8  | 64 | 640 | 16,398.5 | 31.20  | 31.57  | 31.83  | 33.38  |
+| 16 | 64 | 640 | 18,845.5 | 54.29  | 54.71  | 54.89  | 56.05  |
+| 32 | 64 | 640 | 19,209.5 | 106.51 | 107.45 | 107.69 | 108.81 |
+
+**FP32, TorchScript**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 4,915.5  | 13.03  | 13.37  | 13.50  | 14.00  |
+| 2  | 64 | 640 | 8,644.5  | 14.81  | 15.10  | 15.19  | 16.39  |
+| 4  | 64 | 640 | 13,480.2 | 18.98  | 19.20  | 19.29  | 20.55  |
+| 8  | 64 | 640 | 17,075.5 | 29.96  | 30.18  | 30.26  | 31.64  |
+| 16 | 64 | 640 | 19,201.7 | 53.29  | 53.74  | 53.95  | 54.78  |
+| 32 | 64 | 640 | 19,724.4 | 103.73 | 104.40 | 104.59 | 105.73 |
+
+
+To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
+
+##### Inference performance: NVIDIA T4
+
+Our results were obtained by running the
+`pytorch/scripts/inference_benchmark.sh` inferencing benchmarking script in the
+pytorch-19.09-py3 NGC container on NVIDIA T4.
+
+The command to launch the inference performance benchmark is provided in the
+[Inference performance benchmark](#inference-performance-benchmark) section.
+
+**FP16, pure Python**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 4,067.0  | 15.75  | 16.26  | 16.43  | 16.74  |
+| 2  | 64 | 640 | 7,559.5  | 16.94  | 17.37  | 17.55  | 17.93  |
+| 4  | 64 | 640 | 13,203.1 | 19.38  | 19.91  | 20.02  | 20.46  |
+| 8  | 64 | 640 | 16,101.8 | 31.78  | 32.45  | 32.53  | 33.00  |
+| 16 | 64 | 640 | 17,375.8 | 58.89  | 59.87  | 60.23  | 60.63  |
+| 32 | 64 | 640 | 17,946.2 | 114.03 | 115.33 | 116.17 | 119.87 |
+
+**FP16, TorchScript**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 5,834.6  | 10.99  | 11.45  | 11.61  | 11.94  |
+| 2  | 64 | 640 | 11,167.8 | 11.47  | 11.87  | 12.08  | 12.80  |
+| 4  | 64 | 640 | 14,890.8 | 17.19  | 17.70  | 17.89  | 18.19  |
+| 8  | 64 | 640 | 16,862.5 | 30.35  | 31.09  | 31.32  | 31.81  |
+| 16 | 64 | 640 | 18,281.2 | 55.98  | 56.82  | 57.00  | 58.52  |
+| 32 | 64 | 640 | 18,912.0 | 108.21 | 109.54 | 110.20 | 113.80 |
+
+**FP32, pure Python**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 3,457.2 | 18.56  | 19.40  | 19.74  | 20.35  |
+| 2  | 64 | 640 | 4,746.1 | 26.98  | 27.84  | 28.12  | 28.56  |
+| 4  | 64 | 640 | 5,687.6 | 44.98  | 45.93  | 46.35  | 47.24  |
+| 8  | 64 | 640 | 6,223.5 | 82.21  | 83.37  | 83.72  | 84.22  |
+| 16 | 64 | 640 | 6,522.6 | 156.87 | 159.63 | 160.43 | 161.13 |
+| 32 | 64 | 640 | 6,608.2 | 309.63 | 313.21 | 314.07 | 315.32 |
+
+**FP32, TorchScript**
+
+|**Batch size**|**Sequence length**|**Memory length**|**Throughput Avg (tok/s)**|**Latency Avg (ms)**|**Latency 90% (ms)**|**Latency 95% (ms)**|**Latency 99% (ms)**|
+|-------------:|------------------:|----------------:|-------------------------:|-------------------:|-------------------:|-------------------:|-------------------:|
+| 1  | 64 | 640 | 3,859.7 | 16.64  | 17.71  | 17.98  | 18.53  |
+| 2  | 64 | 640 | 4,823.6 | 26.55  | 27.41  | 27.70  | 28.05  |
+| 4  | 64 | 640 | 5,790.0 | 44.18  | 45.07  | 45.30  | 45.91  |
+| 8  | 64 | 640 | 6,306.4 | 81.12  | 82.26  | 82.54  | 83.11  |
+| 16 | 64 | 640 | 6,599.5 | 155.04 | 157.54 | 158.15 | 159.88 |
+| 32 | 64 | 640 | 6,707.0 | 305.06 | 307.94 | 308.54 | 309.44 |
+
+To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
+
+## Release notes
+
+### Changelog
+
+* October 2019
+  * Initial release
+    *  Support for FP32 and mixed precision training on NVIDIA DGX-1, NVIDIA
+       DGX-2 and inference on NVIDIA Tesla V100 16G and NVIDIA T4
+
+### Known issues
+There are no known issues with this model.

+ 120 - 0
PyTorch/LanguageModeling/TransformerXL/getdata.sh

@@ -0,0 +1,120 @@
+# BSD 3-Clause License
+# 
+# Copyright (c) 2017, 
+# All rights reserved.
+# 
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+# 
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+# 
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+# 
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+# 
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+echo "=== Acquiring datasets ==="
+echo "---"
+
+mkdir -p data
+cd data
+
+if [[ ! -d 'wikitext-2' ]]; then
+    echo "- Downloading WikiText-2 (WT2)"
+    wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
+    unzip -q wikitext-2-v1.zip
+    cd wikitext-2
+    mv wiki.train.tokens train.txt
+    mv wiki.valid.tokens valid.txt
+    mv wiki.test.tokens test.txt
+    cd ..
+fi
+
+echo "- Downloading WikiText-103 (WT2)"
+if [[ ! -d 'wikitext-103' ]]; then
+    wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
+    unzip -q wikitext-103-v1.zip
+    cd wikitext-103
+    mv wiki.train.tokens train.txt
+    mv wiki.valid.tokens valid.txt
+    mv wiki.test.tokens test.txt
+    cd ..
+fi
+
+echo "- Downloading enwik8 (Character)"
+if [[ ! -d 'enwik8' ]]; then
+    mkdir -p enwik8
+    cd enwik8
+    wget --continue http://mattmahoney.net/dc/enwik8.zip
+    wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
+    python3 prep_enwik8.py
+    cd ..
+fi
+
+echo "- Downloading text8 (Character)"
+if [[ ! -d 'text8' ]]; then
+    mkdir -p text8
+    cd text8
+    wget --continue http://mattmahoney.net/dc/text8.zip
+    python ../../prep_text8.py
+    cd ..
+fi
+
+echo "- Downloading Penn Treebank (PTB)"
+if [[ ! -d 'penn' ]]; then
+    wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
+    tar -xzf simple-examples.tgz
+
+    mkdir -p penn
+    cd penn
+    mv ../simple-examples/data/ptb.train.txt train.txt
+    mv ../simple-examples/data/ptb.test.txt test.txt
+    mv ../simple-examples/data/ptb.valid.txt valid.txt
+    cd ..
+
+    echo "- Downloading Penn Treebank (Character)"
+    mkdir -p pennchar
+    cd pennchar
+    mv ../simple-examples/data/ptb.char.train.txt train.txt
+    mv ../simple-examples/data/ptb.char.test.txt test.txt
+    mv ../simple-examples/data/ptb.char.valid.txt valid.txt
+    cd ..
+
+    rm -rf simple-examples/
+fi
+
+echo "- Downloading 1B words"
+
+if [[ ! -d 'one-billion-words' ]]; then
+    mkdir -p one-billion-words
+    cd one-billion-words
+
+    wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
+    tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
+
+    path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/"
+    cat ${path}/news.en.heldout-00000-of-00050 > valid.txt
+    cat ${path}/news.en.heldout-00000-of-00050 > test.txt
+
+    wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt
+
+    cd ..
+fi
+
+echo "---"
+echo "Happy language modeling :)"

+ 62 - 0
PyTorch/LanguageModeling/TransformerXL/prep_text8.py

@@ -0,0 +1,62 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# BSD 3-Clause License
+#
+# Copyright (c) 2017,
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+#   list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+#   this list of conditions and the following disclaimer in the documentation
+#   and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+#   contributors may be used to endorse or promote products derived from
+#   this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+import os
+import sys
+import zipfile
+
+from io import open
+
+if os.path.exists('train.txt'):
+    print('Tokenized text8 already exists - skipping processing')
+    sys.exit()
+
+data = zipfile.ZipFile('text8.zip').extractall()
+data = open('text8', 'r', encoding='utf-8').read()
+
+print('Length of text8: {}'.format(len(data)))
+
+num_test_chars = 5000000
+
+train_data = data[: -2 * num_test_chars]
+valid_data = data[-2 * num_test_chars: -num_test_chars]
+test_data = data[-num_test_chars:]
+
+for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]:
+    print('{} will have {} bytes'.format(fn, len(part)))
+    print('- Tokenizing...')
+    # Change space ' ' to underscore '_'
+    part_str = ' '.join(['_' if c == ' ' else c for c in part.strip()])
+    print('- Writing...')
+    f = open(fn, 'w').write(part_str)
+    f = open(fn + '.raw', 'w', encoding='utf-8').write(part)

+ 2 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/.dockerignore

@@ -0,0 +1,2 @@
+LM-TFM*
+internal/result*

+ 30 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/Dockerfile

@@ -0,0 +1,30 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:19.09-py3
+FROM ${FROM_IMAGE_NAME}
+
+ENV LANG C.UTF-8
+ENV LC_ALL C.UTF-8
+
+WORKDIR /tmp/unique_for_apex
+RUN git clone https://github.com/NVIDIA/apex.git && cd apex && git reset --hard 3ae89c754d945e407a6674aa2006d5a0e35d540e
+RUN cd apex && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
+
+WORKDIR /workspace/transformer-xl/pytorch
+
+COPY requirements.txt .
+RUN pip install --no-cache-dir -r requirements.txt
+
+ADD . /workspace/transformer-xl/pytorch

+ 333 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/data_utils.py

@@ -0,0 +1,333 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import glob
+import logging
+import os
+import re
+
+import numpy as np
+import sacremoses
+import torch
+
+import utils
+from utils.vocabulary import OpenAIVocab
+from utils.vocabulary import Vocab
+
+
+class LMOrderedIterator(object):
+    def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
+        """
+            data -- LongTensor -- the LongTensor is strictly ordered
+        """
+        self.bsz = bsz
+        self.bptt = bptt
+        self.ext_len = ext_len if ext_len is not None else 0
+
+        self.device = device
+
+        # Work out how cleanly we can divide the dataset into bsz parts.
+        self.n_step = data.size(0) // bsz
+
+        # Trim off any extra elements that wouldn't cleanly fit (remainders).
+        data = data.narrow(0, 0, self.n_step * bsz)
+
+        # Evenly divide the data across the bsz batches.
+        self.data = data.view(bsz, -1).t().contiguous()
+
+        # Partition data for DistributedDataParallel
+        world_size = utils.distributed.get_world_size()
+        rank = utils.distributed.get_rank()
+        self.data = self.data.chunk(world_size, dim=1)[rank].to(device)
+
+        # Number of mini-batches
+        self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
+
+    def roll(self):
+        for i in range(self.data.size(1)):
+            row = self.data[:, i]
+            shift = torch.randint(0, self.data.size(0), (1,))
+            row = torch.cat((row[shift:], row[:shift]))
+            self.data[:, i] = row
+
+    def get_batch(self, i, bptt=None):
+        if bptt is None:
+            bptt = self.bptt
+
+        seq_len = min(bptt, self.data.size(0) - 1 - i)
+
+        end_idx = i + seq_len
+        beg_idx = max(0, i - self.ext_len)
+
+        data = self.data[beg_idx:end_idx]
+        target = self.data[i+1:i+1+seq_len]
+
+        return data, target, seq_len
+
+    def get_fixlen_iter(self, start=0):
+        for i in range(start, self.data.size(0) - 1, self.bptt):
+            yield self.get_batch(i)
+
+    def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
+        max_len = self.bptt + max_deviation * std
+        i = start
+        while True:
+            bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
+            bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
+            data, target, seq_len = self.get_batch(i, bptt)
+            i += seq_len
+            yield data, target, seq_len
+            if i >= self.data.size(0) - 2:
+                break
+
+    def __iter__(self):
+        return self.get_fixlen_iter()
+
+
+class LMShuffledIterator(object):
+    def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
+        """
+            data -- list[LongTensor] -- there is no order among the LongTensors
+        """
+        self.data = data
+
+        self.bsz = bsz
+        self.bptt = bptt
+        self.ext_len = ext_len if ext_len is not None else 0
+
+        self.device = device
+        self.shuffle = shuffle
+
+    def get_sent_stream(self):
+        # index iterator
+        epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
+            else np.array(range(len(self.data)))
+
+        # sentence iterator
+        for idx in epoch_indices:
+            yield self.data[idx]
+
+    def stream_iterator(self, sent_stream):
+        # streams for each data in the batch
+        streams = [None] * self.bsz
+
+        data = torch.LongTensor(self.bptt, self.bsz)
+        target = torch.LongTensor(self.bptt, self.bsz)
+
+        n_retain = 0
+
+        while True:
+            # data   : [n_retain+bptt x bsz]
+            # target : [bptt x bsz]
+            data[n_retain:].fill_(-1)
+            target.fill_(-1)
+
+            valid_batch = True
+
+            for i in range(self.bsz):
+                n_filled = 0
+                try:
+                    while n_filled < self.bptt:
+                        if streams[i] is None or len(streams[i]) <= 1:
+                            streams[i] = next(sent_stream)
+                        # number of new tokens to fill in
+                        n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
+                        # first n_retain tokens are retained from last batch
+                        data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
+                            streams[i][:n_new]
+                        target[n_filled:n_filled+n_new, i] = \
+                            streams[i][1:n_new+1]
+                        streams[i] = streams[i][n_new:]
+                        n_filled += n_new
+                except StopIteration:
+                    valid_batch = False
+                    break
+
+            if not valid_batch:
+                return
+
+            data = data.to(self.device)
+            target = target.to(self.device)
+
+            yield data, target, self.bptt
+
+            n_retain = min(data.size(0), self.ext_len)
+            if n_retain > 0:
+                data[:n_retain] = data[-n_retain:]
+            data.resize_(n_retain + self.bptt, data.size(1))
+
+    def __iter__(self):
+        # sent_stream is an iterator
+        sent_stream = self.get_sent_stream()
+
+        for batch in self.stream_iterator(sent_stream):
+            yield batch
+
+
+class LMMultiFileIterator(LMShuffledIterator):
+    def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
+                 shuffle=False):
+
+        self.paths = paths
+        self.vocab = vocab
+
+        self.bsz = bsz
+        self.bptt = bptt
+        self.ext_len = ext_len if ext_len is not None else 0
+
+        self.device = device
+        self.shuffle = shuffle
+
+    def get_sent_stream(self, path):
+        sents = self.vocab.encode_file(path, add_double_eos=True)
+        if self.shuffle:
+            np.random.shuffle(sents)
+        sent_stream = iter(sents)
+
+        return sent_stream
+
+    def __iter__(self):
+        if self.shuffle:
+            np.random.shuffle(self.paths)
+
+        for path in self.paths:
+            # sent_stream is an iterator
+            sent_stream = self.get_sent_stream(path)
+            for batch in self.stream_iterator(sent_stream):
+                yield batch
+
+
+class Corpus(object):
+    def __init__(self, path, dataset, vocab, *args, **kwargs):
+        self.dataset = dataset
+        if vocab == 'word':
+            self.vocab = Vocab(*args, **kwargs)
+        elif vocab == 'bpe':
+            self.vocab = OpenAIVocab()
+        else:
+            raise RuntimeError('Unsupported vocab')
+
+        if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
+            self.vocab.count_file(os.path.join(path, 'train.txt'))
+            self.vocab.count_file(os.path.join(path, 'valid.txt'))
+            self.vocab.count_file(os.path.join(path, 'test.txt'))
+        elif self.dataset == 'wt103':
+            self.vocab.count_file(os.path.join(path, 'train.txt'))
+        elif self.dataset == 'lm1b':
+            train_path_pattern = os.path.join(
+                path, '1-billion-word-language-modeling-benchmark-r13output',
+                'training-monolingual.tokenized.shuffled', 'news.en-*')
+            train_paths = glob.glob(train_path_pattern)
+            # the vocab will load from file when build_vocab() is called
+
+        self.vocab.build_vocab()
+
+        if self.dataset in ['ptb', 'wt2', 'wt103']:
+            self.train = self.vocab.encode_file(
+                os.path.join(path, 'train.txt'), ordered=True)
+            self.valid = self.vocab.encode_file(
+                os.path.join(path, 'valid.txt'), ordered=True)
+            self.test = self.vocab.encode_file(
+                os.path.join(path, 'test.txt'), ordered=True)
+        elif self.dataset in ['enwik8', 'text8']:
+            self.train = self.vocab.encode_file(
+                os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
+            self.valid = self.vocab.encode_file(
+                os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
+            self.test = self.vocab.encode_file(
+                os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
+        elif self.dataset == 'lm1b':
+            self.train = train_paths
+            self.valid = self.vocab.encode_file(
+                os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
+            self.test = self.vocab.encode_file(
+                os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
+
+    def get_iterator(self, split, *args, **kwargs):
+        if split == 'train':
+            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
+                data_iter = LMOrderedIterator(self.train, *args, **kwargs)
+            elif self.dataset == 'lm1b':
+                kwargs['shuffle'] = True
+                data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
+        elif split in ['valid', 'test']:
+            data = self.valid if split == 'valid' else self.test
+            if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
+                data_iter = LMOrderedIterator(data, *args, **kwargs)
+            elif self.dataset == 'lm1b':
+                data_iter = LMShuffledIterator(data, *args, **kwargs)
+
+        return data_iter
+
+
+def get_lm_corpus(datadir, dataset, vocab):
+    if vocab == 'word':
+        fn = os.path.join(datadir, 'cache.pt')
+    elif vocab == 'bpe':
+        fn = os.path.join(datadir, 'cache.pt.bpe')
+    else:
+        raise RuntimeError('Unsupported vocab')
+
+    if os.path.exists(fn):
+        logging.info('Loading cached dataset...')
+        corpus = torch.load(fn)
+    else:
+        logging.info('Producing dataset {}...'.format(dataset))
+        kwargs = {}
+        if dataset in ['wt103', 'wt2']:
+            kwargs['special'] = ['<eos>']
+            kwargs['lower_case'] = False
+        elif dataset == 'ptb':
+            kwargs['special'] = ['<eos>']
+            kwargs['lower_case'] = True
+        elif dataset == 'lm1b':
+            kwargs['special'] = []
+            kwargs['lower_case'] = False
+            kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
+        elif dataset in ['enwik8', 'text8']:
+            pass
+
+        corpus = Corpus(datadir, dataset, vocab, **kwargs)
+        with utils.distributed.sync_workers() as rank:
+            if rank == 0:
+                torch.save(corpus, fn)
+
+    return corpus
+
+
+def tokenize_raw(text, lang='en'):
+    mt = sacremoses.MosesTokenizer(lang)
+    text = mt.tokenize(text, return_str=True)
+    text = re.sub(r'&quot;', '"', text)
+    text = re.sub(r'&apos;', "'", text)
+    text = re.sub(r'(\d)\.(\d)', r'\1 @.@ \2', text)
+    text = re.sub(r'(\d),(\d)', r'\1 @,@ \2', text)
+    text = re.sub(r'(\w)-(\w)', r'\1 @-@ \2', text)
+    return text
+
+
+if __name__ == '__main__':
+    import argparse
+    parser = argparse.ArgumentParser(description='unit test')
+    parser.add_argument('--datadir', type=str, default='../data/text8',
+                        help='location of the data corpus')
+    parser.add_argument('--dataset', type=str, default='text8',
+                        choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
+                        help='dataset name')
+    args = parser.parse_args()
+
+    logging.basicConfig(level=logging.INFO)
+
+    corpus = get_lm_corpus(args.datadir, args.dataset, vocab='word')
+    logging.info('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))

+ 320 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/eval.py

@@ -0,0 +1,320 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import pickle
+import sys
+import time
+
+import numpy as np
+import torch
+
+import data_utils
+import utils
+from data_utils import get_lm_corpus
+from data_utils import tokenize_raw
+from utils.exp_utils import AverageMeter
+from utils.exp_utils import benchmark
+from utils.exp_utils import create_exp_dir
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='PyTorch Transformer Language Model',
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+    parser.add_argument('--work_dir', default='LM-TFM', type=str,
+                        help='experiment directory')
+    parser.add_argument('--debug', action='store_true',
+                        help='run in debug mode (do not create exp dir)')
+    parser.add_argument('--data', type=str, default='../data/wikitext-103',
+                        help='location of the data corpus')
+    parser.add_argument('--manual', type=str, default=None, nargs='+',
+                        help='run model on raw input data')
+    parser.add_argument('--dataset', type=str, default='wt103',
+                        choices=['wt103', 'lm1b', 'enwik8', 'text8'],
+                        help='dataset name')
+    parser.add_argument('--split', type=str, default='all',
+                        choices=['all', 'valid', 'test'],
+                        help='which split to evaluate')
+    parser.add_argument('--type', type=str, default='pytorch',
+                        choices=['pytorch', 'torchscript', 'onnx'],
+                        help='type of runtime to use')
+    parser.add_argument('--batch_size', type=int, default=16,
+                        help='batch size')
+    parser.add_argument('--tgt_len', type=int, default=64,
+                        help='number of tokens to predict')
+    parser.add_argument('--ext_len', type=int, default=0,
+                        help='length of the extended context')
+    parser.add_argument('--mem_len', type=int, default=640,
+                        help='length of the retained previous heads')
+    parser.add_argument('--clamp_len', type=int, default=-1,
+                        help='max positional embedding index')
+    parser.add_argument('--cuda', action='store_true',
+                        help='use CUDA')
+    parser.add_argument('--model', type=str, default='',
+                        help='path to the checkpoint')
+    parser.add_argument('--fp16', action='store_true',
+                        help='Run training in fp16/mixed precision')
+    parser.add_argument('--log_all_ranks', action='store_true',
+                        help='Enable logging for all distributed ranks')
+    parser.add_argument('--same_length', action='store_true',
+                        help='set same length attention with masking')
+    parser.add_argument('--target_perplexity', type=float, default=None,
+                        help='target perplexity')
+    parser.add_argument('--target_throughput', type=float, default=None,
+                        help='target throughput')
+    parser.add_argument('--save_data', action='store_true',
+                        help='save latency and throughput data to a file')
+    parser.add_argument('--repeat', type=int, default=1,
+                        help='loop over the dataset REPEAT times')
+    parser.add_argument('--max_size', type=int, default=None,
+                        help='run inference on up to MAX_SIZE batches')
+    parser.add_argument('--percentiles', nargs='+', default=[90, 95, 99],
+                        help='percentiles for latency confidence intervals')
+    parser.add_argument('--save_torchscript', default=None, type=str,
+                        help='save torchscript model to a file')
+    parser.add_argument('--load_torchscript', default=None, type=str,
+                        help='load torchscript model from a file')
+    parser.add_argument('--local_rank', default=0, type=int,
+                        help='Used for multi-process training. ' +
+                        'Can either be manually set ' +
+                        'or automatically set by using \'python -m multiproc\'.')
+    args = parser.parse_args()
+    assert args.ext_len >= 0, 'extended context length must be non-negative'
+    return args
+
+
+def load_checkpoint(path):
+    dst = f'cuda:{torch.cuda.current_device()}'
+    logging.info(f'Loading checkpoint from {path}')
+    checkpoint = torch.load(path, map_location=dst)
+    return checkpoint
+
+
+def format_log(loss, split, args):
+    if args.dataset in ['enwik8', 'text8']:
+        log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
+            split, loss, loss / math.log(2))
+    else:
+        log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
+            split, loss, math.exp(loss))
+    return log_str
+
+
+def evaluate(eval_iter, model, meters, max_size=None, repeat=1):
+    total_len, total_loss = 0, 0.
+    torch.cuda.synchronize()
+    start_time = time.time()
+    with torch.no_grad():
+        mems = None
+        for _ in range(repeat):
+            for idx, (data, target, seq_len) in enumerate(eval_iter):
+                if max_size and idx >= max_size:
+                    break
+                torch.cuda.synchronize()
+                start_iter = time.time()
+                ret = model(data, target, mems)
+                torch.cuda.synchronize()
+                elapsed = time.time() - start_iter
+                loss, mems = ret[0], ret[1:]
+                loss = loss.mean()
+                total_loss += seq_len * loss.item()
+                total_len += seq_len
+                meters['eval_latency'].update(elapsed)
+                target_tokens = target.numel()
+                throughput = target_tokens / elapsed
+                throughput = utils.distributed.all_reduce_item(throughput, op='sum')
+                meters['eval_throughput'].update(throughput)
+
+    utils.distributed.barrier()
+    torch.cuda.synchronize()
+    total_time = time.time() - start_time
+    logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
+            total_time, 1000 * total_time / (idx+1)))
+
+    avg_loss = total_loss / total_len
+    avg_loss = utils.distributed.all_reduce_item(avg_loss, op='mean')
+    return avg_loss
+
+
+def compile_model(model, device, args):
+    inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
+    tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
+    start = time.time()
+    with torch.no_grad():
+        mems = None
+        for _ in range(2):
+            ret = model(inp, tgt, mems)
+            _, mems = ret[0], ret[1:]
+    torch.cuda.synchronize()
+    stop = time.time()
+    logging.info(f'Building the model took {stop - start:.2f} seconds')
+
+
+def main():
+    args = parse_args()
+
+    if args.type == 'pytorch':
+        from mem_transformer import MemTransformerLM
+    else:
+        from inference.mem_transformer_base_jit import MemTransformerLM
+
+    torch.cuda.set_device(args.local_rank)
+    device = torch.device('cuda' if args.cuda else 'cpu')
+    utils.distributed.init_distributed(args.cuda)
+
+    with utils.distributed.sync_workers() as rank:
+        if rank == 0:
+            create_exp_dir(args.work_dir, debug=args.debug)
+
+    # Setup logging
+    if args.log_all_ranks:
+        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
+    else:
+        log_file = f'log.log'
+
+    log_file = os.path.join(args.work_dir, log_file)
+    if args.debug:
+        log_file = os.devnull
+
+    utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
+                                  filename=log_file,
+                                  filemode='a',
+                                  )
+    logging.info(args)
+
+    if args.model:
+        model_path = args.model
+    elif args.work_dir:
+        model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
+    else:
+        raise RuntimeError('Specify path to checkpoint using --model or --work_dir')
+
+    checkpoint = load_checkpoint(model_path)
+
+    if args.manual:
+        args.batch_size = 1
+        vocab = checkpoint['vocab']
+
+        if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
+            vocab.unk_idx = vocab.sym2idx['<unk>']
+
+        text = " ".join(args.manual)
+        tokenized = tokenize_raw(text)
+        symbols = vocab.tokenize(tokenized, add_eos=True)
+        tensor = vocab.convert_to_tensor(symbols)
+
+        iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size,
+                                            bptt=args.tgt_len, device=device,
+                                            ext_len=args.ext_len)
+    else:
+        # Load dataset
+        corpus = get_lm_corpus(args.data, args.dataset, checkpoint['args'].vocab)
+
+        if args.split == 'valid':
+            iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
+                                       device=device, ext_len=args.ext_len)
+        elif args.split == 'test':
+            iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
+                                       device=device, ext_len=args.ext_len)
+        else:
+            raise RuntimeError('Unknown split')
+
+    if args.fp16:
+        dtype = torch.float16
+        math_str = 'fp16'
+    else:
+        dtype = torch.float32
+        math_str = 'fp32'
+
+    if args.load_torchscript:
+        model = torch.jit.load(args.load_torchscript)
+
+    else:
+        checkpoint['model_config']['tgt_len'] = args.tgt_len
+        checkpoint['model_config']['ext_len'] = args.ext_len
+        checkpoint['model_config']['mem_len'] = args.mem_len
+        checkpoint['model_config']['clamp_len'] = args.clamp_len
+        checkpoint['model_config']['same_length'] = args.same_length
+        checkpoint['model_config']['dtype'] = dtype
+
+        model = MemTransformerLM(**checkpoint['model_config'])
+        model.load_state_dict(checkpoint['model_state'])
+
+    model = model.eval()
+    model = model.to(device)
+
+    model = model.float()
+    if args.fp16:
+        model = model.half()
+
+    if args.type != 'pytorch':
+        compile_model(model, device, args)
+
+    if args.type == 'torchscript' and args.save_torchscript:
+        torch.jit.save(model, args.save_torchscript)
+
+    logging.info(f'Evaluating with: math {math_str} type {args.type} '
+                 f'bsz {args.batch_size} tgt_len {args.tgt_len} '
+                 f'ext_len {args.ext_len} mem_len {args.mem_len} '
+                 f'clamp_len {args.clamp_len}')
+
+    meters = {}
+    warmup = args.mem_len // args.tgt_len + 1
+    meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
+    meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
+
+    loss = evaluate(iter, model, meters, args.max_size, args.repeat)
+    perplexity = math.exp(loss)
+    log_str = format_log(loss, args.split, args)
+
+    logging.info('=' * 100)
+    logging.info(log_str)
+    logging.info('=' * 100)
+
+    if args.save_data:
+        latency_data = np.array(meters['eval_latency'].vals)
+        throughput_data = np.array(meters['eval_throughput'].vals)
+        precision = 'fp16' if args.fp16 else 'fp32'
+        data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
+        data_path = os.path.join(args.work_dir, data_fname)
+        data = {
+            'args': args,
+            'throughput': throughput_data,
+            'latency': latency_data,
+            }
+        with open(data_path, 'wb') as f:
+            pickle.dump(data, f)
+        logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
+        logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
+        for p in args.percentiles:
+            logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
+
+        logging.info('=' * 100)
+
+    passed = benchmark(target_perplexity=args.target_perplexity,
+                       test_perplexity=perplexity,
+                       target_throughput=args.target_throughput,
+                       test_throughput=meters['eval_throughput'].avg,
+                       )
+    if not passed:
+        sys.exit(1)
+
+
+if __name__ == "__main__":
+    main()

BIN
PyTorch/LanguageModeling/TransformerXL/pytorch/img/model.png


BIN
PyTorch/LanguageModeling/TransformerXL/pytorch/img/training_loss.png


+ 469 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/inference/mem_transformer_base_jit.py

@@ -0,0 +1,469 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from inference.proj_adaptive_softmax_jit import ProjectedAdaptiveLogSoftmax
+
+
+class PositionalEmbedding(torch.jit.ScriptModule):
+    def __init__(self, demb):
+        super(PositionalEmbedding, self).__init__()
+
+        self.demb = demb
+
+        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
+        self.register_buffer('inv_freq', inv_freq)
+
+    @torch.jit.script_method
+    def forward(self, pos_seq, bsz: Optional[int] = None):
+        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
+        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
+
+        if bsz is not None:
+            return pos_emb[:, None, :].expand(-1, bsz, -1)
+        else:
+            return pos_emb[:, None, :]
+
+
+class PositionwiseFF(torch.jit.ScriptModule):
+    __constants__ = ['pre_lnorm']
+
+    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
+        super(PositionwiseFF, self).__init__()
+
+        self.d_model = d_model
+        self.d_inner = d_inner
+        self.dropout = dropout
+
+        self.CoreNet = nn.Sequential(
+            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
+            nn.Dropout(dropout),
+            nn.Linear(d_inner, d_model),
+            nn.Dropout(dropout),
+        )
+
+        self.layer_norm = nn.LayerNorm(d_model)
+
+        self.pre_lnorm = pre_lnorm
+
+    @torch.jit.script_method
+    def forward(self, inp):
+        if self.pre_lnorm:
+            # layer normalization + positionwise feed-forward
+            core_out = self.CoreNet(self.layer_norm(inp))
+
+            # residual connection
+            output = core_out + inp
+        else:
+            # positionwise feed-forward
+            core_out = self.CoreNet(inp)
+
+            # residual connection + layer normalization
+            output = self.layer_norm(inp + core_out)
+
+        return output
+
+
+class RelMultiHeadAttn(torch.jit.ScriptModule):
+    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
+                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
+        super(RelMultiHeadAttn, self).__init__()
+
+        self.n_head = n_head
+        self.d_model = d_model
+        self.d_head = d_head
+        self.dropout = dropout
+
+        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
+
+        self.drop = nn.Dropout(dropout)
+        self.dropatt = nn.Dropout(dropatt)
+        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
+
+        self.layer_norm = nn.LayerNorm(d_model)
+
+        self.scale = 1 / (d_head ** 0.5)
+
+        self.pre_lnorm = pre_lnorm
+
+    def _parallelogram_mask(self, h, w, left=False):
+        mask = torch.ones((h, w)).byte()
+        m = min(h, w)
+        mask[:m, :m] = torch.triu(mask[:m, :m])
+        mask[-m:, -m:] = torch.tril(mask[-m:, -m:])
+
+        if left:
+            return mask.bool()
+        else:
+            return mask.flip(0).bool()
+
+    def _shift(self, x, qlen, klen, mask, left=False):
+        if qlen > 1:
+            zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
+                                   device=x.device, dtype=x.dtype)
+        else:
+            zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
+
+        if left:
+            mask = mask.flip(1)
+            x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
+        else:
+            x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
+
+        x = x_padded.masked_select(mask[:, :, None, None]) \
+                    .view(qlen, klen, x.size(2), x.size(3))
+
+        return x
+
+    @torch.jit.script_method
+    def _rel_shift(self, x, zero_triu: bool = False):
+        zero_pad = torch.zeros((x.size(0), x.size(1), 1, x.size(3)),
+                               device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=2)
+
+        x_padded = x_padded.view(x.size(0), x.size(2) + 1, x.size(1), x.size(3))
+
+        x = x_padded[:, 1:].view_as(x)
+
+        if zero_triu:
+            ones = torch.ones((x.size(0), x.size(1)))
+            x = x * torch.tril(ones, x.size(1) - x.size(0))[None, :, :, None]
+
+        return x
+
+    @torch.jit.script_method
+    def forward(self, w, r, attn_mask, mems=None):
+        raise NotImplementedError
+
+
+class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
+    __constants__ = ['pre_lnorm', 'n_head', 'd_head', 'scale']
+
+    def __init__(self, *args, **kwargs):
+        super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
+
+        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
+
+    @torch.jit.script_method
+    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask,
+                mems: Optional[torch.Tensor] = None):
+        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
+
+        if mems is not None:
+            cat = torch.cat([mems, w], 0)
+            if self.pre_lnorm:
+                w_heads = self.qkv_net(self.layer_norm(cat))
+            else:
+                w_heads = self.qkv_net(cat)
+            r_head_k = self.r_net(r)
+
+            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+            w_head_q = w_head_q[-qlen:]
+        else:
+            if self.pre_lnorm:
+                w_heads = self.qkv_net(self.layer_norm(w))
+            else:
+                w_heads = self.qkv_net(w)
+            r_head_k = self.r_net(r)
+
+            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+
+        klen = w_head_k.size(0)
+
+        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)  # qlen x bsz x n_head x d_head
+        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)  # klen x bsz x n_head x d_head
+        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)  # klen x bsz x n_head x d_head
+
+        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)       # qlen x n_head x d_head
+
+        # compute attention score
+        rw_head_q = w_head_q + r_w_bias                                # qlen x bsz x n_head x d_head
+
+        # AC = torch.einsum('ibnd,jbnd->bijn', (rw_head_q, w_head_k))    # bsz x qlen x klen x n_head
+        rw_head_q = rw_head_q.view(qlen, bsz * self.n_head, self.d_head).permute(1, 0, 2)
+        w_head_k = w_head_k.reshape(klen, bsz * self.n_head, self.d_head).permute(1, 2, 0)
+        AC = torch.bmm(rw_head_q, w_head_k).view(bsz, self.n_head, qlen, klen).permute(0, 2, 3 ,1)
+
+        rr_head_q = w_head_q + r_r_bias
+
+        # BD = torch.einsum('ibnd,jnd->bijn', (rr_head_q, r_head_k))     # bsz x qlen x klen x n_head
+        rr_head_q = rr_head_q.permute(2, 1, 0, 3).reshape(self.n_head, bsz * qlen, self.d_head)
+        r_head_k = r_head_k.permute(1, 2, 0).view(self.n_head, self.d_head, klen)
+        BD = torch.bmm(rr_head_q, r_head_k).permute(1, 2, 0).view(bsz, qlen, klen, self.n_head)
+
+        BD = self._rel_shift(BD, False)
+
+        # [bsz x qlen x klen x n_head]
+        attn_score = AC + BD
+        attn_score.mul_(self.scale)
+
+        # compute attention probability
+        if attn_mask is not None and attn_mask.any():
+            if attn_mask.dim() == 2:
+                attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
+            elif attn_mask.dim() == 3:
+                attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
+
+        # [bsz x qlen x klen x n_head]
+        attn_prob = F.softmax(attn_score, dim=2)
+        attn_prob = self.dropatt(attn_prob)
+
+        # compute attention vector
+        # attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, w_head_v))
+        attn_prob = attn_prob.permute(0, 3, 1 ,2).reshape(bsz * self.n_head, qlen, klen)
+        w_head_v = w_head_v.permute(1, 2, 0, 3).reshape(bsz * self.n_head, klen, self.d_head)
+        attn_vec = torch.bmm(attn_prob, w_head_v).permute(1, 0, 2).view(qlen, bsz, self.n_head, self.d_head)
+
+        # [qlen x bsz x n_head x d_head]
+        attn_vec = attn_vec.reshape(
+            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
+
+        # linear projection
+        attn_out = self.o_net(attn_vec)
+        attn_out = self.drop(attn_out)
+
+        if self.pre_lnorm:
+            # residual connection
+            output = w + attn_out
+        else:
+            # residual connection + layer normalization
+            output = self.layer_norm(w + attn_out)
+            output = output.type_as(w)
+
+        return output
+
+
+class RelPartialLearnableDecoderLayer(torch.jit.ScriptModule):
+    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
+                 **kwargs):
+        super(RelPartialLearnableDecoderLayer, self).__init__()
+
+        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
+                                                         d_head, dropout,
+                                                         **kwargs)
+        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
+                                     pre_lnorm=kwargs.get('pre_lnorm'))
+
+    @torch.jit.script_method
+    def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask,
+                mems: Optional[torch.Tensor] = None
+                ):
+
+        output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
+                               attn_mask=dec_attn_mask,
+                               mems=mems)
+        output = self.pos_ff(output)
+
+        return output
+
+
+class AdaptiveEmbedding(torch.jit.ScriptModule):
+    __constants__ = ['div_val', 'd_proj', 'd_embed', 'emb_scale']
+
+    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
+                 sample_softmax=False):
+        super(AdaptiveEmbedding, self).__init__()
+
+        self.n_token = n_token
+        self.d_embed = d_embed
+
+        self.cutoffs = cutoffs + [n_token]
+        self.div_val = div_val
+        self.d_proj = d_proj
+
+        self.emb_scale = d_proj ** 0.5
+
+        self.cutoff_ends = [0] + self.cutoffs
+
+        self.emb_layers = nn.ModuleList()
+        self.emb_projs = nn.ParameterList()
+        if div_val != 1:
+            raise RuntimeError('TorchScripted model supports only div_val == 1')
+        if d_proj != d_embed:
+            raise RuntimeError('TorchScripted model supports only d_proj == d_embed')
+        self.emb_layers.append(nn.Embedding(n_token, d_embed))
+
+    @torch.jit.script_method
+    def forward(self, x):
+        for emb_layer in self.emb_layers:
+            x = emb_layer(x)
+
+        x.mul_(self.emb_scale)
+
+        return x
+
+
+class MemTransformerLM(torch.jit.ScriptModule):
+    __constants__ = ['same_length', 'mem_len', 'clamp_len', 'ext_len',
+                     'n_layer', 'dtype']
+
+    def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
+                 dropout, dropatt, dtype, tie_weight=True, d_embed=None,
+                 div_val=1, tie_projs=[False], pre_lnorm=False,
+                 tgt_len=None, ext_len=None, mem_len=None,
+                 cutoffs=[], adapt_inp=False,
+                 same_length=False, attn_type=0, clamp_len=-1,
+                 sample_softmax=-1):
+        super(MemTransformerLM, self).__init__()
+        self.n_token = n_token
+
+        d_embed = d_model if d_embed is None else d_embed
+        self.d_embed = d_embed
+        self.d_model = d_model
+        self.n_head = n_head
+        self.d_head = d_head
+
+        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
+                                          div_val=div_val)
+
+        self.drop = nn.Dropout(dropout)
+
+        self.n_layer = n_layer
+
+        self.tgt_len = tgt_len
+        self.mem_len = mem_len
+        self.ext_len = ext_len
+        self.max_klen = tgt_len + ext_len + mem_len
+
+        self.dtype = dtype
+
+        self.attn_type = attn_type
+        if attn_type != 0:
+            raise RuntimeError('TorchScripted supports only attn_type == 0')
+
+        self.layers = nn.ModuleList()
+        # the default attention
+        for i in range(n_layer):
+            self.layers.append(
+                RelPartialLearnableDecoderLayer(
+                    n_head, d_model, d_head, d_inner, dropout,
+                    tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
+                    dropatt=dropatt, pre_lnorm=pre_lnorm)
+            )
+
+        self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
+                                                cutoffs, div_val=div_val)
+
+        self.same_length = same_length
+        self.clamp_len = clamp_len
+
+        self._create_params()
+
+    def _create_params(self):
+        self.pos_emb = PositionalEmbedding(self.d_model)
+        self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
+        self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
+
+    @torch.jit.script_method
+    def init_mems(self):
+        mems = []
+        for i in range(self.n_layer+1):
+            empty = torch.empty(0, dtype=self.dtype, device=torch.device('cuda'))
+            mems.append(empty)
+
+        return mems
+
+    def _update_mems(self, hids: List[torch.Tensor], mems: List[torch.Tensor],
+                     qlen: int, mlen: int):
+        assert len(hids) == len(mems), 'len(hids) != len(mems)'
+
+        # There are `mlen + qlen` steps that can be cached into mems
+        # For the next step, the last `ext_len` of the `qlen` tokens
+        # will be used as the extended context. Hence, we only cache
+        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
+        # to `mlen + qlen - self.ext_len`.
+        new_mems = []
+        end_idx = mlen + max(0, qlen - 0 - self.ext_len)
+        beg_idx = max(0, end_idx - self.mem_len)
+        for i in range(len(hids)):
+
+            cat = torch.cat([mems[i], hids[i]], dim=0)
+            new_mems.append(cat[beg_idx:end_idx].detach())
+
+        return new_mems
+
+    @torch.jit.script_method
+    def _forward(self, dec_inp, mems: List[torch.Tensor]):
+        qlen, bsz = dec_inp.size()
+
+        word_emb = self.word_emb(dec_inp)
+
+        mlen = mems[0].size(0)
+        klen = mlen + qlen
+        if self.same_length:
+            # all_ones = word_emb.new_ones(qlen, klen)
+            all_ones = torch.ones((qlen, klen), device=torch.device('cuda'), dtype=torch.float32)
+            mask_len = klen - self.mem_len
+            if mask_len > 0:
+                mask_shift_len = qlen - mask_len
+            else:
+                mask_shift_len = qlen
+            dec_attn_mask = (torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len)).to(torch.bool)
+        else:
+            all_ones = torch.ones((qlen, klen), device=torch.device('cuda'), dtype=torch.float32)
+            dec_attn_mask = torch.triu(
+                all_ones, diagonal=1+mlen).to(torch.bool)
+
+        hids = []
+        pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
+                               dtype=word_emb.dtype)
+        if self.clamp_len > 0:
+            pos_seq.clamp_(max=self.clamp_len)
+        pos_emb = self.pos_emb(pos_seq)
+
+        core_out = self.drop(word_emb)
+        pos_emb = self.drop(pos_emb)
+
+        hids.append(core_out)
+        i = 0
+        for layer in self.layers:
+            mems_i = None if mems is None else mems[i]
+            core_out = layer(core_out, pos_emb, self.r_w_bias,
+                             self.r_r_bias, dec_attn_mask=dec_attn_mask,
+                             mems=mems_i)
+            hids.append(core_out)
+            i += 1
+
+        core_out = self.drop(core_out)
+
+        new_mems = self._update_mems(hids, mems, qlen, mlen)
+
+        return core_out, new_mems
+
+    @torch.jit.script_method
+    def forward(self, data, target, mems: Optional[List[torch.Tensor]]):
+        # nn.DataParallel does not allow size(0) tensors to be broadcasted.
+        # So, have to initialize size(0) mems inside the model forward.
+        # Moreover, have to return new_mems to allow nn.DataParallel to piece
+        # them together.
+        if mems is None:
+            mems = self.init_mems()
+
+        tgt_len = target.size(0)
+        hidden, new_mems = self._forward(data, mems=mems)
+
+        pred_hid = hidden[-tgt_len:]
+        loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
+        loss = loss.view(tgt_len, -1)
+
+        if new_mems is None:
+            return [loss]
+        else:
+            return [loss] + new_mems

+ 141 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/inference/proj_adaptive_softmax_jit.py

@@ -0,0 +1,141 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ProjectedAdaptiveLogSoftmax(torch.jit.ScriptModule):
+    __constants__ = ['n_clusters', 'cutoffs', 'cutoff_ends', 'keep_order']
+
+    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
+                 keep_order=False):
+        super().__init__()
+
+        self.n_token = n_token
+        self.d_embed = d_embed
+        self.d_proj = d_proj
+
+        self.cutoffs = cutoffs + [n_token]
+        self.cutoff_ends = type(self.cutoffs)([0]) + self.cutoffs
+        self.div_val = div_val
+
+        self.shortlist_size = self.cutoffs[0]
+        self.n_clusters = len(self.cutoffs) - 1
+        self.head_size = self.shortlist_size + self.n_clusters
+
+        if self.n_clusters > 0:
+            self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
+            self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
+
+        self.out_layers = nn.ModuleList()
+
+        if d_proj != d_embed:
+            raise RuntimeError('TorchScripted module requires d_proj == d_embed')
+        if div_val != 1:
+            raise RuntimeError('TorchScripted module requires div_val == 1')
+
+        self.out_layers.append(nn.Linear(d_embed, n_token))
+
+        self.keep_order = keep_order
+
+    @torch.jit.script_method
+    def _compute_logit(self, hidden, weight, bias, proj: Optional[torch.Tensor]):
+        if proj is not None:
+            raise RuntimeError('TorchScripted module requires proj == None')
+        logit = F.linear(hidden, weight, bias=bias)
+        return logit
+
+    @torch.jit.script_method
+    def forward(self, hidden, target, keep_order: bool = False):
+        '''
+            hidden :: [len*bsz x d_proj]
+            target :: [len*bsz]
+        '''
+
+        if hidden.size(0) != target.size(0):
+            raise RuntimeError('Input and target should have the same size '
+                               'in the batch dimension.')
+
+        if self.n_clusters == 0:
+            for out_layer in self.out_layers:
+                hidden = self._compute_logit(hidden, out_layer.weight,
+                                             out_layer.bias, None)
+            nll = -F.log_softmax(hidden, dim=-1) \
+                    .gather(1, target.unsqueeze(1)).squeeze(1)
+        else:
+            # construct weights and biases
+            weights, biases = [], []
+            for i in range(len(self.cutoffs)):
+                for out_layer in self.out_layers:
+                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+                    weight_i = out_layer.weight[l_idx:r_idx]
+                    bias_i = out_layer.bias[l_idx:r_idx]
+
+                    if i == 0:
+                        weight_i = torch.cat(
+                            [weight_i, self.cluster_weight], dim=0)
+                        bias_i = torch.cat(
+                            [bias_i, self.cluster_bias], dim=0)
+
+                    weights.append(weight_i)
+                    biases.append(bias_i)
+
+            head_weight, head_bias, head_proj = weights[0], biases[0], None
+
+            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
+            head_logprob = F.log_softmax(head_logit, dim=1)
+
+            nll = torch.zeros_like(target, layout=torch.strided,
+                                   dtype=hidden.dtype, device=hidden.device)
+
+            offset = 0
+            cutoff_values = [0] + self.cutoffs
+            for i in range(len(cutoff_values) - 1):
+                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
+
+                mask_i = (target >= l_idx) & (target < r_idx)
+                indices_i = mask_i.nonzero().squeeze()
+
+                if indices_i.numel() == 0:
+                    continue
+
+                target_i = target.index_select(0, indices_i) - l_idx
+                head_logprob_i = head_logprob.index_select(0, indices_i)
+
+                if i == 0:
+                    logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
+                else:
+                    weight_i, bias_i, proj_i = weights[i], biases[i], None
+
+                    hidden_i = hidden.index_select(0, indices_i)
+
+                    tail_logit_i = self._compute_logit(hidden_i, weight_i,
+                                                       bias_i, proj_i)
+                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
+
+                    logprob_i = head_logprob_i[:, -i] \
+                        + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
+
+                if self.keep_order or keep_order:
+                    nll.index_copy_(0, indices_i, -logprob_i)
+                else:
+                    nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
+
+                offset += logprob_i.size(0)
+
+        return nll

+ 247 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/lamb.py

@@ -0,0 +1,247 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2019 cybertronai
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+"""Lamb optimizer."""
+
+import collections
+import math
+
+import torch
+from torch.optim import Optimizer
+
+
+class Lamb(Optimizer):
+    r"""Implements Lamb algorithm.
+
+    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
+
+    Arguments:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        adam (bool, optional): always use trust ratio = 1, which turns this into
+            Adam. Useful for comparison purposes.
+
+    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
+        https://arxiv.org/abs/1904.00962
+    """
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
+                 weight_decay=0, adam=False):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay)
+        self.adam = adam
+        super(Lamb, self).__init__(params, defaults)
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError('Lamb does not support sparse gradients.')
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p.data)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                state['step'] += 1
+
+                # Decay the first and second moment running average coefficient
+                # m_t
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+                # v_t
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+
+                # Paper v3 does not use debiasing.
+                # bias_correction1 = 1 - beta1 ** state['step']
+                # bias_correction2 = 1 - beta2 ** state['step']
+                # Apply bias to lr to avoid broadcast.
+                step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
+
+                weight_norm = p.data.norm(p=2).clamp_(0, 10)
+
+                adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
+                if group['weight_decay'] != 0:
+                    adam_step.add_(group['weight_decay'], p.data)
+
+                adam_norm = adam_step.norm(p=2)
+
+                trust_ratio = weight_norm / (adam_norm + group['eps'])
+                state['weight_norm'] = weight_norm
+                state['adam_norm'] = adam_norm
+                state['trust_ratio'] = trust_ratio
+                if self.adam:
+                    trust_ratio = 1
+
+                p.data.add_(-step_size * trust_ratio, adam_step)
+
+        return loss
+
+
[email protected]
+def lamb_kernel(param, grad, exp_avg, exp_avg_sq, beta1: float,
+                beta2: float, step_size: float, eps: float, weight_decay: float):
+    exp_avg = exp_avg * beta1 + (1 - beta1) * grad
+    exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad)
+
+    adam_step = exp_avg / (exp_avg_sq.sqrt() + eps)
+    adam_step = adam_step + weight_decay * param
+
+    weight_norm = param.norm(p=2).clamp_(0, 10)
+    adam_norm = adam_step.norm(p=2)
+
+    trust_ratio = weight_norm / (adam_norm + eps)
+
+    param = param - step_size * trust_ratio * adam_step
+    return param, exp_avg, exp_avg_sq
+
+
+class JITLamb(Optimizer):
+    r"""Implements Lamb algorithm.
+
+    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
+
+    Arguments:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        adam (bool, optional): always use trust ratio = 1, which turns this into
+            Adam. Useful for comparison purposes.
+
+    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
+        https://arxiv.org/abs/1904.00962
+    """
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
+                 weight_decay=0, adam=False):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay)
+        self.adam = adam
+        super().__init__(params, defaults)
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError('Lamb does not support sparse gradients.')
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p.data)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                state['step'] += 1
+                step_size = group['lr']
+
+                param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg,
+                                                         exp_avg_sq, beta1,
+                                                         beta2, step_size,
+                                                         group['eps'],
+                                                         group['weight_decay'],
+                                                         )
+                state['exp_avg'] = exp_avg
+                state['exp_avg_sq'] = exp_avg_sq
+                p.data = param
+
+        return loss

+ 842 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/mem_transformer.py

@@ -0,0 +1,842 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
+from utils.log_uniform_sampler import LogUniformSampler
+from utils.log_uniform_sampler import sample_logits
+
+
+class PositionalEmbedding(nn.Module):
+    def __init__(self, demb):
+        super(PositionalEmbedding, self).__init__()
+
+        self.demb = demb
+
+        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
+        self.register_buffer('inv_freq', inv_freq)
+
+    def forward(self, pos_seq, bsz=None):
+        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
+        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
+
+        if bsz is not None:
+            return pos_emb[:, None, :].expand(-1, bsz, -1)
+        else:
+            return pos_emb[:, None, :]
+
+
+class PositionwiseFF(nn.Module):
+    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
+        super(PositionwiseFF, self).__init__()
+
+        self.d_model = d_model
+        self.d_inner = d_inner
+        self.dropout = dropout
+
+        self.CoreNet = nn.Sequential(
+            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
+            nn.Dropout(dropout),
+            nn.Linear(d_inner, d_model),
+            nn.Dropout(dropout),
+        )
+
+        self.layer_norm = nn.LayerNorm(d_model)
+
+        self.pre_lnorm = pre_lnorm
+
+    def forward(self, inp):
+        if self.pre_lnorm:
+            # layer normalization + positionwise feed-forward
+            core_out = self.CoreNet(self.layer_norm(inp))
+
+            # residual connection
+            output = core_out + inp
+        else:
+            # positionwise feed-forward
+            core_out = self.CoreNet(inp)
+
+            # residual connection + layer normalization
+            output = self.layer_norm(inp + core_out)
+
+        return output
+
+
+class MultiHeadAttn(nn.Module):
+    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
+                 pre_lnorm=False):
+        super(MultiHeadAttn, self).__init__()
+
+        self.n_head = n_head
+        self.d_model = d_model
+        self.d_head = d_head
+        self.dropout = dropout
+
+        self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
+        self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
+
+        self.drop = nn.Dropout(dropout)
+        self.dropatt = nn.Dropout(dropatt)
+        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
+
+        self.layer_norm = nn.LayerNorm(d_model)
+
+        self.scale = 1 / (d_head ** 0.5)
+
+        self.pre_lnorm = pre_lnorm
+
+    def forward(self, h, attn_mask=None, mems=None):
+        # multihead attention
+        # [hlen x bsz x n_head x d_head]
+
+        if mems is not None:
+            c = torch.cat([mems, h], 0)
+        else:
+            c = h
+
+        if self.pre_lnorm:
+            # layer normalization
+            c = self.layer_norm(c)
+
+        head_q = self.q_net(h)
+        head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
+
+        head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
+        head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
+        head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
+
+        # [bsz x qlen x klen x n_head]
+        attn_score = torch.einsum('ibnd,jbnd->bijn', (head_q, head_k))
+        attn_score.mul_(self.scale)
+        if attn_mask is not None and attn_mask.any().item():
+            if attn_mask.dim() == 2:
+                attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
+            elif attn_mask.dim() == 3:
+                attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
+
+        # [bsz x qlen x klen x n_head]
+        attn_prob = F.softmax(attn_score, dim=2)
+        attn_prob = self.dropatt(attn_prob)
+
+        # [bsz x qlen x klen x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
+        attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, head_v))
+        attn_vec = attn_vec.contiguous().view(
+            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
+
+        # linear projection
+        attn_out = self.o_net(attn_vec)
+        attn_out = self.drop(attn_out)
+
+        if self.pre_lnorm:
+            # residual connection
+            output = h + attn_out
+        else:
+            # residual connection + layer normalization
+            output = self.layer_norm(h + attn_out)
+
+        return output
+
+
+class RelMultiHeadAttn(nn.Module):
+    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
+                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
+        super(RelMultiHeadAttn, self).__init__()
+
+        self.n_head = n_head
+        self.d_model = d_model
+        self.d_head = d_head
+        self.dropout = dropout
+
+        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
+
+        self.drop = nn.Dropout(dropout)
+        self.dropatt = nn.Dropout(dropatt)
+        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
+
+        self.layer_norm = nn.LayerNorm(d_model)
+
+        self.scale = 1 / (d_head ** 0.5)
+
+        self.pre_lnorm = pre_lnorm
+
+    def _parallelogram_mask(self, h, w, left=False):
+        mask = torch.ones((h, w)).byte()
+        m = min(h, w)
+        mask[:m, :m] = torch.triu(mask[:m, :m])
+        mask[-m:, -m:] = torch.tril(mask[-m:, -m:])
+
+        if left:
+            return mask.bool()
+        else:
+            return mask.flip(0).bool()
+
+    def _shift(self, x, qlen, klen, mask, left=False):
+        if qlen > 1:
+            zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
+                                   device=x.device, dtype=x.dtype)
+        else:
+            zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
+
+        if left:
+            mask = mask.flip(1)
+            x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
+        else:
+            x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
+
+        x = x_padded.masked_select(mask[:, :, None, None]) \
+                    .view(qlen, klen, x.size(2), x.size(3))
+
+        return x
+
+    def _rel_shift(self, x, zero_triu=False):
+        zero_pad = torch.zeros((x.size(0), x.size(1), 1, x.size(3)),
+                               device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=2)
+
+        x_padded = x_padded.view(x.size(0), x.size(2) + 1, x.size(1), x.size(3))
+
+        x = x_padded[:, 1:].view_as(x)
+
+        if zero_triu:
+            ones = torch.ones((x.size(0), x.size(1)))
+            x = x * torch.tril(ones, x.size(1) - x.size(0))[None, :, :, None]
+
+        return x
+
+    def forward(self, w, r, attn_mask=None, mems=None):
+        raise NotImplementedError
+
+
+class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
+    def __init__(self, *args, **kwargs):
+        super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
+
+        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
+
+    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
+        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
+
+        if mems is not None:
+            cat = torch.cat([mems, w], 0)
+            if self.pre_lnorm:
+                w_heads = self.qkv_net(self.layer_norm(cat))
+            else:
+                w_heads = self.qkv_net(cat)
+            r_head_k = self.r_net(r)
+
+            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+            w_head_q = w_head_q[-qlen:]
+        else:
+            if self.pre_lnorm:
+                w_heads = self.qkv_net(self.layer_norm(w))
+            else:
+                w_heads = self.qkv_net(w)
+            r_head_k = self.r_net(r)
+
+            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+
+        klen = w_head_k.size(0)
+
+        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)  # qlen x bsz x n_head x d_head
+        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)  # klen x bsz x n_head x d_head
+        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)  # klen x bsz x n_head x d_head
+
+        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)       # qlen x n_head x d_head
+
+        # compute attention score
+        rw_head_q = w_head_q + r_w_bias                                # qlen x bsz x n_head x d_head
+        AC = torch.einsum('ibnd,jbnd->bijn', (rw_head_q, w_head_k))    # bsz x qlen x klen x n_head
+
+        rr_head_q = w_head_q + r_r_bias
+        BD = torch.einsum('ibnd,jnd->bijn', (rr_head_q, r_head_k))     # bsz x qlen x klen x n_head
+        BD = self._rel_shift(BD)
+
+        # [bsz x qlen x klen x n_head]
+        attn_score = AC + BD
+        attn_score.mul_(self.scale)
+
+        # compute attention probability
+        if attn_mask is not None and attn_mask.any().item():
+            if attn_mask.dim() == 2:
+                attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
+            elif attn_mask.dim() == 3:
+                attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
+
+        # [bsz x qlen x klen x n_head]
+        attn_prob = F.softmax(attn_score, dim=2)
+        attn_prob = self.dropatt(attn_prob)
+
+        # compute attention vector
+        attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, w_head_v))
+
+        # [qlen x bsz x n_head x d_head]
+        attn_vec = attn_vec.contiguous().view(
+            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
+
+        # linear projection
+        attn_out = self.o_net(attn_vec)
+        attn_out = self.drop(attn_out)
+
+        if self.pre_lnorm:
+            # residual connection
+            output = w + attn_out
+        else:
+            # residual connection + layer normalization
+            output = self.layer_norm(w + attn_out)
+
+        return output
+
+
+class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
+    def __init__(self, *args, **kwargs):
+        super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
+
+    def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
+        # r_emb: [klen, n_head, d_head], used for term B
+        # r_w_bias: [n_head, d_head], used for term C
+        # r_bias: [klen, n_head], used for term D
+
+        qlen, bsz = w.size(0), w.size(1)
+
+        if mems is not None:
+            cat = torch.cat([mems, w], 0)
+            if self.pre_lnorm:
+                w_heads = self.qkv_net(self.layer_norm(cat))
+            else:
+                w_heads = self.qkv_net(cat)
+            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+
+            w_head_q = w_head_q[-qlen:]
+        else:
+            if self.pre_lnorm:
+                w_heads = self.qkv_net(self.layer_norm(w))
+            else:
+                w_heads = self.qkv_net(w)
+            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+
+        klen = w_head_k.size(0)
+
+        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
+        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
+        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
+
+        if klen > r_emb.size(0):
+            r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
+            r_emb = torch.cat([r_emb_pad, r_emb], 0)
+            r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
+            r_bias = torch.cat([r_bias_pad, r_bias], 0)
+        else:
+            r_emb = r_emb[-klen:]
+            r_bias = r_bias[-klen:]
+
+        # compute attention score
+        rw_head_q = w_head_q + r_w_bias[None]                        # qlen x bsz x n_head x d_head
+
+        AC = torch.einsum('ibnd,jbnd->bijn', (rw_head_q, w_head_k))  # bsz x qlen x klen x n_head
+        B_ = torch.einsum('ibnd,jnd->bijn', (w_head_q, r_emb))       # bsz x qlen x klen x n_head
+        D_ = r_bias[None, None, :, :]                                # 1   x 1    x klen x n_head
+        BD = self._rel_shift(B_ + D_)
+
+        # [bsz x qlen x klen x n_head]
+        attn_score = AC + BD
+        attn_score.mul_(self.scale)
+
+        # compute attention probability
+        if attn_mask is not None and attn_mask.any().item():
+            if attn_mask.dim() == 2:
+                attn_score.masked_fill_(attn_mask[None, :, :, None], -float('inf'))
+            elif attn_mask.dim() == 3:
+                attn_score.masked_fill_(attn_mask[:, :, :, None], -float('inf'))
+
+        # [bsz x qlen x klen x n_head]
+        attn_prob = F.softmax(attn_score, dim=2)
+        attn_prob = self.dropatt(attn_prob)
+
+        # compute attention vector
+        attn_vec = torch.einsum('bijn,jbnd->ibnd', (attn_prob, w_head_v))
+
+        # [qlen x bsz x n_head x d_head]
+        attn_vec = attn_vec.contiguous().view(
+            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
+
+        # linear projection
+        attn_out = self.o_net(attn_vec)
+        attn_out = self.drop(attn_out)
+
+        if self.pre_lnorm:
+            # residual connection
+            output = w + attn_out
+        else:
+            # residual connection + layer normalization
+            output = self.layer_norm(w + attn_out)
+
+        return output
+
+
+class DecoderLayer(nn.Module):
+    def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
+        super(DecoderLayer, self).__init__()
+
+        self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
+        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
+                                     pre_lnorm=kwargs.get('pre_lnorm'))
+
+    def forward(self, dec_inp, dec_attn_mask=None, mems=None):
+
+        output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
+                               mems=mems)
+        output = self.pos_ff(output)
+
+        return output
+
+
+class RelLearnableDecoderLayer(nn.Module):
+    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
+                 **kwargs):
+        super(RelLearnableDecoderLayer, self).__init__()
+
+        self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head,
+                                                  dropout, **kwargs)
+        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
+                                     pre_lnorm=kwargs.get('pre_lnorm'))
+
+    def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
+
+        output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
+                               attn_mask=dec_attn_mask,
+                               mems=mems)
+        output = self.pos_ff(output)
+
+        return output
+
+
+class RelPartialLearnableDecoderLayer(nn.Module):
+    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
+                 **kwargs):
+        super(RelPartialLearnableDecoderLayer, self).__init__()
+
+        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
+                                                         d_head, dropout,
+                                                         **kwargs)
+        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
+                                     pre_lnorm=kwargs.get('pre_lnorm'))
+
+    def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
+
+        output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
+                               attn_mask=dec_attn_mask,
+                               mems=mems)
+        output = self.pos_ff(output)
+
+        return output
+
+
+class AdaptiveEmbedding(nn.Module):
+    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
+                 sample_softmax=False):
+        super(AdaptiveEmbedding, self).__init__()
+
+        self.n_token = n_token
+        self.d_embed = d_embed
+
+        self.cutoffs = cutoffs + [n_token]
+        self.div_val = div_val
+        self.d_proj = d_proj
+
+        self.emb_scale = d_proj ** 0.5
+
+        self.cutoff_ends = [0] + self.cutoffs
+
+        self.emb_layers = nn.ModuleList()
+        self.emb_projs = nn.ParameterList()
+        if div_val == 1:
+            self.emb_layers.append(
+                nn.Embedding(n_token, d_embed, sparse=(sample_softmax > 0))
+            )
+            if d_proj != d_embed:
+                self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
+        else:
+            for i in range(len(self.cutoffs)):
+                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
+                d_emb_i = d_embed // (div_val ** i)
+                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
+                self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
+
+    def forward(self, inp):
+        if self.div_val == 1:
+            embed = self.emb_layers[0](inp)
+            if self.d_proj != self.d_embed:
+                embed = F.linear(embed, self.emb_projs[0])
+        else:
+            param = next(self.parameters())
+            inp_flat = inp.view(-1)
+            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
+                                   dtype=param.dtype, device=param.device)
+            for i in range(len(self.cutoffs)):
+                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+
+                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
+                indices_i = mask_i.nonzero().squeeze()
+
+                if indices_i.numel() == 0:
+                    continue
+
+                inp_i = inp_flat.index_select(0, indices_i) - l_idx
+                emb_i = self.emb_layers[i](inp_i)
+                emb_i = F.linear(emb_i, self.emb_projs[i])
+
+                emb_flat.index_copy_(0, indices_i, emb_i)
+
+            embed = emb_flat.view(*inp.size(), self.d_proj)
+
+        embed.mul_(self.emb_scale)
+
+        return embed
+
+
+class MemTransformerLM(nn.Module):
+    def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
+                 dropout, dropatt, dtype, tie_weight=True, d_embed=None,
+                 div_val=1, tie_projs=[False], pre_lnorm=False,
+                 tgt_len=None, ext_len=None, mem_len=None,
+                 cutoffs=[], adapt_inp=False,
+                 same_length=False, attn_type=0, clamp_len=-1,
+                 sample_softmax=-1):
+        super(MemTransformerLM, self).__init__()
+        self.n_token = n_token
+
+        d_embed = d_model if d_embed is None else d_embed
+        self.d_embed = d_embed
+        self.d_model = d_model
+        self.n_head = n_head
+        self.d_head = d_head
+
+        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
+                                          div_val=div_val)
+
+        self.drop = nn.Dropout(dropout)
+
+        self.n_layer = n_layer
+
+        self.tgt_len = tgt_len
+        self.mem_len = mem_len
+        self.ext_len = ext_len
+        self.max_klen = tgt_len + ext_len + mem_len
+
+        self.attn_type = attn_type
+
+        self.layers = nn.ModuleList()
+        # the default attention
+        if attn_type == 0:
+            for i in range(n_layer):
+                self.layers.append(
+                    RelPartialLearnableDecoderLayer(
+                        n_head, d_model, d_head, d_inner, dropout,
+                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
+                        dropatt=dropatt, pre_lnorm=pre_lnorm)
+                )
+        # learnable embeddings
+        elif attn_type == 1:
+            for i in range(n_layer):
+                self.layers.append(
+                    RelLearnableDecoderLayer(
+                        n_head, d_model, d_head, d_inner, dropout,
+                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
+                        dropatt=dropatt, pre_lnorm=pre_lnorm)
+                )
+        # absolute embeddings
+        elif attn_type in [2, 3]:
+            for i in range(n_layer):
+                self.layers.append(
+                    DecoderLayer(
+                        n_head, d_model, d_head, d_inner, dropout,
+                        dropatt=dropatt, pre_lnorm=pre_lnorm)
+                )
+
+        self.sample_softmax = sample_softmax
+        # use sampled softmax
+        if sample_softmax > 0:
+            self.out_layer = nn.Linear(d_model, n_token)
+            if tie_weight:
+                self.out_layer.weight = self.word_emb.weight
+            self.tie_weight = tie_weight
+            self.sampler = LogUniformSampler(n_token, sample_softmax)
+
+        # use adaptive softmax (including standard softmax)
+        else:
+            self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
+                                                    cutoffs, div_val=div_val)
+
+            if tie_weight:
+                for i in range(len(self.crit.out_layers)):
+                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
+
+            if tie_projs:
+                for i, tie_proj in enumerate(tie_projs):
+                    if tie_proj and div_val == 1 and d_model != d_embed:
+                        self.crit.out_projs[i] = self.word_emb.emb_projs[0]
+                    elif tie_proj and div_val != 1:
+                        self.crit.out_projs[i] = self.word_emb.emb_projs[i]
+
+        self.same_length = same_length
+        self.clamp_len = clamp_len
+
+        self._create_params()
+
+    def backward_compatible(self):
+        self.sample_softmax = -1
+
+    def _create_params(self):
+        # default attention
+        if self.attn_type == 0:
+            self.pos_emb = PositionalEmbedding(self.d_model)
+            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
+            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
+        # learnable
+        elif self.attn_type == 1:
+            self.r_emb = nn.Parameter(torch.Tensor(
+                    self.n_layer, self.max_klen, self.n_head, self.d_head))
+            self.r_w_bias = nn.Parameter(torch.Tensor(
+                    self.n_layer, self.n_head, self.d_head))
+            self.r_bias = nn.Parameter(torch.Tensor(
+                    self.n_layer, self.max_klen, self.n_head))
+        # absolute standard
+        elif self.attn_type == 2:
+            self.pos_emb = PositionalEmbedding(self.d_model)
+        # absolute deeper SA
+        elif self.attn_type == 3:
+            self.r_emb = nn.Parameter(torch.Tensor(
+                    self.n_layer, self.max_klen, self.d_model))
+
+    def reset_length(self, tgt_len, ext_len, mem_len):
+        self.tgt_len = tgt_len
+        self.mem_len = mem_len
+        self.ext_len = ext_len
+
+    def init_mems(self):
+        if self.mem_len > 0:
+            mems = []
+            param = next(self.parameters())
+            for i in range(self.n_layer+1):
+                empty = torch.empty(0, dtype=param.dtype, device=param.device)
+                mems.append(empty)
+
+            return mems
+        else:
+            return None
+
+    def _update_mems(self, hids, mems, qlen, mlen):
+        # does not deal with None
+        if mems is None:
+            return None
+
+        # mems is not None
+        assert len(hids) == len(mems), 'len(hids) != len(mems)'
+
+        # There are `mlen + qlen` steps that can be cached into mems
+        # For the next step, the last `ext_len` of the `qlen` tokens
+        # will be used as the extended context. Hence, we only cache
+        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
+        # to `mlen + qlen - self.ext_len`.
+        with torch.no_grad():
+            new_mems = []
+            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
+            beg_idx = max(0, end_idx - self.mem_len)
+            for i in range(len(hids)):
+
+                cat = torch.cat([mems[i], hids[i]], dim=0)
+                new_mems.append(cat[beg_idx:end_idx].detach())
+
+        return new_mems
+
+    def _forward(self, dec_inp, mems=None):
+        qlen, bsz = dec_inp.size()
+
+        word_emb = self.word_emb(dec_inp)
+
+        mlen = mems[0].size(0) if mems is not None else 0
+        klen = mlen + qlen
+        if self.same_length:
+            all_ones = word_emb.new_ones(qlen, klen)
+            mask_len = klen - self.mem_len
+            if mask_len > 0:
+                mask_shift_len = qlen - mask_len
+            else:
+                mask_shift_len = qlen
+            dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+                             + torch.tril(all_ones, -mask_shift_len)).bool()
+        else:
+            dec_attn_mask = torch.triu(
+                word_emb.new_ones(qlen, klen), diagonal=1+mlen).bool()
+
+        hids = []
+        # default
+        if self.attn_type == 0:
+            pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
+                                   dtype=word_emb.dtype)
+            if self.clamp_len > 0:
+                pos_seq.clamp_(max=self.clamp_len)
+            pos_emb = self.pos_emb(pos_seq)
+
+            core_out = self.drop(word_emb)
+            pos_emb = self.drop(pos_emb)
+
+            hids.append(core_out)
+            for i, layer in enumerate(self.layers):
+                mems_i = None if mems is None else mems[i]
+                core_out = layer(core_out, pos_emb, self.r_w_bias,
+                                 self.r_r_bias, dec_attn_mask=dec_attn_mask,
+                                 mems=mems_i)
+                hids.append(core_out)
+        # learnable
+        elif self.attn_type == 1:
+            core_out = self.drop(word_emb)
+            hids.append(core_out)
+            for i, layer in enumerate(self.layers):
+                if self.clamp_len > 0:
+                    r_emb = self.r_emb[i][-self.clamp_len:]
+                    r_bias = self.r_bias[i][-self.clamp_len:]
+                else:
+                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]
+
+                mems_i = None if mems is None else mems[i]
+                core_out = layer(core_out, r_emb, self.r_w_bias[i],
+                                 r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
+                hids.append(core_out)
+        # absolute
+        elif self.attn_type == 2:
+            pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
+                                   dtype=word_emb.dtype)
+            if self.clamp_len > 0:
+                pos_seq.clamp_(max=self.clamp_len)
+            pos_emb = self.pos_emb(pos_seq)
+
+            core_out = self.drop(word_emb + pos_emb[-qlen:])
+
+            hids.append(core_out)
+            for i, layer in enumerate(self.layers):
+                mems_i = None if mems is None else mems[i]
+                if mems_i is not None and len(mems_i) and i == 0:
+                    mems_i += pos_emb[:mlen]
+                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
+                                 mems=mems_i)
+                hids.append(core_out)
+        elif self.attn_type == 3:
+            core_out = self.drop(word_emb)
+
+            hids.append(core_out)
+            for i, layer in enumerate(self.layers):
+                mems_i = None if mems is None else mems[i]
+                if mems_i is not None and len(mems_i) and mlen > 0:
+                    cur_emb = self.r_emb[i][:-qlen]
+                    cur_size = cur_emb.size(0)
+                    if cur_size < mlen:
+                        cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
+                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
+                    else:
+                        cur_emb = cur_emb[-mlen:]
+                    mems_i += cur_emb.view(mlen, 1, -1)
+                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
+
+                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
+                                 mems=mems_i)
+                hids.append(core_out)
+
+        core_out = self.drop(core_out)
+
+        new_mems = self._update_mems(hids, mems, qlen, mlen)
+
+        return core_out, new_mems
+
+    def forward(self, data, target, mems):
+        # nn.DataParallel does not allow size(0) tensors to be broadcasted.
+        # So, have to initialize size(0) mems inside the model forward.
+        # Moreover, have to return new_mems to allow nn.DataParallel to piece
+        # them together.
+        if mems is None:
+            mems = self.init_mems()
+
+        tgt_len = target.size(0)
+        hidden, new_mems = self._forward(data, mems=mems)
+
+        pred_hid = hidden[-tgt_len:]
+        if self.sample_softmax > 0 and self.training:
+            assert self.tie_weight
+            logit = sample_logits(self.word_emb, self.out_layer.bias, target,
+                                  pred_hid, self.sampler)
+            loss = -F.log_softmax(logit, -1)[:, :, 0]
+        else:
+            loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
+            loss = loss.view(tgt_len, -1)
+
+        if new_mems is None:
+            return [loss]
+        else:
+            return [loss] + new_mems
+
+
+if __name__ == '__main__':
+    import argparse
+
+    parser = argparse.ArgumentParser(description='unit test')
+
+    parser.add_argument('--n_layer', type=int, default=4, help='')
+    parser.add_argument('--n_rel_layer', type=int, default=4, help='')
+    parser.add_argument('--n_head', type=int, default=2, help='')
+    parser.add_argument('--d_head', type=int, default=2, help='')
+    parser.add_argument('--d_model', type=int, default=200, help='')
+    parser.add_argument('--d_embed', type=int, default=200, help='')
+    parser.add_argument('--d_inner', type=int, default=200, help='')
+    parser.add_argument('--dropout', type=float, default=0.0, help='')
+    parser.add_argument('--cuda', action='store_true', help='')
+    parser.add_argument('--seed', type=int, default=1111, help='')
+    parser.add_argument('--multi_gpu', action='store_true', help='')
+
+    args = parser.parse_args()
+
+    device = torch.device("cuda" if args.cuda else "cpu")
+
+    B = 4
+    tgt_len, mem_len, ext_len = 36, 36, 0
+    data_len = tgt_len * 20
+    args.n_token = 10000
+
+    import data_utils
+
+    data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
+    diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)
+
+    cutoffs = [args.n_token // 2]
+    tie_projs = [False] + [True] * len(cutoffs)
+
+    for div_val in [1, 2]:
+        for d_embed in [200, 100]:
+            model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
+                                     args.d_model, args.d_head, args.d_inner, args.dropout,
+                                     dropatt=args.dropout, tie_weight=True,
+                                     d_embed=d_embed, div_val=div_val,
+                                     tie_projs=tie_projs, pre_lnorm=True,
+                                     tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
+                                     cutoffs=cutoffs, attn_type=0).to(device)
+
+            print(sum(p.numel() for p in model.parameters()))
+
+            mems = tuple()
+            for idx, (inp, tgt, seqlen) in enumerate(diter):
+                print('batch {}'.format(idx))
+                out = model(inp, tgt, *mems)
+                mems = out[1:]

+ 2 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/requirements.txt

@@ -0,0 +1,2 @@
+pytorch-transformers==1.1.0
+sacremoses==0.0.35

+ 41 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_enwik8_base.sh

@@ -0,0 +1,41 @@
+#!/bin/bash
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python train.py \
+        --cuda \
+        --data ../data/enwik8/ \
+        --dataset enwik8 \
+        --n_layer 12 \
+        --d_model 512 \
+        --n_head 8 \
+        --d_head 64 \
+        --d_inner 2048 \
+        --dropout 0.1 \
+        --dropatt 0.0 \
+        --optim adam \
+        --lr 0.00025 \
+        --warmup_step 0 \
+        --max_step 400000 \
+        --tgt_len 512 \
+        --mem_len 512 \
+        --eval_tgt_len 128 \
+        --batch_size 22 \
+        --multi_gpu \
+        --gpu0_bsz 4 \
+        ${@:2}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python eval.py \
+        --cuda \
+        --data ../data/enwik8/ \
+        --dataset enwik8 \
+        --tgt_len 80 \
+        --mem_len 2100 \
+        --clamp_len 820 \
+        --same_length \
+        --split test \
+        ${@:2}
+else
+    echo 'unknown argment 1'
+fi

+ 41 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_enwik8_large.sh

@@ -0,0 +1,41 @@
+#!/bin/bash
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python train.py \
+        --cuda \
+        --data ../data/enwik8/ \
+        --dataset enwik8 \
+        --n_layer 24 \
+        --d_model 1024 \
+        --n_head 8 \
+        --d_head 128 \
+        --d_inner 3072 \
+        --dropout 0.15 \
+        --dropatt 0.15 \
+        --optim adam \
+        --lr 0.00025 \
+        --warmup_step 4000 \
+        --max_step 400000 \
+        --tgt_len 768 \
+        --mem_len 768 \
+        --eval_tgt_len 128 \
+        --batch_size 64 \
+        --multi_gpu \
+        --gpu0_bsz 0 \
+        ${@:2}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python eval.py \
+        --cuda \
+        --data ../data/enwik8/ \
+        --dataset enwik8 \
+        --tgt_len 128 \
+        --mem_len 3800 \
+        --clamp_len 1000 \
+        --same_length \
+        --split test \
+        ${@:2}
+else
+    echo 'unknown argment 1'
+fi

+ 43 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_lm1b_base.sh

@@ -0,0 +1,43 @@
+#!/bin/bash
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python train.py \
+        --cuda \
+        --data ../data/one-billion-words/ \
+        --dataset lm1b \
+        --adaptive \
+        --n_layer 18 \
+        --d_model 1024 \
+        --div_val 4 \
+        --n_head 8 \
+        --d_head 128 \
+        --d_inner 4096 \
+        --dropout 0.0 \
+        --dropatt 0.0 \
+        --optim adam \
+        --warmup_step 20000 \
+        --max_step 500000 \
+        --lr 0.00025 \
+        --tgt_len 32 \
+        --mem_len 32 \
+        --eval_tgt_len 32 \
+        --batch_size 224 \
+        --multi_gpu \
+        --gpu0_bsz 32 \
+        ${@:2}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python eval.py \
+        --cuda \
+        --data ../data/one-billion-words/ \
+        --dataset lm1b \
+        --batch_size 64 \
+        --tgt_len 32 \
+        --mem_len 128 \
+        --split test \
+        --same_length \
+        ${@:2}
+else
+    echo 'unknown argment 1'
+fi

+ 43 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_lm1b_large.sh

@@ -0,0 +1,43 @@
+#!/bin/bash
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python train.py \
+        --cuda \
+        --data ../data/one-billion-words/ \
+        --dataset lm1b \
+        --adaptive \
+        --div_val 4 \
+        --n_layer 24 \
+        --d_model 1280 \
+        --n_head 16 \
+        --d_head 80 \
+        --d_inner 8192 \
+        --dropout 0.05 \
+        --dropatt 0.05 \
+        --optim adam \
+        --warmup_step 30000 \
+        --max_step 1200000 \
+        --lr 0.00025 \
+        --tgt_len 32 \
+        --mem_len 32 \
+        --eval_tgt_len 32 \
+        --batch_size 512 \
+        --multi_gpu \
+        --gpu0_bsz 0 \
+        ${@:2}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python eval.py \
+        --cuda \
+        --data ../data/one-billion-words/ \
+        --dataset lm1b \
+        --batch_size 8 \
+        --tgt_len 32 \
+        --mem_len 128 \
+        --split test \
+        --same_length \
+        ${@:2}
+else
+    echo 'unknown argment 1'
+fi

+ 41 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_text8_base.sh

@@ -0,0 +1,41 @@
+#!/bin/bash
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python train.py \
+        --cuda \
+        --data ../data/text8/ \
+        --dataset text8 \
+        --n_layer 12 \
+        --d_model 512 \
+        --n_head 8 \
+        --d_head 64 \
+        --d_inner 2048 \
+        --dropout 0.1 \
+        --dropatt 0.0 \
+        --optim adam \
+        --lr 0.00025 \
+        --warmup_step 0 \
+        --max_step 400000 \
+        --tgt_len 512 \
+        --mem_len 512 \
+        --eval_tgt_len 128 \
+        --batch_size 22 \
+        --multi_gpu \
+        --gpu0_bsz 4 \
+        ${@:2}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python eval.py \
+        --cuda \
+        --data ../data/text8/ \
+        --dataset text8 \
+        --tgt_len 80 \
+        --mem_len 2100 \
+        --clamp_len 820 \
+        --same_length \
+        --split test \
+        ${@:2}
+else
+    echo 'unknown argment 1'
+fi

+ 38 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_text8_large.sh

@@ -0,0 +1,38 @@
+#!/bin/bash
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python train.py \
+        --cuda \
+        --data ../data/text8/ \
+        --dataset text8 \
+        --n_layer 24 \
+        --d_model 1024 \
+        --n_head 8 \
+        --d_head 128 \
+        --d_inner 3072 \
+        --dropout 0.15 \
+        --dropatt 0.15 \
+        --optim adam \
+        --lr 0.00025 \
+        --tgt_len 768 \
+        --mem_len 768 \
+        --eval_tgt_len 128 \
+        --batch_size 64 \
+        --max_step 400000 \
+        ${@:2}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python eval.py \
+        --cuda \
+        --data ../data/text8/ \
+        --dataset text8 \
+        --tgt_len 128 \
+        --mem_len 3800 \
+        --clamp_len 1000 \
+        --same_length \
+        --split test \
+        ${@:2}
+else
+    echo 'unknown argment 1'
+fi

+ 58 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_wt103_base.sh

@@ -0,0 +1,58 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python -m torch.distributed.launch --nproc_per_node=$2 train.py \
+        --cuda \
+        --data ../data/wikitext-103/ \
+        --dataset wt103 \
+        --n_layer 16 \
+        --d_model 512 \
+        --n_head 8 \
+        --d_head 64 \
+        --d_inner 2048 \
+        --dropout 0.1 \
+        --dropatt 0.0 \
+        --optim jitlamb \
+        --lr 0.01 \
+        --eta_min 0.001 \
+        --roll \
+        --warmup_step 1000 \
+        --max_step 40000 \
+        --tgt_len 192 \
+        --mem_len 192 \
+        --eval_tgt_len 192 \
+        --batch_size 256 \
+        --multi_gpu ddp \
+        --log_interval 10 \
+        --eval_interval 5000 \
+        ${@:3}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python -m torch.distributed.launch --nproc_per_node=$2 eval.py \
+        --cuda \
+        --data ../data/wikitext-103/ \
+        --dataset wt103 \
+        --tgt_len 64 \
+        --mem_len 640 \
+        --clamp_len 400 \
+        --same_length \
+        --split test \
+        ${@:3}
+else
+    echo 'unknown argment 1'
+fi

+ 54 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/run_wt103_large.sh

@@ -0,0 +1,54 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [[ $1 == 'train' ]]; then
+    echo 'Run training...'
+    python -m torch.distributed.launch --nproc_per_node=$2 train.py \
+        --cuda \
+        --data ../data/wikitext-103/ \
+        --dataset wt103 \
+        --n_layer 18 \
+        --d_model 1024 \
+        --n_head 16 \
+        --d_head 64 \
+        --d_inner 4096 \
+        --dropout 0.2 \
+        --dropatt 0.2 \
+        --optim adam \
+        --lr 0.00025 \
+        --warmup_step 16000 \
+        --max_step 4000000 \
+        --tgt_len 256 \
+        --mem_len 256 \
+        --eval_tgt_len 128 \
+        --batch_size 128 \
+        --multi_gpu ddp \
+        ${@:3}
+elif [[ $1 == 'eval' ]]; then
+    echo 'Run evaluation...'
+    python -m torch.distributed.launch --nproc_per_node=$2 eval.py \
+        --cuda \
+        --data ../data/wikitext-103/ \
+        --dataset wt103 \
+        --tgt_len 128 \
+        --mem_len 1600 \
+        --clamp_len 1000 \
+        --same_length \
+        --split test \
+        ${@:3}
+else
+    echo 'unknown argment 1'
+fi

+ 17 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/docker/build.sh

@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+docker build . --network=host --rm -t transformer-xl:latest

+ 17 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/docker/interactive.sh

@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+nvidia-docker run --init -it --rm --network=host --ipc=host -v $(dirname $PWD):/workspace/transformer-xl transformer-xl bash

+ 39 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/inference_benchmark.sh

@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODEL=${MODEL:-"LM-TFM/checkpoint_best.pt"}
+
+BATCH_SIZES=(1 2 4 8 16 32)
+TYPES=("pytorch" "torchscript")
+# "empty" MATH corresponds to fp32
+MATHS=("" "--fp16")
+
+
+for (( i = 0; i < ${#TYPES[@]}; i++ )); do
+   for (( j = 0; j < ${#BATCH_SIZES[@]}; j++ )); do
+      for (( k = 0; k < ${#MATHS[@]}; k++ )); do
+         echo type: ${TYPES[i]} batch size: ${BATCH_SIZES[j]} math: ${MATHS[k]}
+
+         taskset -c 0 bash run_wt103_base.sh eval 1 \
+            --model "${MODEL}" \
+            --type "${TYPES[i]}" \
+            --batch_size "${BATCH_SIZES[j]}" \
+            "${MATHS[k]}" \
+            --save_data \
+            "${@:1}"
+      done
+   done
+done

+ 71 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/infer_bench.sh

@@ -0,0 +1,71 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
+REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_inference_throughput
+
+MATH=$1
+if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
+   echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
+   exit 1
+fi
+
+if [[ ${MATH} == 'fp16' ]]; then
+   MATH_OPT='--fp16'
+elif [[ ${MATH} == 'fp32' ]]; then
+   MATH_OPT=''
+fi
+
+TYPE=$2
+if [[ ${TYPE} != "pytorch" && ${TYPE} != "torchscript" ]]; then
+   echo "Unsupported option for TYPE, use either 'pytorch' or 'torchscript'"
+   exit 1
+fi
+
+PERF_TOLERANCE=0.9
+BATCH_SIZE=16
+
+GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
+echo 'GPU_NAME:' "${GPU_NAME}"
+GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
+echo 'GPU_COUNT:' "${GPU_COUNT}"
+GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
+echo 'GPU_MEM:' "${GPU_MEM}"
+
+REFERENCE_PERF=$(grep "${MATH},${BATCH_SIZE},${GPU_NAME}" \
+   ${REFERENCE_FILE} | \cut -f 4 -d ',')
+
+if [ -z "${REFERENCE_PERF}" ]; then
+   echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
+   TARGET_PERF=''
+else
+   PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
+   TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
+fi
+
+cd $REPO_DIR
+
+export CUDA_VISIBLE_DEVICES=0
+
+bash run_wt103_base.sh eval 1 \
+   --model checkpoint/checkpoint_best.pt \
+   --target_perplexity 23.4 \
+   --batch_size "${BATCH_SIZE}" \
+   --type "${TYPE}" \
+   "${MATH_OPT}" \
+   "${TARGET_PERF}"

+ 6 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/reference_inference_throughput

@@ -0,0 +1,6 @@
+fp16,16,Tesla V100-SXM2-16GB,40000
+fp32,16,Tesla V100-SXM2-16GB,18750
+fp16,16,Tesla V100-SXM2-32GB,40000
+fp32,16,Tesla V100-SXM2-32GB,18750
+fp16,16,Tesla V100-SXM3-32GB,40000
+fp32,16,Tesla V100-SXM3-32GB,18750

+ 10 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/reference_training_throughput

@@ -0,0 +1,10 @@
+fp16,4,Tesla V100-SXM2-16GB,126000
+fp32,4,Tesla V100-SXM2-16GB,45000
+fp16,4,Tesla V100-SXM2-32GB,126000
+fp32,4,Tesla V100-SXM2-32GB,45000
+fp16,8,Tesla V100-SXM2-16GB,233000
+fp32,8,Tesla V100-SXM2-16GB,88000
+fp16,8,Tesla V100-SXM2-32GB,233000
+fp32,8,Tesla V100-SXM2-32GB,88000
+fp16,16,Tesla V100-SXM3-32GB,356000
+fp32,16,Tesla V100-SXM3-32GB,176000

+ 78 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_bench.sh

@@ -0,0 +1,78 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
+REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
+
+MATH=$1
+if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
+   echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
+   exit 1
+fi
+
+if [[ ${MATH} == 'fp16' ]]; then
+   MATH_OPT='--fp16'
+elif [[ ${MATH} == 'fp32' ]]; then
+   MATH_OPT=''
+fi
+
+PERF_TOLERANCE=0.9
+GLOBAL_BATCH_SIZE=256
+
+GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
+echo 'GPU_NAME:' "${GPU_NAME}"
+GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
+echo 'GPU_COUNT:' "${GPU_COUNT}"
+GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
+echo 'GPU_MEM:' "${GPU_MEM}"
+
+if (( GPU_MEM > 16500 )); then
+   LOCAL_BATCH_SIZE=32
+else
+   if [[ ${MATH} == 'fp16' ]]; then
+      LOCAL_BATCH_SIZE=32
+   elif [[ ${MATH} == 'fp32' ]]; then
+      LOCAL_BATCH_SIZE=16
+   fi
+fi
+
+BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
+BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
+
+REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
+   ${REFERENCE_FILE} | \cut -f 4 -d ',')
+
+if [ -z "${REFERENCE_PERF}" ]; then
+   echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
+   TARGET_PERF=''
+else
+   PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
+   TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
+fi
+
+cd $REPO_DIR
+
+bash run_wt103_base.sh train "${GPU_COUNT}" \
+   --debug \
+   --max_step $((256 / GPU_COUNT)) \
+   --batch_chunk "${BATCH_CHUNK}" \
+   --log_interval 1 \
+   --adaptive \
+   --vocab word \
+   "${MATH_OPT}" \
+   "${TARGET_PERF}"

+ 79 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_full.sh

@@ -0,0 +1,79 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
+REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
+
+MATH=$1
+if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
+   echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
+   exit 1
+fi
+
+if [[ ${MATH} == 'fp16' ]]; then
+   MATH_OPT='--fp16'
+elif [[ ${MATH} == 'fp32' ]]; then
+   MATH_OPT=''
+fi
+
+PERF_TOLERANCE=0.9
+GLOBAL_BATCH_SIZE=256
+
+GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
+echo 'GPU_NAME:' "${GPU_NAME}"
+GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
+echo 'GPU_COUNT:' "${GPU_COUNT}"
+GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
+echo 'GPU_MEM:' "${GPU_MEM}"
+
+if (( GPU_MEM > 16500 )); then
+   LOCAL_BATCH_SIZE=32
+else
+   if [[ ${MATH} == 'fp16' ]]; then
+      LOCAL_BATCH_SIZE=32
+   elif [[ ${MATH} == 'fp32' ]]; then
+      LOCAL_BATCH_SIZE=16
+   fi
+fi
+
+BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
+BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
+
+REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
+   ${REFERENCE_FILE} | \cut -f 4 -d ',')
+
+if [ -z "${REFERENCE_PERF}" ]; then
+   echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
+   TARGET_PERF=''
+else
+   PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
+   TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
+fi
+
+cd $REPO_DIR
+
+bash run_wt103_base.sh train "${GPU_COUNT}" \
+   --debug \
+   --max_step 40000 \
+   --target_perplexity 23.4 \
+   --batch_chunk "${BATCH_CHUNK}" \
+   --log_interval 1 \
+   --adaptive \
+   --vocab word \
+   "${MATH_OPT}" \
+   "${TARGET_PERF}"

+ 80 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_long.sh

@@ -0,0 +1,80 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
+REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
+
+MATH=$1
+if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
+   echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
+   exit 1
+fi
+
+if [[ ${MATH} == 'fp16' ]]; then
+   MATH_OPT='--fp16'
+elif [[ ${MATH} == 'fp32' ]]; then
+   MATH_OPT=''
+fi
+
+PERF_TOLERANCE=0.9
+GLOBAL_BATCH_SIZE=256
+
+GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
+echo 'GPU_NAME:' "${GPU_NAME}"
+GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
+echo 'GPU_COUNT:' "${GPU_COUNT}"
+GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
+echo 'GPU_MEM:' "${GPU_MEM}"
+
+if (( GPU_MEM > 16500 )); then
+   LOCAL_BATCH_SIZE=32
+else
+   if [[ ${MATH} == 'fp16' ]]; then
+      LOCAL_BATCH_SIZE=32
+   elif [[ ${MATH} == 'fp32' ]]; then
+      LOCAL_BATCH_SIZE=16
+   fi
+fi
+
+BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
+BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
+
+REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
+   ${REFERENCE_FILE} | \cut -f 4 -d ',')
+
+if [ -z "${REFERENCE_PERF}" ]; then
+   echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
+   TARGET_PERF=''
+else
+   PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
+   TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
+fi
+
+cd $REPO_DIR
+
+bash run_wt103_base.sh train "${GPU_COUNT}" \
+   --debug \
+   --max_step 30000 \
+   --max_step_scheduler 40000 \
+   --target_perplexity 24.2 \
+   --batch_chunk "${BATCH_CHUNK}" \
+   --log_interval 1 \
+   --adaptive \
+   --vocab word \
+   "${MATH_OPT}" \
+   "${TARGET_PERF}"

+ 80 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/scripts/tests/train_short.sh

@@ -0,0 +1,80 @@
+#!/bin/bash
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+# 
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# 
+#       http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+
+REPO_DIR=${REPO_DIR:-"/workspace/transformer-xl/pytorch/"}
+REFERENCE_FILE=$REPO_DIR/scripts/tests/reference_training_throughput
+
+MATH=$1
+if [[ ${MATH} != "fp16" && ${MATH} != "fp32" ]]; then
+   echo "Unsupported option for MATH, use either 'fp16' or 'fp32'"
+   exit 1
+fi
+
+if [[ ${MATH} == 'fp16' ]]; then
+   MATH_OPT='--fp16'
+elif [[ ${MATH} == 'fp32' ]]; then
+   MATH_OPT=''
+fi
+
+PERF_TOLERANCE=0.9
+GLOBAL_BATCH_SIZE=256
+
+GPU_NAME=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |uniq)
+echo 'GPU_NAME:' "${GPU_NAME}"
+GPU_COUNT=$(nvidia-smi --query-gpu=gpu_name --format=csv,noheader |wc -l)
+echo 'GPU_COUNT:' "${GPU_COUNT}"
+GPU_MEM=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader |head -n 1 |cut -f 1 -d " ")
+echo 'GPU_MEM:' "${GPU_MEM}"
+
+if (( GPU_MEM > 16500 )); then
+   LOCAL_BATCH_SIZE=32
+else
+   if [[ ${MATH} == 'fp16' ]]; then
+      LOCAL_BATCH_SIZE=32
+   elif [[ ${MATH} == 'fp32' ]]; then
+      LOCAL_BATCH_SIZE=16
+   fi
+fi
+
+BATCH_CHUNK=$((GLOBAL_BATCH_SIZE / (GPU_COUNT * LOCAL_BATCH_SIZE)))
+BATCH_CHUNK=$((BATCH_CHUNK < 1 ? 1 : BATCH_CHUNK))
+
+REFERENCE_PERF=$(grep "${MATH},${GPU_COUNT},${GPU_NAME}" \
+   ${REFERENCE_FILE} | \cut -f 4 -d ',')
+
+if [ -z "${REFERENCE_PERF}" ]; then
+   echo "WARNING: COULD NOT FIND REFERENCE PERFORMANCE FOR EXECUTED CONFIG"
+   TARGET_PERF=''
+else
+   PERF_THRESHOLD=$(awk 'BEGIN {print ('"${REFERENCE_PERF}"' * '"${PERF_TOLERANCE}"')}')
+   TARGET_PERF='--target_throughput '${PERF_THRESHOLD}
+fi
+
+cd $REPO_DIR
+
+bash run_wt103_base.sh train "${GPU_COUNT}" \
+   --debug \
+   --max_step 5000 \
+   --max_step_scheduler 40000 \
+   --target_perplexity 43.5 \
+   --batch_chunk "${BATCH_CHUNK}" \
+   --log_interval 1 \
+   --adaptive \
+   --vocab word \
+   "${MATH_OPT}" \
+   "${TARGET_PERF}"

+ 810 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/train.py

@@ -0,0 +1,810 @@
+# coding: utf-8
+
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import functools
+import itertools
+import logging
+import math
+import os
+import time
+import sys
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from apex.parallel import DistributedDataParallel
+
+import lamb
+import utils
+from apex import amp
+from data_utils import get_lm_corpus
+from mem_transformer import MemTransformerLM
+from utils.data_parallel import BalancedDataParallel
+from utils.exp_utils import create_exp_dir
+from utils.exp_utils import benchmark
+from utils.exp_utils import AverageMeter
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='PyTorch Transformer-XL Language Model',
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+        )
+
+    general = parser.add_argument_group('general setup')
+    general.add_argument('--work_dir', default='LM-TFM', type=str,
+                         help='Directory for the results')
+    general.add_argument('--append_dataset', action='store_true',
+                         help='Automatically append dataset name to work_dir')
+    general.add_argument('--append_time', action='store_true',
+                         help='Automatically append current time to work_dir')
+    general.add_argument('--cuda', action='store_true',
+                         help='Use CUDA')
+    general.add_argument('--fp16', action='store_true',
+                         help='Run training in fp16/mixed precision')
+    general.add_argument('--restart', type=str, default='',
+                         help='Restart training from the saved checkpoint')
+    general.add_argument('--debug', action='store_true',
+                         help='Run in debug mode (do not create exp dir)')
+    general.add_argument('--log_all_ranks', action='store_true',
+                         help='Enable logging from all distributed ranks')
+    general.add_argument('--save-all', action='store_true',
+                         help='Save all checkpoints')
+    general.add_argument('--log_interval', type=int, default=10,
+                         help='Report interval')
+    general.add_argument('--target_throughput', type=float, default=None,
+                         help='Target training throughput (for benchmarking)')
+    general.add_argument('--target_perplexity', type=float, default=None,
+                         help='Target validation perplexity (for benchmarking)')
+
+    dataset = parser.add_argument_group('dataset setup')
+    dataset.add_argument('--data', type=str, default='../data/wikitext-103',
+                         help='Location of the data corpus')
+    dataset.add_argument('--dataset', type=str, default='wt103',
+                         choices=['wt103', 'lm1b', 'enwik8', 'text8'],
+                         help='Dataset name')
+    dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
+                         help='Type of vocabulary')
+
+    model = parser.add_argument_group('model setup')
+    model.add_argument('--n_layer', type=int, default=16,
+                       help='Number of total layers')
+    model.add_argument('--n_head', type=int, default=8,
+                       help='Number of heads')
+    model.add_argument('--d_head', type=int, default=64,
+                       help='Head dimension')
+    model.add_argument('--d_embed', type=int, default=-1,
+                       help='Embedding dimension')
+    model.add_argument('--d_model', type=int, default=512,
+                       help='Model dimension')
+    model.add_argument('--d_inner', type=int, default=2048,
+                       help='Inner dimension in feedforward layer')
+    model.add_argument('--dropout', type=float, default=0.1,
+                       help='Global dropout rate')
+    model.add_argument('--dropatt', type=float, default=0.0,
+                       help='Attention probability dropout rate')
+    model.add_argument('--pre_lnorm', action='store_true',
+                       help='Apply LayerNorm to the input instead of the output')
+    model.add_argument('--attn_type', type=int, default=0,
+                       help='Attention type. 0 for ours, 1 for Shaw et al,'
+                       '2 for Vaswani et al, 3 for Al Rfou et al.')
+    model.add_argument('--not_tied', action='store_true',
+                       help='Do not tie the word embedding and softmax weights')
+    model.add_argument('--clamp_len', type=int, default=-1,
+                       help='Use the same pos embeddings after clamp_len')
+    model.add_argument('--adaptive', action='store_true',
+                       help='Use adaptive softmax')
+    model.add_argument('--div_val', type=int, default=1,
+                       help='Dividend value for adaptive input and softmax')
+    model.add_argument('--sample_softmax', type=int, default=-1,
+                       help='Number of samples in sampled softmax')
+    model.add_argument('--init', default='normal', type=str,
+                       help='Parameter initializer to use')
+    model.add_argument('--emb_init', default='normal', type=str,
+                       help='Parameter initializer to use')
+    model.add_argument('--init_range', type=float, default=0.1,
+                       help='Parameters initialized by U(-init_range, init_range)')
+    model.add_argument('--emb_init_range', type=float, default=0.01,
+                       help='Parameters initialized by U(-init_range, init_range)')
+    model.add_argument('--init_std', type=float, default=0.02,
+                       help='Parameters initialized by N(0, init_std)')
+    model.add_argument('--proj_init_std', type=float, default=0.01,
+                       help='Parameters initialized by N(0, init_std)')
+
+    opt = parser.add_argument_group('optimizer setup')
+    opt.add_argument('--optim', default='jitlamb', type=str,
+                     choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
+                     help='Optimizer to use')
+    opt.add_argument('--lr', type=float, default=0.01,
+                     help='Initial learning rate')
+    opt.add_argument('--mom', type=float, default=0.0,
+                     help='Momentum for sgd')
+    opt.add_argument('--scheduler', default='cosine', type=str,
+                     choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
+                     help='LR scheduler to use')
+    opt.add_argument('--max_step_scheduler', type=int, default=None,
+                     help='Max number of training steps for LR scheduler')
+    opt.add_argument('--warmup_step', type=int, default=1000,
+                     help='Number of iterations for LR warmup')
+    opt.add_argument('--decay_rate', type=float, default=0.5,
+                     help='Decay factor when ReduceLROnPlateau is used')
+    opt.add_argument('--lr_min', type=float, default=0.0,
+                     help='Minimum learning rate during annealing')
+    opt.add_argument('--clip', type=float, default=0.25,
+                     help='Gradient clipping')
+    opt.add_argument('--weight_decay', type=float, default=0.0,
+                     help='Weight decay for adam|lamb')
+    opt.add_argument('--clip_nonemb', action='store_true',
+                     help='Only clip the gradient of non-embedding params')
+    opt.add_argument('--patience', type=int, default=0,
+                     help='Patience')
+    opt.add_argument('--eta_min', type=float, default=0.001,
+                     help='Min learning rate for cosine scheduler')
+
+    training = parser.add_argument_group('training setup')
+    training.add_argument('--max_step', type=int, default=40000,
+                          help='Max number of training steps')
+    training.add_argument('--batch_size', type=int, default=256,
+                          help='Global batch size')
+    training.add_argument('--batch_chunk', type=int, default=1,
+                          help='Split batch into chunks to save memory')
+    training.add_argument('--roll', action='store_true',
+                          help='Enable random shifts within each data stream')
+    training.add_argument('--tgt_len', type=int, default=192,
+                          help='Number of tokens to predict')
+    training.add_argument('--ext_len', type=int, default=0,
+                          help='Length of the extended context')
+    training.add_argument('--mem_len', type=int, default=192,
+                          help='Length of the retained previous heads')
+    training.add_argument('--seed', type=int, default=1111,
+                          help='Random seed')
+    training.add_argument('--multi_gpu', default=None, type=str,
+                          choices=['ddp', 'dp'],
+                          help='Use multiple GPU')
+    training.add_argument('--gpu0_bsz', type=int, default=-1,
+                          help='Batch size on gpu 0 (for "dp" backend)')
+    training.add_argument('--same_length', action='store_true',
+                          help='Use the same attn length for all tokens')
+    training.add_argument('--varlen', action='store_true',
+                          help='Use variable length')
+
+    val = parser.add_argument_group('validation setup')
+    val.add_argument('--eval_tgt_len', type=int, default=192,
+                     help='Number of tokens to predict for evaluation')
+    val.add_argument('--eval_batch_size', type=int, default=16,
+                     help='Eval batch size')
+    val.add_argument('--eval_max_steps', type=int, default=-1,
+                     help='Max eval steps')
+    val.add_argument('--eval_interval', type=int, default=5000,
+                     help='Evaluation interval')
+
+    dist = parser.add_argument_group('distributed setup')
+    dist.add_argument('--local_rank', default=0, type=int,
+                      help='Used for multi-process training. ' +
+                      'Can either be manually set ' +
+                      'or automatically set by using \'python -m multiproc\'')
+
+    args = parser.parse_args()
+    args.tied = not args.not_tied
+
+    if args.d_embed < 0:
+        args.d_embed = args.d_model
+
+    assert args.ext_len >= 0, 'extended context length must be non-negative'
+    assert args.batch_size % args.batch_chunk == 0
+
+    return args
+
+
+def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
+                    train_step, best_val_loss, work_dir, name='checkpoint.pt'):
+    if args.fp16:
+        amp_state = amp.state_dict()
+    else:
+        amp_state = None
+
+    state = {
+        'args': args,
+        'model_config': model_config,
+        'model_state': model.state_dict(),
+        'optimizer_state': optimizer.state_dict(),
+        'scheduler_state': scheduler.state_dict(),
+        'vocab': vocab,
+        'amp_state': amp_state,
+        'train_step': train_step,
+        'best_val_loss': best_val_loss,
+        }
+
+    with utils.distributed.sync_workers() as rank:
+        path = os.path.join(work_dir, name)
+        logging.info(f'Saving checkpoint to {path}')
+        if rank == 0:
+            torch.save(state, path)
+
+
+def load_checkpoint(path):
+    if os.path.isdir(path):
+        path = os.path.join(path, 'checkpoint_last.pt')
+
+    dst = f'cuda:{torch.cuda.current_device()}'
+    logging.info(f'Loading checkpoint from {path}')
+    checkpoint = torch.load(path, map_location=dst)
+    return checkpoint
+
+
+def init_weight(weight, args):
+    if args.init == 'uniform':
+        nn.init.uniform_(weight, -args.init_range, args.init_range)
+    elif args.init == 'normal':
+        nn.init.normal_(weight, 0.0, args.init_std)
+
+
+def init_bias(bias):
+    nn.init.constant_(bias, 0.0)
+
+
+def weights_init(m, args):
+    classname = m.__class__.__name__
+    if classname.find('Linear') != -1:
+        if hasattr(m, 'weight') and m.weight is not None:
+            init_weight(m.weight, args)
+        if hasattr(m, 'bias') and m.bias is not None:
+            init_bias(m.bias)
+    elif classname.find('AdaptiveEmbedding') != -1:
+        if hasattr(m, 'emb_projs'):
+            for i in range(len(m.emb_projs)):
+                if m.emb_projs[i] is not None:
+                    nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
+    elif classname.find('Embedding') != -1:
+        if hasattr(m, 'weight'):
+            init_weight(m.weight, args)
+    elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
+        if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
+            init_weight(m.cluster_weight, args)
+        if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
+            init_bias(m.cluster_bias)
+        if hasattr(m, 'out_projs'):
+            for i in range(len(m.out_projs)):
+                if m.out_projs[i] is not None:
+                    nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
+    elif classname.find('LayerNorm') != -1:
+        if hasattr(m, 'weight'):
+            nn.init.normal_(m.weight, 1.0, args.init_std)
+        if hasattr(m, 'bias') and m.bias is not None:
+            init_bias(m.bias)
+    elif classname.find('TransformerLM') != -1:
+        if hasattr(m, 'r_emb'):
+            init_weight(m.r_emb, args)
+        if hasattr(m, 'r_w_bias'):
+            init_weight(m.r_w_bias, args)
+        if hasattr(m, 'r_r_bias'):
+            init_weight(m.r_r_bias, args)
+        if hasattr(m, 'r_bias'):
+            init_bias(m.r_bias)
+
+
+def update_dropout(m, args):
+    classname = m.__class__.__name__
+    if classname.find('Dropout') != -1:
+        if hasattr(m, 'p'):
+            m.p = args.dropout
+
+
+def update_dropatt(m, args):
+    if hasattr(m, 'dropatt'):
+        m.dropatt.p = args.dropatt
+
+
+def evaluate(eval_iter, model, args):
+    # Turn on evaluation mode which disables dropout.
+    model.eval()
+
+    # If the model does not use memory at all, make the ext_len longer.
+    # Otherwise, make the mem_len longer and keep the ext_len the same.
+    if args.mem_len == 0:
+        model.reset_length(tgt_len=args.eval_tgt_len,
+                           ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
+                           mem_len=args.mem_len
+                           )
+    else:
+        model.reset_length(tgt_len=args.eval_tgt_len,
+                           ext_len=args.ext_len,
+                           mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
+                           )
+
+    # Evaluation
+    total_len, total_loss = 0, 0.
+    with torch.no_grad():
+        mems = None
+        for i, (data, target, seq_len) in enumerate(eval_iter):
+            if args.eval_max_steps > 0 and i >= args.eval_max_steps:
+                break
+            ret = model(data, target, mems)
+            loss, mems = ret[0], ret[1:]
+            loss = loss.mean()
+            total_loss += seq_len * loss.float().item()
+            total_len += seq_len
+
+    # Switch back to the training mode
+    model.reset_length(tgt_len=args.tgt_len,
+                       ext_len=args.ext_len,
+                       mem_len=args.mem_len
+                       )
+    model.train()
+
+    return total_loss / total_len
+
+
+def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
+          optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch, train_step,
+          best_val_loss, meters, args):
+    # Turn on training mode which enables dropout.
+    model.train()
+
+    train_loss = 0
+    target_tokens = 0
+    log_step = 0
+    log_start_time = time.time()
+
+    mems = [None for _ in range(args.batch_chunk)]
+    train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
+
+    for batch, (data, target, seq_len) in enumerate(train_iter):
+        log_step += 1
+        target_tokens += target.numel()
+
+        model.zero_grad()
+
+        data_chunks = torch.chunk(data, args.batch_chunk, 1)
+        target_chunks = torch.chunk(target, args.batch_chunk, 1)
+
+        for i in range(args.batch_chunk):
+            data_i = data_chunks[i].contiguous()
+            target_i = target_chunks[i].contiguous()
+            ret = para_model(data_i, target_i, mems[i])
+            loss, mems[i] = ret[0], ret[1:]
+            loss = loss.float().mean().type_as(loss) / args.batch_chunk
+
+            if args.fp16:
+                with amp.scale_loss(loss, optimizer) as scaled_loss:
+                    scaled_loss.backward()
+            else:
+                loss.backward()
+
+            train_loss += loss.float().item()
+
+        if args.fp16:
+            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
+        else:
+            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
+
+        optimizer.step()
+        if optimizer_sparse:
+            optimizer_sparse.step()
+
+        # step-wise learning rate annealing
+        train_step += 1
+        if args.scheduler in ['cosine', 'constant', 'dev_perf']:
+            # linear warmup stage
+            if train_step < args.warmup_step:
+                curr_lr = args.lr * train_step / args.warmup_step
+                optimizer.param_groups[0]['lr'] = curr_lr
+                if optimizer_sparse:
+                    optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
+            else:
+                if args.scheduler == 'cosine':
+                    scheduler.step(train_step)
+                    if scheduler_sparse:
+                        scheduler_sparse.step(train_step)
+        elif args.scheduler == 'inv_sqrt':
+            scheduler.step(train_step)
+
+        if train_step % args.log_interval == 0:
+            cur_loss = train_loss / log_step
+            cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
+            train_loss = 0
+
+            elapsed = time.time() - log_start_time
+            avg_elapsed = elapsed / log_step
+            avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
+            log_start_time = time.time()
+            log_step = 0
+
+            lr = optimizer.param_groups[0]['lr']
+            throughput = target_tokens / elapsed
+            throughput = utils.distributed.all_reduce_item(throughput, op='sum')
+            meters['train_throughput'].update(throughput)
+            target_tokens = 0
+
+            log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
+                '| ms/batch {:5.1f} | tok/s {:>7d} | loss {:5.2f}'.format(
+                    epoch,
+                    train_step,
+                    batch+1,
+                    tr_iter.n_batch,
+                    lr,
+                    avg_elapsed * 1000,
+                    int(throughput),
+                    cur_loss,
+                    )
+
+            if args.dataset in ['enwik8', 'text8']:
+                log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
+            else:
+                log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
+
+            logging.info(log_str)
+
+        if train_step % args.eval_interval == 0:
+            eval_start_time = time.time()
+            val_loss = evaluate(va_iter, model, args)
+            val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
+
+            logging.info('-' * 100)
+            log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
+                      '| valid loss {:5.2f}'.format(
+                          train_step // args.eval_interval,
+                          train_step,
+                          (time.time() - eval_start_time),
+                          val_loss,
+                          )
+            if args.dataset in ['enwik8', 'text8']:
+                log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
+            else:
+                log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
+            logging.info(log_str)
+            logging.info('-' * 100)
+
+            # Save the model if the validation loss is the best we've seen so far.
+            if not best_val_loss or val_loss < best_val_loss:
+                best_val_loss = val_loss
+                if not args.debug:
+                    name = 'checkpoint_best.pt'
+                    save_checkpoint(args, model, model_config, optimizer,
+                                    scheduler, vocab, train_step,
+                                    best_val_loss, args.work_dir, name)
+
+            # Always save after eval if save_all is true and not debug
+            if not args.debug and args.save_all:
+                name = f'checkpoint_{train_step}.pt'
+                save_checkpoint(args, model, model_config, optimizer,
+                                scheduler, vocab, train_step, best_val_loss,
+                                args.work_dir, name)
+
+            # Save last checkpoint if not debug and not save_all
+            if not args.debug and not args.save_all:
+                name = 'checkpoint_last.pt'
+                save_checkpoint(args, model, model_config, optimizer,
+                                scheduler, vocab, train_step, best_val_loss,
+                                args.work_dir, name)
+
+            # dev-performance based learning rate annealing
+            if args.scheduler == 'dev_perf':
+                scheduler.step(val_loss)
+                if scheduler_sparse:
+                    scheduler_sparse.step(val_loss)
+
+            # subtract eval time from timers for training
+            log_start_time += time.time() - eval_start_time
+
+        if train_step == args.max_step:
+            break
+    return train_step, best_val_loss
+
+
+def main():
+    args = parse_args()
+
+    # Initialize device and distributed backend
+    torch.cuda.set_device(args.local_rank)
+    device = torch.device('cuda' if args.cuda else 'cpu')
+    utils.distributed.init_distributed(args.cuda)
+
+    args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
+                                                        args.dataset,
+                                                        args.append_dataset,
+                                                        args.append_time,
+                                                        )
+
+    with utils.distributed.sync_workers() as rank:
+        if rank == 0:
+            create_exp_dir(args.work_dir,
+                           scripts_to_save=['train.py', 'mem_transformer.py'],
+                           debug=args.debug)
+
+    # Setup logging
+    if args.log_all_ranks:
+        log_file = f'log_rank_{utils.distributed.get_rank()}.log'
+    else:
+        log_file = f'log.log'
+    log_file = os.path.join(args.work_dir, log_file)
+
+    if args.debug:
+        log_file = os.devnull
+
+    utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
+                                  filename=log_file,
+                                  )
+    logging.info(args)
+
+    # Set the random seed manually for reproducibility.
+    np.random.seed(args.seed + utils.distributed.get_rank())
+    torch.manual_seed(args.seed + utils.distributed.get_rank())
+
+    ###########################################################################
+    # Load data
+    ###########################################################################
+    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
+    ntokens = len(corpus.vocab)
+    vocab = corpus.vocab
+    args.n_token = ntokens
+
+    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
+                                  device=device, ext_len=args.ext_len)
+    va_iter = corpus.get_iterator('valid', args.eval_batch_size, args.eval_tgt_len,
+                                  device=device, ext_len=args.ext_len)
+    te_iter = corpus.get_iterator('test', args.eval_batch_size, args.eval_tgt_len,
+                                  device=device, ext_len=args.ext_len)
+
+    # adaptive softmax / embedding
+    cutoffs, tie_projs = [], [False]
+    if args.adaptive:
+        assert args.dataset in ['wt103', 'lm1b']
+        if args.dataset == 'wt103':
+            cutoffs = [19997, 39997, 199997]
+            tie_projs += [True] * len(cutoffs)
+        elif args.dataset == 'lm1b':
+            cutoffs = [59997, 99997, 639997]
+            tie_projs += [False] * len(cutoffs)
+
+    ###########################################################################
+    # Build the model
+    ###########################################################################
+    model_config = {
+        'n_token': ntokens,
+        'n_layer': args.n_layer,
+        'n_head': args.n_head,
+        'd_model': args.d_model,
+        'd_head': args.d_head,
+        'd_inner': args.d_inner,
+        'dropout': args.dropout,
+        'dropatt': args.dropatt,
+        'dtype': None,
+        'tie_weight': args.tied,
+        'd_embed': args.d_embed,
+        'div_val': args.div_val,
+        'tie_projs': tie_projs,
+        'pre_lnorm': args.pre_lnorm,
+        'tgt_len': args.tgt_len,
+        'ext_len': args.ext_len,
+        'mem_len': args.mem_len,
+        'cutoffs': cutoffs,
+        'same_length': args.same_length,
+        'attn_type': args.attn_type,
+        'clamp_len': args.clamp_len,
+        'sample_softmax': args.sample_softmax,
+        }
+
+    model = MemTransformerLM(**model_config)
+
+    model.apply(functools.partial(weights_init, args=args))
+    # ensure embedding init is not overridden by out_layer in case of weight sharing
+    model.word_emb.apply(functools.partial(weights_init, args=args))
+
+    args.n_all_param = sum([p.nelement() for p in model.parameters()])
+    args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
+
+    # optimizer
+    if args.optim.lower() == 'sgd':
+        if args.sample_softmax > 0:
+            dense_params, sparse_params = [], []
+            for param in model.parameters():
+                if param.size() == model.word_emb.weight.size():
+                    sparse_params.append(param)
+                else:
+                    dense_params.append(param)
+            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
+            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
+        else:
+            optimizer = optim.SGD(model.parameters(), lr=args.lr,
+                                  momentum=args.mom)
+            optimizer_sparse = None
+    elif args.optim.lower() == 'adam':
+        if args.sample_softmax > 0:
+            dense_params, sparse_params = [], []
+            for param in model.parameters():
+                if param.size() == model.word_emb.weight.size():
+                    sparse_params.append(param)
+                else:
+                    dense_params.append(param)
+            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
+            optimizer = optim.Adam(dense_params, lr=args.lr,
+                                   weight_decay=args.weight_decay)
+        else:
+            optimizer = optim.Adam(model.parameters(), lr=args.lr,
+                                   weight_decay=args.weight_decay)
+            optimizer_sparse = None
+    elif args.optim.lower() == 'adagrad':
+        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
+        optimizer_sparse = None
+    elif args.optim.lower() == 'lamb':
+        optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
+                              weight_decay=args.weight_decay)
+        optimizer_sparse = None
+    elif args.optim.lower() == 'jitlamb':
+        optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
+                                 weight_decay=args.weight_decay)
+        optimizer_sparse = None
+
+    model = model.to(device)
+
+    if args.fp16:
+        model, optimizer = amp.initialize(
+            model,
+            optimizer,
+            opt_level='O2',
+            )
+
+    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
+        para_model = DistributedDataParallel(model,
+                                             delay_allreduce=True,
+                                             )
+    elif args.multi_gpu == 'dp':
+        if args.gpu0_bsz >= 0:
+            para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
+                                              model, dim=1).to(device)
+        else:
+            para_model = nn.DataParallel(model, dim=1).to(device)
+    else:
+        para_model = model
+
+    # scheduler
+    if args.scheduler == 'cosine':
+        if args.max_step_scheduler:
+            max_step = args.max_step_scheduler
+        else:
+            max_step = args.max_step
+
+        scheduler = optim.lr_scheduler.CosineAnnealingLR(
+            optimizer, max_step, eta_min=args.eta_min
+            )
+        if args.sample_softmax > 0:
+            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
+                optimizer_sparse, max_step, eta_min=args.eta_min
+                )
+        else:
+            scheduler_sparse = None
+    elif args.scheduler == 'inv_sqrt':
+        # originally used for Transformer (in Attention is all you need)
+        def lr_lambda(step):
+            # return a multiplier instead of a learning rate
+            if step == 0 and args.warmup_step == 0:
+                return 1.
+            else:
+                return 1. / (step ** 0.5) if step > args.warmup_step \
+                    else step / (args.warmup_step ** 1.5)
+        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
+    elif args.scheduler == 'dev_perf':
+        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
+            optimizer, factor=args.decay_rate, patience=args.patience,
+            min_lr=args.lr_min,
+            )
+        if args.sample_softmax > 0:
+            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
+                optimizer_sparse, factor=args.decay_rate, patience=args.patience,
+                min_lr=args.lr_min,
+                )
+        else:
+            scheduler_sparse = None
+    elif args.scheduler == 'constant':
+        pass
+
+    logging.info('=' * 100)
+    for k, v in args.__dict__.items():
+        logging.info('    - {} : {}'.format(k, v))
+    logging.info('=' * 100)
+    logging.info('#params = {}'.format(args.n_all_param))
+    logging.info('#non emb params = {}'.format(args.n_nonemb_param))
+
+    train_step = 0
+    best_val_loss = None
+
+    if args.restart:
+        checkpoint = load_checkpoint(args.restart)
+        model.load_state_dict(checkpoint['model_state'])
+        optimizer.load_state_dict(checkpoint['optimizer_state'])
+        scheduler.load_state_dict(checkpoint['scheduler_state'])
+        if args.fp16:
+            amp.load_state_dict(checkpoint['amp_state'])
+        train_step = checkpoint['train_step']
+        best_val_loss = checkpoint['best_val_loss']
+
+        model.apply(functools.partial(update_dropout, args=args))
+        model.apply(functools.partial(update_dropatt, args=args))
+
+    meters = {}
+    warmup = args.mem_len // args.tgt_len + 1
+    meters['train_throughput'] = AverageMeter(warmup=warmup)
+    ###########################################################################
+    # Train
+    ###########################################################################
+    # Loop over epochs.
+    # At any point you can hit Ctrl + C to break out of training early.
+    start_time = time.time()
+    try:
+        for epoch in itertools.count(start=1):
+            if args.roll:
+                tr_iter.roll()
+            train_step, best_val_loss = train(
+                tr_iter, va_iter, model, para_model, model_config, optimizer,
+                optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
+                train_step, best_val_loss, meters, args
+                )
+
+            if train_step == args.max_step:
+                logging.info('-' * 100)
+                logging.info('End of training')
+                break
+    except KeyboardInterrupt:
+        logging.info('-' * 100)
+        logging.info('Exiting from training early')
+    elapsed = time.time() - start_time
+
+    ###########################################################################
+    # Test
+    ###########################################################################
+    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
+    if not args.debug and os.path.exists(test_path):
+        # Load the best saved model.
+        checkpoint = load_checkpoint(test_path)
+        model.load_state_dict(checkpoint['model_state'])
+
+        # Run on test data.
+        test_start_time = time.time()
+        test_loss = evaluate(te_iter, model, args)
+        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
+
+        logging.info('=' * 100)
+        if args.dataset in ['enwik8', 'text8']:
+            logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
+                time.time() - test_start_time, test_loss, test_loss / math.log(2)))
+        else:
+            logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
+                time.time() - test_start_time, test_loss, math.exp(test_loss)))
+        logging.info('=' * 100)
+
+    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
+    logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
+
+    if best_val_loss:
+        val_perplexity = math.exp(best_val_loss)
+    else:
+        val_perplexity = None
+
+    passed = benchmark(
+        target_perplexity=args.target_perplexity,
+        test_perplexity=val_perplexity,
+        target_throughput=args.target_throughput,
+        test_throughput=meters['train_throughput'].avg
+        )
+    if not passed:
+        sys.exit(1)
+
+
+if __name__ == "__main__":
+    main()

+ 16 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/__init__.py

@@ -0,0 +1,16 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import distributed
+from . import exp_utils

+ 90 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/adaptive_softmax.py

@@ -0,0 +1,90 @@
+from collections import defaultdict
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class AdaptiveLogSoftmax(nn.Module):
+    def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
+        super(AdaptiveLogSoftmax, self).__init__()
+
+        cutoffs = list(cutoffs)
+
+        if (cutoffs != sorted(cutoffs)) \
+                or (min(cutoffs) <= 0) \
+                or (max(cutoffs) >= (n_classes - 1)) \
+                or (len(set(cutoffs)) != len(cutoffs)) \
+                or any([int(c) != c for c in cutoffs]):
+
+            raise ValueError("cutoffs should be a sequence of unique, positive "
+                             "integers sorted in an increasing order, where "
+                             "each value is between 1 and n_classes-1")
+
+        self.in_features = in_features
+        self.n_classes = n_classes
+        self.cutoffs = cutoffs + [n_classes]
+
+        self.shortlist_size = self.cutoffs[0]
+        self.n_clusters = len(self.cutoffs) - 1
+        self.head_size = self.shortlist_size + self.n_clusters
+
+        self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
+        self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
+
+        self.keep_order = keep_order
+
+
+    def forward(self, hidden, target, weight, bias, keep_order=False):
+        if hidden.size(0) != target.size(0):
+            raise RuntimeError('Input and target should have the same size '
+                               'in the batch dimension.')
+
+        head_weight = torch.cat(
+            [weight[:self.shortlist_size], self.cluster_weight], dim=0)
+        head_bias = torch.cat(
+            [bias[:self.shortlist_size], self.cluster_bias], dim=0)
+
+        head_logit = F.linear(hidden, head_weight, bias=head_bias)
+        head_logprob = F.log_softmax(head_logit, dim=1)
+
+        nll = torch.zeros_like(target,
+                dtype=hidden.dtype, device=hidden.device)
+
+        offset = 0
+        cutoff_values = [0] + self.cutoffs
+        for i in range(len(cutoff_values) - 1):
+            l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
+
+            mask_i = (target >= l_idx) & (target < h_idx)
+            indices_i = mask_i.nonzero().squeeze()
+
+            if indices_i.numel() == 0:
+                continue
+
+            target_i = target.index_select(0, indices_i) - l_idx
+            head_logprob_i = head_logprob.index_select(0, indices_i)
+
+            if i == 0:
+                logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
+            else:
+                weight_i = weight[l_idx:h_idx]
+                bias_i = bias[l_idx:h_idx]
+
+                hidden_i = hidden.index_select(0, indices_i)
+
+                tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
+                tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
+
+                logprob_i = head_logprob_i[:, -i] \
+                          + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
+
+            if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
+                nll.index_copy_(0, indices_i, -logprob_i)
+            else:
+                nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
+
+            offset += logprob_i.size(0)
+
+        return nll

+ 91 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/data_parallel.py

@@ -0,0 +1,91 @@
+
+from torch.nn.parallel import DataParallel
+import torch
+from torch.nn.parallel._functions import Scatter
+from torch.nn.parallel.parallel_apply import parallel_apply
+
+def scatter(inputs, target_gpus, chunk_sizes, dim=0):
+    r"""
+    Slices tensors into approximately equal chunks and
+    distributes them across given GPUs. Duplicates
+    references to objects that are not tensors.
+    """
+    def scatter_map(obj):
+        if isinstance(obj, torch.Tensor):
+            try:
+                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
+            except:
+                print('obj', obj.size())
+                print('dim', dim)
+                print('chunk_sizes', chunk_sizes)
+                quit()
+        if isinstance(obj, tuple) and len(obj) > 0:
+            return list(zip(*map(scatter_map, obj)))
+        if isinstance(obj, list) and len(obj) > 0:
+            return list(map(list, zip(*map(scatter_map, obj))))
+        if isinstance(obj, dict) and len(obj) > 0:
+            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+        return [obj for targets in target_gpus]
+
+    # After scatter_map is called, a scatter_map cell will exist. This cell
+    # has a reference to the actual function scatter_map, which has references
+    # to a closure that has a reference to the scatter_map cell (because the
+    # fn is recursive). To avoid this reference cycle, we set the function to
+    # None, clearing the cell
+    try:
+        return scatter_map(inputs)
+    finally:
+        scatter_map = None
+
+def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
+    r"""Scatter with support for kwargs dictionary"""
+    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
+    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
+    if len(inputs) < len(kwargs):
+        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+    elif len(kwargs) < len(inputs):
+        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+    inputs = tuple(inputs)
+    kwargs = tuple(kwargs)
+    return inputs, kwargs
+
+class BalancedDataParallel(DataParallel):
+    def __init__(self, gpu0_bsz, *args, **kwargs):
+        self.gpu0_bsz = gpu0_bsz
+        super().__init__(*args, **kwargs)
+
+    def forward(self, *inputs, **kwargs):
+        if not self.device_ids:
+            return self.module(*inputs, **kwargs)
+        if self.gpu0_bsz == 0:
+            device_ids = self.device_ids[1:]
+        else:
+            device_ids = self.device_ids
+        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
+        if len(self.device_ids) == 1:
+            return self.module(*inputs[0], **kwargs[0])
+        replicas = self.replicate(self.module, self.device_ids)
+        if self.gpu0_bsz == 0:
+            replicas = replicas[1:]
+        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
+        return self.gather(outputs, self.output_device)
+
+    def parallel_apply(self, replicas, device_ids, inputs, kwargs):
+        return parallel_apply(replicas, inputs, kwargs, device_ids)
+
+    def scatter(self, inputs, kwargs, device_ids):
+        bsz = inputs[0].size(self.dim)
+        num_dev = len(self.device_ids)
+        gpu0_bsz = self.gpu0_bsz
+        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
+        if gpu0_bsz < bsz_unit:
+            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
+            delta = bsz - sum(chunk_sizes)
+            for i in range(delta):
+                chunk_sizes[i + 1] += 1
+            if gpu0_bsz == 0:
+                chunk_sizes = chunk_sizes[1:]
+        else:
+            return super().scatter(inputs, kwargs, device_ids)
+        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
+

+ 110 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/distributed.py

@@ -0,0 +1,110 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from contextlib import contextmanager
+
+import torch
+
+
+def init_distributed(cuda):
+    """
+    Initializes distributed backend.
+
+    :param cuda: (bool) if True initializes nccl backend, if False initializes
+        gloo backend
+    """
+    world_size = int(os.environ.get('WORLD_SIZE', 1))
+    distributed = (world_size > 1)
+    if distributed:
+        backend = 'nccl' if cuda else 'gloo'
+        torch.distributed.init_process_group(backend=backend,
+                                             init_method='env://')
+        assert torch.distributed.is_initialized()
+    return distributed
+
+
+def barrier():
+    """
+    Call torch.distributed.barrier() if distritubed is in use
+    """
+    if torch.distributed.is_available() and torch.distributed.is_initialized():
+        torch.distributed.barrier()
+
+
+def get_rank():
+    """
+    Gets distributed rank or returns zero if distributed is not initialized.
+    """
+    if torch.distributed.is_available() and torch.distributed.is_initialized():
+        rank = torch.distributed.get_rank()
+    else:
+        rank = 0
+    return rank
+
+
+def get_world_size():
+    """
+    Gets total number of distributed workers or returns one if distributed is
+    not initialized.
+    """
+    if torch.distributed.is_available() and torch.distributed.is_initialized():
+        world_size = torch.distributed.get_world_size()
+    else:
+        world_size = 1
+    return world_size
+
+
+def all_reduce_item(value, op='sum'):
+    """
+    All-reduces single scalar value if distributed is in use
+    """
+    if torch.distributed.is_available() and torch.distributed.is_initialized():
+        if op == 'sum' or op == 'mean':
+            dop = torch.distributed.ReduceOp.SUM
+        elif op == 'min':
+            dop = torch.distributed.ReduceOp.MIN
+        elif op == 'max':
+            dop = torch.distributed.ReduceOp.MAX
+        elif op == 'product':
+            dop = torch.distributed.ReduceOp.PRODUCT
+        else:
+            raise RuntimeError('Unsupported reduce op')
+
+        backend = torch.distributed.get_backend()
+        if backend == torch.distributed.Backend.NCCL:
+            device = torch.device('cuda')
+        elif backend == torch.distributed.Backend.GLOO:
+            device = torch.device('cpu')
+        else:
+            raise RuntimeError('Unsupported distributed backend')
+
+        tensor = torch.tensor(value, device=device)
+        torch.distributed.all_reduce(tensor, dop)
+        if op == 'mean':
+            tensor /= get_world_size()
+        ret = tensor.item()
+    else:
+        ret = value
+    return ret
+
+
+@contextmanager
+def sync_workers():
+    """
+    Yields distributed rank and synchronizes all workers on exit.
+    """
+    rank = get_rank()
+    yield rank
+    barrier()

+ 147 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/exp_utils.py

@@ -0,0 +1,147 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import logging
+import os
+import shutil
+import sys
+import time
+
+import utils
+
+
+class AverageMeter:
+    """
+    Computes and stores the average and current value
+    """
+    def __init__(self, warmup=0, keep=False):
+        self.reset()
+        self.warmup = warmup
+        self.keep = keep
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+        self.iters = 0
+        self.vals = []
+
+    def update(self, val, n=1):
+        self.iters += 1
+        self.val = val
+
+        if self.iters > self.warmup:
+            self.sum += val * n
+            self.count += n
+            self.avg = self.sum / self.count
+            if self.keep:
+                self.vals.append(val)
+
+
+def benchmark(test_perplexity=None, target_perplexity=None,
+              test_throughput=None, target_throughput=None):
+    def test(achieved, target, name, higher_better=True):
+        passed = True
+        if target is not None and achieved is not None:
+            logging.info(f'{name} achieved: {achieved:.2f} '
+                         f'target: {target:.2f}')
+            if higher_better:
+                result = (achieved >= target)
+            else:
+                result = (achieved <= target)
+
+            if result:
+                logging.info(f'{name} test passed')
+            else:
+                logging.info(f'{name} test failed')
+                passed = False
+        return passed
+
+    passed = True
+    passed &= test(test_perplexity, target_perplexity, 'Perplexity', False)
+    passed &= test(test_throughput, target_throughput, 'Throughput')
+    return passed
+
+
+def setup_logging(log_all_ranks=True, filename=os.devnull, filemode='w'):
+    """
+    Configures logging.
+    By default logs from all workers are printed to the console, entries are
+    prefixed with "N: " where N is the rank of the worker. Logs printed to the
+    console don't include timestaps.
+    Full logs with timestamps are saved to the log_file file.
+    """
+    class RankFilter(logging.Filter):
+        def __init__(self, rank, log_all_ranks):
+            self.rank = rank
+            self.log_all_ranks = log_all_ranks
+
+        def filter(self, record):
+            record.rank = self.rank
+            if self.log_all_ranks:
+                return True
+            else:
+                return (self.rank == 0)
+
+    rank = utils.distributed.get_rank()
+    rank_filter = RankFilter(rank, log_all_ranks)
+
+    if log_all_ranks:
+        logging_format = "%(asctime)s - %(levelname)s - %(rank)s - %(message)s"
+    else:
+        logging_format = "%(asctime)s - %(levelname)s - %(message)s"
+    logging.basicConfig(level=logging.DEBUG,
+                        format=logging_format,
+                        datefmt="%Y-%m-%d %H:%M:%S",
+                        filename=filename,
+                        filemode=filemode)
+    console = logging.StreamHandler(sys.stdout)
+    console.setLevel(logging.INFO)
+    if log_all_ranks:
+        formatter = logging.Formatter('%(rank)s: %(message)s')
+    else:
+        formatter = logging.Formatter('%(message)s')
+    console.setFormatter(formatter)
+    logging.getLogger('').addHandler(console)
+    logging.getLogger('').addFilter(rank_filter)
+
+
+def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
+    if debug:
+        return
+
+    os.makedirs(dir_path, exist_ok=True)
+
+    print('Experiment dir : {}'.format(dir_path))
+    if scripts_to_save is not None:
+        script_path = os.path.join(dir_path, 'scripts')
+        os.makedirs(script_path, exist_ok=True)
+        for script in scripts_to_save:
+            dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
+            shutil.copyfile(script, dst_file)
+
+
+def build_work_dir_name(work_dir, dataset, append_dataset, append_time):
+    if append_dataset:
+        work_dir = '{}-{}'.format(work_dir, dataset)
+
+    if append_time:
+        now = int(time.time())
+        now_max = utils.distributed.all_reduce_item(now, op='max')
+        now_str = datetime.datetime.fromtimestamp(now_max).strftime('%Y%m%d-%H%M%S')
+
+        work_dir = os.path.join(work_dir, now_str)
+    return work_dir

+ 147 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/log_uniform_sampler.py

@@ -0,0 +1,147 @@
+import torch
+from torch import nn
+import numpy as np
+
+class LogUniformSampler(object):
+    def __init__(self, range_max, n_sample):
+        """
+        Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
+            `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
+
+        expected count can be approximated by 1 - (1 - p)^n
+        and we use a numerically stable version -expm1(num_tries * log1p(-p))
+
+        Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
+        """
+        with torch.no_grad():
+            self.range_max = range_max
+            log_indices = torch.arange(1., range_max+2., 1.).log_()
+            self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
+            # print('P', self.dist.numpy().tolist()[-30:])
+
+            self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
+
+        self.n_sample = n_sample
+
+    def sample(self, labels):
+        """
+            labels: [b1, b2]
+        Return
+            true_log_probs: [b1, b2]
+            samp_log_probs: [n_sample]
+            neg_samples: [n_sample]
+        """
+
+        # neg_samples = torch.empty(0).long()
+        n_sample = self.n_sample
+        n_tries = 2 * n_sample
+
+        with torch.no_grad():
+            neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
+            device = labels.device
+            neg_samples = neg_samples.to(device)
+            true_log_probs = self.log_q[labels].to(device)
+            samp_log_probs = self.log_q[neg_samples].to(device)
+            return true_log_probs, samp_log_probs, neg_samples
+
+def sample_logits(embedding, bias, labels, inputs, sampler):
+    """
+        embedding: an nn.Embedding layer
+        bias: [n_vocab]
+        labels: [b1, b2]
+        inputs: [b1, b2, n_emb]
+        sampler: you may use a LogUniformSampler
+    Return
+        logits: [b1, b2, 1 + n_sample]
+    """
+    true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
+    n_sample = neg_samples.size(0)
+    b1, b2 = labels.size(0), labels.size(1)
+    all_ids = torch.cat([labels.view(-1), neg_samples])
+    all_w = embedding(all_ids)
+    true_w = all_w[: -n_sample].view(b1, b2, -1)
+    sample_w = all_w[- n_sample:].view(n_sample, -1)
+
+    all_b = bias[all_ids]
+    true_b = all_b[: -n_sample].view(b1, b2)
+    sample_b = all_b[- n_sample:]
+
+    hit = (labels[:, :, None] == neg_samples).detach()
+
+    true_logits = torch.einsum('ijk,ijk->ij',
+        [true_w, inputs]) + true_b - true_log_probs
+    sample_logits = torch.einsum('lk,ijk->ijl',
+        [sample_w, inputs]) + sample_b - samp_log_probs
+    sample_logits.masked_fill_(hit, -1e30)
+    logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
+
+    return logits
+
+
+# class LogUniformSampler(object):
+#     def __init__(self, range_max, unique=False):
+#         """
+#         Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
+#             `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
+#         """
+#         self.range_max = range_max
+#         log_indices = torch.arange(1., range_max+2., 1.).log_()
+#         self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
+
+#         self.unique = unique
+
+#         if self.unique:
+#             self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
+
+#     def sample(self, n_sample, labels):
+#         pos_sample, new_labels = labels.unique(return_inverse=True)
+#         n_pos_sample = pos_sample.size(0)
+#         n_neg_sample = n_sample - n_pos_sample
+
+#         if self.unique:
+#             self.exclude_mask.index_fill_(0, pos_sample, 1)
+#             sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
+#             self.exclude_mask.index_fill_(0, pos_sample, 0)
+#         else:
+#             sample_dist = self.dist
+
+#         neg_sample = torch.multinomial(sample_dist, n_neg_sample)
+
+#         sample = torch.cat([pos_sample, neg_sample])
+#         sample_prob = self.dist[sample]
+
+#         return new_labels, sample, sample_prob
+
+
+if __name__ == '__main__':
+    S, B = 3, 4
+    n_vocab = 10000
+    n_sample = 5
+    H = 32
+
+    labels = torch.LongTensor(S, B).random_(0, n_vocab)
+
+    # sampler = LogUniformSampler(n_vocab, unique=False)
+    # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
+
+    sampler = LogUniformSampler(n_vocab, unique=True)
+    # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
+
+    # print('true_probs', true_probs.numpy().tolist())
+    # print('samp_probs', samp_probs.numpy().tolist())
+    # print('neg_samples', neg_samples.numpy().tolist())
+
+    # print('sum', torch.sum(sampler.dist).item())
+
+    # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
+
+    embedding = nn.Embedding(n_vocab, H)
+    bias = torch.zeros(n_vocab)
+    inputs = torch.Tensor(S, B, H).normal_()
+
+    logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
+    print('logits', logits.detach().numpy().tolist())
+    print('logits shape', logits.size())
+    print('out_labels', out_labels.detach().numpy().tolist())
+    print('out_labels shape', out_labels.size())
+

+ 154 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/proj_adaptive_softmax.py

@@ -0,0 +1,154 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ProjectedAdaptiveLogSoftmax(nn.Module):
+    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
+                 keep_order=False):
+        super().__init__()
+
+        self.n_token = n_token
+        self.d_embed = d_embed
+        self.d_proj = d_proj
+
+        self.cutoffs = cutoffs + [n_token]
+        self.cutoff_ends = [0] + self.cutoffs
+        self.div_val = div_val
+
+        self.shortlist_size = self.cutoffs[0]
+        self.n_clusters = len(self.cutoffs) - 1
+        self.head_size = self.shortlist_size + self.n_clusters
+
+        if self.n_clusters > 0:
+            self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
+            self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
+
+        self.out_layers = nn.ModuleList()
+
+        if div_val == 1:
+            if d_proj != d_embed:
+                self.out_projs = nn.ParameterList()
+                for i in range(len(self.cutoffs)):
+                    self.out_projs.append(
+                        nn.Parameter(torch.Tensor(d_proj, d_embed))
+                        )
+            else:
+                self.out_projs = [None] * len(self.cutoffs)
+
+            self.out_layers.append(nn.Linear(d_embed, n_token))
+        else:
+            self.out_projs = nn.ParameterList()
+            for i in range(len(self.cutoffs)):
+                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
+                d_emb_i = d_embed // (div_val ** i)
+
+                self.out_projs.append(
+                    nn.Parameter(torch.Tensor(d_proj, d_emb_i))
+                )
+
+                self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
+
+        self.keep_order = keep_order
+
+    def _compute_logit(self, hidden, weight, bias, proj):
+        if proj is None:
+            logit = F.linear(hidden, weight, bias=bias)
+        else:
+            logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
+            if bias is not None:
+                logit = logit + bias
+        return logit
+
+    def forward(self, hidden, target, keep_order=False):
+        '''
+            hidden :: [len*bsz x d_proj]
+            target :: [len*bsz]
+        '''
+
+        if hidden.size(0) != target.size(0):
+            raise RuntimeError('Input and target should have the same size '
+                               'in the batch dimension.')
+
+        if self.n_clusters == 0:
+            logit = self._compute_logit(hidden, self.out_layers[0].weight,
+                                        self.out_layers[0].bias, self.out_projs[0])
+            nll = -F.log_softmax(logit, dim=-1) \
+                    .gather(1, target.unsqueeze(1)).squeeze(1)
+        else:
+            # construct weights and biases
+            weights, biases = [], []
+            for i in range(len(self.cutoffs)):
+                if self.div_val == 1:
+                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+                    weight_i = self.out_layers[0].weight[l_idx:r_idx]
+                    bias_i = self.out_layers[0].bias[l_idx:r_idx]
+                else:
+                    weight_i = self.out_layers[i].weight
+                    bias_i = self.out_layers[i].bias
+
+                if i == 0:
+                    weight_i = torch.cat(
+                        [weight_i, self.cluster_weight], dim=0)
+                    bias_i = torch.cat(
+                        [bias_i, self.cluster_bias], dim=0)
+
+                weights.append(weight_i)
+                biases.append(bias_i)
+
+            head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
+
+            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
+            head_logprob = F.log_softmax(head_logit, dim=1)
+
+            nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device)
+
+            offset = 0
+            cutoff_values = [0] + self.cutoffs
+            for i in range(len(cutoff_values) - 1):
+                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
+
+                mask_i = (target >= l_idx) & (target < r_idx)
+                indices_i = mask_i.nonzero().squeeze()
+
+                if indices_i.numel() == 0:
+                    continue
+
+                target_i = target.index_select(0, indices_i) - l_idx
+                head_logprob_i = head_logprob.index_select(0, indices_i)
+
+                if i == 0:
+                    logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
+                else:
+                    weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
+
+                    hidden_i = hidden.index_select(0, indices_i)
+
+                    tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
+                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
+
+                    logprob_i = head_logprob_i[:, -i] \
+                        + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1)
+
+                if self.keep_order or keep_order:
+                    nll.index_copy_(0, indices_i, -logprob_i)
+                else:
+                    nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
+
+                offset += logprob_i.size(0)
+
+        return nll

+ 229 - 0
PyTorch/LanguageModeling/TransformerXL/pytorch/utils/vocabulary.py

@@ -0,0 +1,229 @@
+# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import os
+from collections import Counter, OrderedDict
+import utils
+
+import torch
+
+class Vocab(object):
+    def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
+                 delimiter=None, vocab_file=None):
+        self.counter = Counter()
+        self.special = special
+        self.min_freq = min_freq
+        self.max_size = max_size
+        self.lower_case = lower_case
+        self.delimiter = delimiter
+        self.vocab_file = vocab_file
+
+    def tokenize(self, line, add_eos=False, add_double_eos=False):
+        line = line.strip()
+        # convert to lower case
+        if self.lower_case:
+            line = line.lower()
+
+        # empty delimiter '' will evaluate False
+        if self.delimiter == '':
+            symbols = line
+        else:
+            symbols = line.split(self.delimiter)
+
+        if add_double_eos: # lm1b
+            return ['<S>'] + symbols + ['<S>']
+        elif add_eos:
+            return symbols + ['<eos>']
+        else:
+            return symbols
+
+    def count_file(self, path, verbose=False, add_eos=False):
+        if verbose: print('counting file {} ...'.format(path))
+        assert os.path.exists(path)
+
+        sents = []
+        with open(path, 'r', encoding='utf-8') as f:
+            for idx, line in enumerate(f):
+                if verbose and idx > 0 and idx % 500000 == 0:
+                    print('    line {}'.format(idx))
+                symbols = self.tokenize(line, add_eos=add_eos)
+                self.counter.update(symbols)
+                sents.append(symbols)
+
+        return sents
+
+    def count_sents(self, sents, verbose=False):
+        """
+            sents : a list of sentences, each a list of tokenized symbols
+        """
+        if verbose: print('counting {} sents ...'.format(len(sents)))
+        for idx, symbols in enumerate(sents):
+            if verbose and idx > 0 and idx % 500000 == 0:
+                print('    line {}'.format(idx))
+            self.counter.update(symbols)
+
+    def _build_from_file(self, vocab_file):
+        self.idx2sym = []
+        self.sym2idx = OrderedDict()
+
+        with open(vocab_file, 'r', encoding='utf-8') as f:
+            for line in f:
+                symb = line.strip().split()[0]
+                self.add_symbol(symb)
+        self.unk_idx = self.sym2idx['<UNK>']
+
+    def build_vocab(self):
+        if self.vocab_file:
+            print('building vocab from {}'.format(self.vocab_file))
+            self._build_from_file(self.vocab_file)
+            print('final vocab size {}'.format(len(self)))
+        else:
+            print('building vocab with min_freq={}, max_size={}'.format(
+                self.min_freq, self.max_size))
+            self.idx2sym = []
+            self.sym2idx = OrderedDict()
+
+            for sym in self.special:
+                self.add_special(sym)
+
+            for sym, cnt in self.counter.most_common(self.max_size):
+                if cnt < self.min_freq: break
+                self.add_symbol(sym)
+
+            print('final vocab size {} from {} unique tokens'.format(
+                len(self), len(self.counter)))
+
+    def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
+            add_double_eos=False):
+        if verbose: print('encoding file {} ...'.format(path))
+        assert os.path.exists(path)
+        encoded = []
+        with open(path, 'r', encoding='utf-8') as f:
+            for idx, line in enumerate(f):
+                if verbose and idx > 0 and idx % 500000 == 0:
+                    print('    line {}'.format(idx))
+                symbols = self.tokenize(line, add_eos=add_eos,
+                    add_double_eos=add_double_eos)
+                encoded.append(self.convert_to_tensor(symbols))
+
+        if ordered:
+            encoded = torch.cat(encoded)
+
+        return encoded
+
+    def encode_sents(self, sents, ordered=False, verbose=False):
+        if verbose: print('encoding {} sents ...'.format(len(sents)))
+        encoded = []
+        for idx, symbols in enumerate(sents):
+            if verbose and idx > 0 and idx % 500000 == 0:
+                print('    line {}'.format(idx))
+            encoded.append(self.convert_to_tensor(symbols))
+
+        if ordered:
+            encoded = torch.cat(encoded)
+
+        return encoded
+
+    def add_special(self, sym):
+        if sym not in self.sym2idx:
+            self.idx2sym.append(sym)
+            self.sym2idx[sym] = len(self.idx2sym) - 1
+            setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
+
+    def add_symbol(self, sym):
+        if sym not in self.sym2idx:
+            self.idx2sym.append(sym)
+            self.sym2idx[sym] = len(self.idx2sym) - 1
+
+    def get_sym(self, idx):
+        assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
+        return self.idx2sym[idx]
+
+    def get_idx(self, sym):
+        if sym in self.sym2idx:
+            return self.sym2idx[sym]
+        else:
+            # print('encounter unk {}'.format(sym))
+            assert '<eos>' not in sym
+            assert hasattr(self, 'unk_idx')
+            return self.sym2idx.get(sym, self.unk_idx)
+
+    def get_symbols(self, indices):
+        return [self.get_sym(idx) for idx in indices]
+
+    def get_indices(self, symbols):
+        return [self.get_idx(sym) for sym in symbols]
+
+    def convert_to_tensor(self, symbols):
+        return torch.LongTensor(self.get_indices(symbols))
+
+    def convert_to_sent(self, indices, exclude=None):
+        if exclude is None:
+            return ' '.join([self.get_sym(idx) for idx in indices])
+        else:
+            return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
+
+    def __len__(self):
+        return len(self.idx2sym)
+
+
+# Class OpenAIVocab has been adapted from
+# https://github.com/cybertronai/transformer-xl/blob/master/utils/vocabulary.py
+class OpenAIVocab(Vocab):
+    def __init__(self, max_size=None, vocab_file=None):
+        from pytorch_transformers import GPT2Tokenizer
+        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+        self.EOT = self.tokenizer.encoder['<|endoftext|>']
+        self.max_size = max_size
+        self.vocab_file = vocab_file
+
+        pad = 8
+        vocab_size = len(self.tokenizer)
+        padded_vocab_size = (vocab_size + pad - 1) // pad * pad
+        for i in range(0, padded_vocab_size - vocab_size):
+            token = f'madeupword{i:09d}'
+            self.tokenizer.add_tokens([token])
+
+    def __len__(self):
+        return len(self.tokenizer)
+
+    def count_file(self, path, verbose=False, add_eos=False):
+        # TODO: train from scratch, respect self.max_size
+        pass
+
+    def build_vocab(self):
+        pass
+
+    def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False) -> torch.LongTensor:
+        cached = path + '.bpe'
+        if os.path.exists(cached):
+            return torch.load(cached)
+        print(f'encoding file {path} ...')
+        assert os.path.exists(path), f"{path} doesn't exist"
+
+        with open(path, encoding='utf-8') as f:
+            # Suppress warnings about length.
+            with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull):
+                out = torch.LongTensor(self.tokenizer.encode(f.read()) + [self.EOT])
+                with utils.distributed.sync_workers() as rank:
+                    if rank == 0:
+                        torch.save(out, cached)
+                return out
+
+    def tokenize(self, line, add_eos=False, add_double_eos=False):
+        return self.tokenizer.encode(line)
+
+    def convert_to_tensor(self, symbols):
+        return torch.LongTensor(symbols)

+ 2 - 0
README.md

@@ -26,6 +26,8 @@ The examples are organized first by framework, such as TensorFlow, PyTorch, etc.
 - __GNMT__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Translation/GNMT)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Translation/GNMT)]
 - __Transformer__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Translation/Transformer)]
 - __BERT__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT)][[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT)]
+- __TransformerXL__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/TransformerXL)]
+
 
 ### Recommender Systems
 - __NCF__ [[PyTorch](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Recommendation/NCF)] [[TensorFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Recommendation/NCF)]