Kaynağa Gözat

[BERT/Paddle] Update base image and integrate cuDNN fused MHA

Shijie Wang 2 yıl önce
ebeveyn
işleme
a5388a45f7

+ 1 - 1
PaddlePaddle/LanguageModeling/BERT/Dockerfile

@@ -1,4 +1,4 @@
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:22.12-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.06-py3
 FROM ${FROM_IMAGE_NAME}
 FROM ${FROM_IMAGE_NAME}
 RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract
 RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract
 
 

+ 39 - 25
PaddlePaddle/LanguageModeling/BERT/README.md

@@ -437,6 +437,7 @@ Advanced Training:
   --use-dynamic-loss-scaling
   --use-dynamic-loss-scaling
                         Enable dynamic loss scaling in AMP training, only applied when --amp is set. (default: False)
                         Enable dynamic loss scaling in AMP training, only applied when --amp is set. (default: False)
   --use-pure-fp16       Enable pure FP16 training, only applied when --amp is set. (default: False)
   --use-pure-fp16       Enable pure FP16 training, only applied when --amp is set. (default: False)
+  --fuse-mha            Enable multihead attention fusion. Require cudnn version >= 8.9.1.
 ```
 ```
 
 
 
 
@@ -463,6 +464,7 @@ Default arguments are listed below in the order `scripts/run_squad.sh` expects:
 -   Enable benchmark - The default is `false`.
 -   Enable benchmark - The default is `false`.
 -   Benchmark steps - The default is `100`.
 -   Benchmark steps - The default is `100`.
 -   Benchmark warmup steps - The default is `100`.
 -   Benchmark warmup steps - The default is `100`.
+-   Fuse MHA fusion - The default is `true`
 
 
 The script saves the final checkpoint to the `/results/bert-large-uncased/squad` folder.
 The script saves the final checkpoint to the `/results/bert-large-uncased/squad` folder.
 
 
@@ -593,7 +595,8 @@ bash run_pretraining.sh \
     <bert_config_file> \
     <bert_config_file> \
     <enable_benchmark> \
     <enable_benchmark> \
     <benchmark_steps> \
     <benchmark_steps> \
-    <benchmark_warmup_steps>
+    <benchmark_warmup_steps> \
+    <fuse_mha>
 ```
 ```
 
 
 Where:
 Where:
@@ -627,6 +630,7 @@ Where:
 -   `masking` LDDL supports both static and dynamic masking. Refer to [LDDL's README](https://github.com/NVIDIA/LDDL/blob/main/README.md) for more information.
 -   `masking` LDDL supports both static and dynamic masking. Refer to [LDDL's README](https://github.com/NVIDIA/LDDL/blob/main/README.md) for more information.
 -   `<bert_config_file>` is the path to the bert config file.
 -   `<bert_config_file>` is the path to the bert config file.
 -   `<enable_benchmark>` a flag to enable benchmark. The train process will warmup for `<benchmark_warmup_steps>` and then measure the throughput of the following `<benchmark_steps>`.
 -   `<enable_benchmark>` a flag to enable benchmark. The train process will warmup for `<benchmark_warmup_steps>` and then measure the throughput of the following `<benchmark_steps>`.
+-   `<fuse_mha>` a flag to enable cuDNN MHA fusion.
 
 
 Note that: 
 Note that: 
 - If users follow [Quick Start Guide](#quick-start-guide) to set up container and dataset, there is no need to set any parameters. For example:
 - If users follow [Quick Start Guide](#quick-start-guide) to set up container and dataset, there is no need to set any parameters. For example:
@@ -670,6 +674,7 @@ python3 -m paddle.distributed.launch \
     --max-predictions-per-seq=20 \
     --max-predictions-per-seq=20 \
     --gradient-merge-steps=32 \
     --gradient-merge-steps=32 \
     --amp \
     --amp \
+    --fuse-mha \
     --use-dynamic-loss-scaling \
     --use-dynamic-loss-scaling \
     --optimizer=Lamb \
     --optimizer=Lamb \
     --phase1 \
     --phase1 \
@@ -769,7 +774,8 @@ bash scripts/run_squad.sh \
     <max_steps> \
     <max_steps> \
     <enable_benchmark> \
     <enable_benchmark> \
     <benchmark_steps> \
     <benchmark_steps> \
-    <benchmark_warmup_steps>
+    <benchmark_warmup_steps> \
+    <fuse_mha>
 ```
 ```
  
  
 By default, the `mode` argument is set to `train eval`. Refer to the [Quick Start Guide](#quick-start-guide) for explanations of each positional argument.
 By default, the `mode` argument is set to `train eval`. Refer to the [Quick Start Guide](#quick-start-guide) for explanations of each positional argument.
@@ -812,7 +818,7 @@ bash scripts/run_pretraining.sh \
     None \
     None \
     /path/to/wikipedia/source \
     /path/to/wikipedia/source \
     32 128 4 0.9 64 static \
     32 128 4 0.9 64 static \
-    None true 10 10
+    None true 10 10 true
 ```
 ```
 
 
 To benchmark the training performance on a specific batch size for SQuAD, refer to [Fine-tuning](#fine-tuning) and turn on the `<benchmark>` flags. An example call to run training for 200 steps (100 steps for warmup and 100 steps to measure), and generate throughput numbers:
 To benchmark the training performance on a specific batch size for SQuAD, refer to [Fine-tuning](#fine-tuning) and turn on the `<benchmark>` flags. An example call to run training for 200 steps (100 steps for warmup and 100 steps to measure), and generate throughput numbers:
@@ -825,7 +831,7 @@ bash scripts/run_squad.sh \
     results/checkpoints \
     results/checkpoints \
     train \
     train \
     bert_configs/bert-large-uncased.json \
     bert_configs/bert-large-uncased.json \
-    -1 true 100 100
+    -1 true 100 100 true
 ```
 ```
  
  
 #### Inference performance benchmark
 #### Inference performance benchmark
@@ -841,7 +847,8 @@ bash scripts/run_squad.sh \
     <results directory> \
     <results directory> \
     eval \
     eval \
     <BERT config path> \
     <BERT config path> \
-    <max steps> <benchmark> <benchmark_steps> <benchmark_warmup_steps>
+    <max steps> <benchmark> <benchmark_steps> <benchmark_warmup_steps> \
+    <fuse_mha>
 ```
 ```
  
  
 An example call to run inference and generate throughput numbers:
 An example call to run inference and generate throughput numbers:
@@ -854,7 +861,7 @@ bash scripts/run_squad.sh \
     results/checkpoints \
     results/checkpoints \
     eval \
     eval \
     bert_configs/bert-large-uncased.json \
     bert_configs/bert-large-uncased.json \
-    -1 true 100 100
+    -1 true 100 100 true
 ```
 ```
 
 
  
  
@@ -870,7 +877,7 @@ Our results were obtained by running the `scripts/run_squad.sh` and `scripts/run
 
 
 | DGX System         | GPUs / Node | Precision | Accumulated Batch size / GPU (Phase 1 and Phase 2) | Accumulation steps (Phase 1 and Phase 2) |     Final Loss    | Time to train(hours) | Time to train speedup (TF32 to mixed precision) |
 | DGX System         | GPUs / Node | Precision | Accumulated Batch size / GPU (Phase 1 and Phase 2) | Accumulation steps (Phase 1 and Phase 2) |     Final Loss    | Time to train(hours) | Time to train speedup (TF32 to mixed precision) |
 |--------------------|-------------|-----------|----------------------------------------------------|------------------------------------------|-------------------|----------------------|-------------------------------------------------|
 |--------------------|-------------|-----------|----------------------------------------------------|------------------------------------------|-------------------|----------------------|-------------------------------------------------|
-| 32 x DGX A100 80GB | 8           | AMP       | 256 and 128                                        | 1 and 4                                  |       1.409       |    ~ 1.2 hours       | 1.72                                            |
+| 32 x DGX A100 80GB | 8           | AMP       | 256 and 128                                        | 1 and 4                                  |       1.409       |    ~ 1.1 hours       | 2.27                                            |
 | 32 x DGX A100 80GB | 8           | TF32      | 128 and 16b                                        | 2 and 8                                  |       1.421       |    ~ 2.5 hours       | 1                                               |
 | 32 x DGX A100 80GB | 8           | TF32      | 128 and 16b                                        | 2 and 8                                  |       1.421       |    ~ 2.5 hours       | 1                                               |
 
 
 
 
@@ -914,28 +921,28 @@ Our results were obtained by running the script `run_pretraining.sh` in the Padd
 
 
 | GPUs | Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
 | GPUs | Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
 |------|----------------------------------|------------------------------------|-----------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
 |------|----------------------------------|------------------------------------|-----------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
-| 1    | 8192 and 8192                    | 64 and 32                          | 128             |  307                             |   633                                       | 2.06                                        | 1.00                | 1.00                           |
-| 8    | 8192 and 8192                    | 64 and 32                          | 128             | 2428                             |  4990                                       | 2.06                                        | 7.91                | 7.88                           |
-| 1    | 4096 and 4096                    | 256 and 128                        | 512             |  107                             |   219                                       | 2.05                                        | 1.00                | 1.00                           |
-| 8    | 4096 and 4096                    | 256 and 128                        | 512             |  851                             |  1724                                       | 2.26                                        | 7.95                | 7.87                           |
+| 1    | 8192 and 8192                    | 64 and 32                          | 128             |  307                             |   694                                       | 2.26                                        | 1.00                | 1.00                           |
+| 8    | 8192 and 8192                    | 64 and 32                          | 128             | 2428                             |  5541                                       | 2.28                                        | 7.91                | 7.98                           |
+| 1    | 4096 and 4096                    | 256 and 128                        | 512             |  107                             |   264                                       | 2.47                                        | 1.00                | 1.00                           |
+| 8    | 4096 and 4096                    | 256 and 128                        | 512             |  851                             |  2109                                       | 2.48                                        | 7.95                | 7.99                           |
 
 
 
 
 ###### Pre-training NVIDIA DGX A100 (8x A100 80GB) Multi-node Scaling
 ###### Pre-training NVIDIA DGX A100 (8x A100 80GB) Multi-node Scaling
 
 
 | Nodes | GPUs / node | Batch size / GPU (TF32 and FP16) | Accumulated Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Mixed Precision Throughput | Mixed Precision Strong Scaling | TF32 Throughput | TF32 Strong Scaling | Speedup (Mixed Precision to TF32) |
 | Nodes | GPUs / node | Batch size / GPU (TF32 and FP16) | Accumulated Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Mixed Precision Throughput | Mixed Precision Strong Scaling | TF32 Throughput | TF32 Strong Scaling | Speedup (Mixed Precision to TF32) |
 |-------|-------------|----------------------------------|------------------------------------|-----------------|----------------------------|--------------------------------|-----------------|---------------------|-----------------------------------|-----|
 |-------|-------------|----------------------------------|------------------------------------|-----------------|----------------------------|--------------------------------|-----------------|---------------------|-----------------------------------|-----|
-| 1     | 8           | 126 and 256 | 8192 and 8192                    | 64 and 32             | 128             |   4990               | 1                              |   2428          |  1                  |  2.06               |
-| 2     | 8           | 126 and 256 | 4096 and 4096                    | 32 and 16             | 128             |   9581               | 1.92                           |   4638          |  1.91               |  2.07               |
-| 4     | 8           | 126 and 256 | 2048 and 2048                    | 16 and 8              | 128             |   19262              | 3.86                           |   9445          |  3.89               |  2.04               |
-| 8     | 8           | 126 and 256 | 1024 and 1024                    | 8 and 4               | 128             |   37526              | 7.52                           |   18335         |  7.55               |  2.05               |
-| 16    | 8           | 126 and 256 | 512 and 512                      | 4 and 2               | 128             |   71156              | 14.26                          |   35526         |  14.63              |  2.00               |
-| 32    | 8           | 126 and 256 | 256 and 256                      | 2 and 1               | 128             |   142087             | 28.47                          |   69701         |  28.71              |  2.04               |
-| 1     | 8           | 16  and 32  | 4096 and 4096                    | 256 and 128           | 512             |   1724               | 1                              |   851           |  1                  |  2.03               |
-| 2     | 8           | 16  and 32  | 2048 and 2048                    | 128 and 64            | 512             |   3305               | 1.92                           |   1601          |  1.88               |  2.06               |
-| 4     | 8           | 16  and 32  | 1024 and 1024                    | 64 and 32             | 512             |   6492               | 3.77                           |   3240          |  3.81               |  2.00               |
-| 8     | 8           | 16  and 32  | 512 and 512                      | 32 and 16             | 512             |   12884              | 7.47                           |   6329          |  7.44               |  2.04               |
-| 16    | 8           | 16  and 32  | 256 and 256                      | 16 and 8              | 512             |   25493              | 14.79                          |   12273         |  14.42              |  2.08               |
-| 32    | 8           | 16  and 32  | 128 and 128                      | 8 and 4               | 512             |   49307              | 28.60                          |   24047         |  28.26              |  2.05               |
+| 1     | 8           | 126 and 256 | 8192 and 8192                    | 64 and 32             | 128             |   5541               | 1                              |   2428          |  1                  |  2.28               |
+| 2     | 8           | 126 and 256 | 4096 and 4096                    | 32 and 16             | 128             |   10646              | 1.92                           |   4638          |  1.91               |  2.29               |
+| 4     | 8           | 126 and 256 | 2048 and 2048                    | 16 and 8              | 128             |   21389              | 3.86                           |   9445          |  3.89               |  2.26               |
+| 8     | 8           | 126 and 256 | 1024 and 1024                    | 8 and 4               | 128             |   41681              | 7.52                           |   18335         |  7.55               |  2.27               |
+| 16    | 8           | 126 and 256 | 512 and 512                      | 4 and 2               | 128             |   79023              | 14.26                          |   35526         |  14.63              |  2.22               |
+| 32    | 8           | 126 and 256 | 256 and 256                      | 2 and 1               | 128             |   157952             | 28.51                          |   69701         |  28.71              |  2.27               |
+| 1     | 8           | 16  and 32  | 4096 and 4096                    | 256 and 128           | 512             |   2109               | 1                              |   851           |  1                  |  2.48               |
+| 2     | 8           | 16  and 32  | 2048 and 2048                    | 128 and 64            | 512             |   4051               | 1.92                           |   1601          |  1.88               |  2.53               |
+| 4     | 8           | 16  and 32  | 1024 and 1024                    | 64 and 32             | 512             |   7972               | 3.78                           |   3240          |  3.81               |  2.46               |
+| 8     | 8           | 16  and 32  | 512 and 512                      | 32 and 16             | 512             |   15760              | 7.47                           |   6329          |  7.44               |  2.49               |
+| 16    | 8           | 16  and 32  | 256 and 256                      | 16 and 8              | 512             |   31129              | 14.76                          |   12273         |  14.42              |  2.54               |
+| 32    | 8           | 16  and 32  | 128 and 128                      | 8 and 4               | 512             |   60206              | 28.55                          |   24047         |  28.26              |  2.50               |
 
 
 
 
 ###### Fine-tuning NVIDIA DGX A100 (8x A100 80GB)
 ###### Fine-tuning NVIDIA DGX A100 (8x A100 80GB)
@@ -944,8 +951,8 @@ Our results were obtained by running the script `run_pretraining.sh` in the Padd
   
   
 | GPUs | Batch size / GPU (TF32 and FP16) | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
 | GPUs | Batch size / GPU (TF32 and FP16) | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
 |------|----------------------------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
 |------|----------------------------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
-| 1    | 32 and 32                        |          83                      |               120                           |               1.45                          | 1.00                | 1.00                           |
-| 8    | 32 and 32                        |         629                      |               876                           |               1.39                          | 7.59                | 7.30                           |
+| 1    | 32 and 32                        |          83                      |               123                           |               1.48                          | 1.00                | 1.00                           |
+| 8    | 32 and 32                        |         629                      |               929                           |               1.48                          | 7.59                | 7.55                           |
  
  
 #### Inference performance results
 #### Inference performance results
 
 
@@ -983,6 +990,13 @@ August 2022
 - SQuAD finetune support with AdamW optimizer.
 - SQuAD finetune support with AdamW optimizer.
 - Updated accuracy and performance tables tested on A100.
 - Updated accuracy and performance tables tested on A100.
 - Initial release.
 - Initial release.
+
+March 2023
+- Pre-training using [Language Datasets and Data Loaders (LDDL)](https://github.com/NVIDIA/LDDL)
+- Binned pretraining for phase2 with LDDL using a bin size of 64
+
+July 2023
+- Optimize AMP training with cuDNN fused dot product attention kernel.
  
  
 ### Known issues
 ### Known issues
  
  

+ 2 - 1
PaddlePaddle/LanguageModeling/BERT/modeling.py

@@ -172,7 +172,8 @@ class BertModel(nn.Layer):
                 dropout=bert_config.hidden_dropout_prob,
                 dropout=bert_config.hidden_dropout_prob,
                 activation=bert_config.hidden_act,
                 activation=bert_config.hidden_act,
                 attn_dropout=bert_config.attention_probs_dropout_prob,
                 attn_dropout=bert_config.attention_probs_dropout_prob,
-                act_dropout=0)
+                act_dropout=0,
+                fuse_qkv=bert_config.fuse_mha)
             self.encoder = nn.TransformerEncoder(encoder_layer,
             self.encoder = nn.TransformerEncoder(encoder_layer,
                                                  bert_config.num_hidden_layers)
                                                  bert_config.num_hidden_layers)
 
 

+ 11 - 6
PaddlePaddle/LanguageModeling/BERT/program.py

@@ -44,12 +44,12 @@ def create_pretraining_data_holder():
     ]
     ]
 
 
 
 
-def create_strategy(use_amp, use_distributed_fused_lamb=False):
+def create_strategy(args, use_distributed_fused_lamb=False):
     """
     """
     Create paddle.static.BuildStrategy and paddle.static.ExecutionStrategy with arguments.
     Create paddle.static.BuildStrategy and paddle.static.ExecutionStrategy with arguments.
 
 
     Args:
     Args:
-        use_amp(bool): Whether to use amp.
+        args(Namespace): Arguments obtained from ArgumentParser.
         use_distributed_fused_lamb(bool, optional): Whether to use distributed fused lamb.
         use_distributed_fused_lamb(bool, optional): Whether to use distributed fused lamb.
     Returns:
     Returns:
         build_strategy(paddle.static.BuildStrategy): A instance of BuildStrategy.
         build_strategy(paddle.static.BuildStrategy): A instance of BuildStrategy.
@@ -59,8 +59,9 @@ def create_strategy(use_amp, use_distributed_fused_lamb=False):
     exec_strategy = paddle.static.ExecutionStrategy()
     exec_strategy = paddle.static.ExecutionStrategy()
 
 
     build_strategy.enable_addto = True
     build_strategy.enable_addto = True
-    if use_amp:
+    if args.amp:
         build_strategy.fuse_gemm_epilogue = True
         build_strategy.fuse_gemm_epilogue = True
+        build_strategy.fuse_dot_product_attention = args.fuse_mha
 
 
     if use_distributed_fused_lamb:
     if use_distributed_fused_lamb:
         build_strategy.fuse_all_reduce_ops = False
         build_strategy.fuse_all_reduce_ops = False
@@ -86,7 +87,7 @@ def dist_optimizer(args, optimizer):
         optimizer(fleet.distributed_optimizer): A distributed optimizer.
         optimizer(fleet.distributed_optimizer): A distributed optimizer.
     """
     """
     use_distributed_fused_lamb = True if args.optimizer == 'DistributedFusedLamb' else False
     use_distributed_fused_lamb = True if args.optimizer == 'DistributedFusedLamb' else False
-    build_strategy, exec_strategy = create_strategy(args.amp,
+    build_strategy, exec_strategy = create_strategy(args,
                                                     use_distributed_fused_lamb)
                                                     use_distributed_fused_lamb)
     dist_strategy = fleet.DistributedStrategy()
     dist_strategy = fleet.DistributedStrategy()
 
 
@@ -160,6 +161,7 @@ def build(args, main_prog, startup_prog, is_train=True):
             bert_config = BertConfig.from_json_file(args.config_file)
             bert_config = BertConfig.from_json_file(args.config_file)
             if bert_config.vocab_size % 8 != 0:
             if bert_config.vocab_size % 8 != 0:
                 bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
                 bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
+            bert_config.fuse_mha = args.fuse_mha
             model = BertForPretraining(bert_config)
             model = BertForPretraining(bert_config)
             criterion = BertPretrainingCriterion(bert_config.vocab_size)
             criterion = BertPretrainingCriterion(bert_config.vocab_size)
             prediction_scores, seq_relationship_score = model(
             prediction_scores, seq_relationship_score = model(
@@ -224,6 +226,7 @@ def run(exe,
     logging.info(f"Training will start at the {last_step+1}th step")
     logging.info(f"Training will start at the {last_step+1}th step")
 
 
     max_steps = args.max_steps
     max_steps = args.max_steps
+    steps_this_run = max_steps
     if args.steps_this_run is not None:
     if args.steps_this_run is not None:
         if args.steps_this_run + last_step > max_steps:
         if args.steps_this_run + last_step > max_steps:
             logging.info(
             logging.info(
@@ -231,12 +234,14 @@ def run(exe,
             )
             )
         else:
         else:
             steps_this_run = args.steps_this_run
             steps_this_run = args.steps_this_run
-            if args.benchmark:
-                steps_this_run = min(steps_this_run, args.benchmark_warmup_steps + args.benchmark_steps)
             max_steps = steps_this_run + last_step
             max_steps = steps_this_run + last_step
             logging.warning(
             logging.warning(
                 f"{steps_this_run} steps will be performed in this run.")
                 f"{steps_this_run} steps will be performed in this run.")
 
 
+    if args.benchmark:
+        max_steps = args.benchmark_warmup_steps + args.benchmark_steps + last_step
+
+
     total_samples = 0
     total_samples = 0
     raw_train_start = time.time()
     raw_train_start = time.time()
     step_start = time.time()
     step_start = time.time()

+ 2 - 1
PaddlePaddle/LanguageModeling/BERT/run_pretraining.py

@@ -81,7 +81,8 @@ def main():
         log_dir=None if args.output_dir is None else
         log_dir=None if args.output_dir is None else
         os.path.join(args.output_dir, 'lddl_log'),
         os.path.join(args.output_dir, 'lddl_log'),
         log_level=logging.WARNING,
         log_level=logging.WARNING,
-        start_epoch=0 if progress is None else progress.get("epoch", 0), )
+        start_epoch=0 if progress is None else progress.get("epoch", 0),
+        sequence_length_alignment=64)
 
 
     if args.amp:
     if args.amp:
         optimizer.amp_init(device)
         optimizer.amp_init(device)

+ 2 - 0
PaddlePaddle/LanguageModeling/BERT/run_squad.py

@@ -186,9 +186,11 @@ def main(args):
 
 
     with paddle.static.program_guard(main_program, startup_program):
     with paddle.static.program_guard(main_program, startup_program):
         bert_config = BertConfig.from_json_file(args.config_file)
         bert_config = BertConfig.from_json_file(args.config_file)
+        bert_config.fuse_mha = args.fuse_mha
         if bert_config.vocab_size % 8 != 0:
         if bert_config.vocab_size % 8 != 0:
             bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
             bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
 
 
+
         model = BertForQuestionAnswering(bert_config)
         model = BertForQuestionAnswering(bert_config)
         criterion = CrossEntropyLossForSQuAD()
         criterion = CrossEntropyLossForSQuAD()
         logits = model(input_ids=input_ids, token_type_ids=segment_ids)
         logits = model(input_ids=input_ids, token_type_ids=segment_ids)

+ 6 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_pretraining.sh

@@ -54,6 +54,7 @@ BERT_CONFIG=${28:-"None"}
 enable_benchmark=${29:-"false"}
 enable_benchmark=${29:-"false"}
 benchmark_steps=${30:-"10"}
 benchmark_steps=${30:-"10"}
 benchmark_warmup_steps=${31:-"10"}
 benchmark_warmup_steps=${31:-"10"}
+fuse_mha=${32:-"true"}
 
 
 # Calculate the total number of shards.
 # Calculate the total number of shards.
 readonly num_blocks=$((num_shards_per_worker * $(( num_workers > 0 ? num_workers : 1 )) * num_nodes * num_gpus))
 readonly num_blocks=$((num_shards_per_worker * $(( num_workers > 0 ? num_workers : 1 )) * num_nodes * num_gpus))
@@ -130,8 +131,12 @@ if [ "$BERT_CONFIG" != "None" ] ; then
 fi
 fi
 
 
 PREC=""
 PREC=""
+FUSE_MHA=""
 if [ "$precision" = "amp" ] ; then
 if [ "$precision" = "amp" ] ; then
    PREC="--amp --use-dynamic-loss-scaling --scale-loss=1048576"
    PREC="--amp --use-dynamic-loss-scaling --scale-loss=1048576"
+   if [ "$fuse_mha" = "true" ] ; then
+      FUSE_MHA="--fuse-mha"
+   fi
 elif [ "$precision" = "fp32" ] ; then
 elif [ "$precision" = "fp32" ] ; then
    PREC=""
    PREC=""
 elif [ "$precision" = "tf32" ] ; then
 elif [ "$precision" = "tf32" ] ; then
@@ -197,6 +202,7 @@ CMD+=" --log-freq=1"
 CMD+=" --optimizer=Lamb"
 CMD+=" --optimizer=Lamb"
 CMD+=" --phase1"
 CMD+=" --phase1"
 CMD+=" $PREC"
 CMD+=" $PREC"
+CMD+=" $FUSE_MHA"
 CMD+=" $ACCUMULATE_GRADIENTS"
 CMD+=" $ACCUMULATE_GRADIENTS"
 CMD+=" $INIT_CHECKPOINT"
 CMD+=" $INIT_CHECKPOINT"
 CMD+=" $BENCH"
 CMD+=" $BENCH"

+ 1 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_pretraining_p1.sh

@@ -31,6 +31,7 @@ python3 -m paddle.distributed.launch \
 --amp \
 --amp \
 --use-dynamic-loss-scaling \
 --use-dynamic-loss-scaling \
 --optimizer=Lamb \
 --optimizer=Lamb \
+--fuse-mha \
 --phase1 \
 --phase1 \
 --scale-loss=1048576 \
 --scale-loss=1048576 \
 --learning-rate=6e-3 \
 --learning-rate=6e-3 \

+ 1 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_pretraining_p2.sh

@@ -32,6 +32,7 @@ python3 -m paddle.distributed.launch \
 --amp \
 --amp \
 --use-dynamic-loss-scaling \
 --use-dynamic-loss-scaling \
 --optimizer=Lamb \
 --optimizer=Lamb \
+--fuse-mha \
 --phase2 \
 --phase2 \
 --scale-loss=1048576 \
 --scale-loss=1048576 \
 --learning-rate=4e-3 \
 --learning-rate=4e-3 \

+ 6 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_squad.sh

@@ -31,6 +31,7 @@ max_steps=${14:-"-1"}
 enable_benchmark=${15:-"false"}
 enable_benchmark=${15:-"false"}
 benchmark_steps=${16:-"100"}
 benchmark_steps=${16:-"100"}
 benchmark_warmup_steps=${17:-"100"}
 benchmark_warmup_steps=${17:-"100"}
+fuse_mha=${18:-"true"}
 
 
 
 
 echo "out dir is $OUT_DIR"
 echo "out dir is $OUT_DIR"
@@ -41,9 +42,13 @@ if [ ! -d "$OUT_DIR" ]; then
 fi
 fi
 
 
 amp=""
 amp=""
+FUSE_MHA=""
 if [ "$precision" = "amp" ] ; then
 if [ "$precision" = "amp" ] ; then
   echo "amp activated!"
   echo "amp activated!"
   amp=" --amp --use-dynamic-loss-scaling --scale-loss=128.0"
   amp=" --amp --use-dynamic-loss-scaling --scale-loss=128.0"
+  if [ "$fuse_mha" = "true" ] ; then
+    FUSE_MHA="--fuse-mha"
+  fi
 fi
 fi
 
 
 CONFIG=""
 CONFIG=""
@@ -119,6 +124,7 @@ CMD+=" --max-steps=$max_steps "
 CMD+=" --optimizer=AdamW "
 CMD+=" --optimizer=AdamW "
 CMD+=" --log-freq=100 "
 CMD+=" --log-freq=100 "
 CMD+=" $amp "
 CMD+=" $amp "
+CMD+=" $FUSE_MHA "
 CMD+=" $BENCH "
 CMD+=" $BENCH "
 CMD+=" --report-file $OUT_DIR/dllogger_${num_gpus}_${precision}.json "
 CMD+=" --report-file $OUT_DIR/dllogger_${num_gpus}_${precision}.json "
 
 

+ 196 - 101
PaddlePaddle/LanguageModeling/BERT/utils/config.py

@@ -18,6 +18,7 @@ import argparse
 import distutils.util
 import distutils.util
 import logging
 import logging
 import dllogger
 import dllogger
+import paddle
 from utils.task import Task
 from utils.task import Task
 from utils.save_load import _PDOPT_SUFFIX, _PDPARAMS_SUFFIX, _PROGRESS_SUFFIX
 from utils.save_load import _PDOPT_SUFFIX, _PDPARAMS_SUFFIX, _PROGRESS_SUFFIX
 
 
@@ -27,7 +28,7 @@ _DEFAULT_BERT_CONFIG = {
     'bert-large-uncased': './bert_configs/bert-large-uncased.json',
     'bert-large-uncased': './bert_configs/bert-large-uncased.json',
     'bert-large-cased': './bert_configs/bert-large-cased.json',
     'bert-large-cased': './bert_configs/bert-large-cased.json',
     'bert-base-uncased': './bert_configs/bert-base-uncased.json',
     'bert-base-uncased': './bert_configs/bert-base-uncased.json',
-    'bert-base-cased': './bert_configs/bert-base-cased.json'
+    'bert-base-cased': './bert_configs/bert-base-cased.json',
 }
 }
 
 
 
 
@@ -41,28 +42,34 @@ def _get_full_path_of_ckpt(args):
         pdparams_path = path_with_prefix + _PDPARAMS_SUFFIX
         pdparams_path = path_with_prefix + _PDPARAMS_SUFFIX
         progress_path = path_with_prefix + _PROGRESS_SUFFIX
         progress_path = path_with_prefix + _PROGRESS_SUFFIX
         found = False
         found = False
-        if os.path.exists(pdopt_path) and os.path.exists(
-                pdparams_path) and os.path.exists(progress_path):
+        if (
+            os.path.exists(pdopt_path)
+            and os.path.exists(pdparams_path)
+            and os.path.exists(progress_path)
+        ):
             found = True
             found = True
         return found, pdopt_path, pdparams_path, progress_path
         return found, pdopt_path, pdparams_path, progress_path
 
 
     if not os.path.exists(args.from_checkpoint):
     if not os.path.exists(args.from_checkpoint):
         logging.warning(
         logging.warning(
-            f"Start training from scratch since no checkpoint is found.")
+            f"Start training from scratch since no checkpoint is found."
+        )
         args.from_checkpoint = None
         args.from_checkpoint = None
         args.last_step_of_checkpoint = 0
         args.last_step_of_checkpoint = 0
         return
         return
 
 
-    target_from_checkpoint = os.path.join(args.from_checkpoint,
-                                          args.model_prefix)
+    target_from_checkpoint = os.path.join(
+        args.from_checkpoint, args.model_prefix
+    )
     if args.last_step_of_checkpoint is None:
     if args.last_step_of_checkpoint is None:
         args.last_step_of_checkpoint = 0
         args.last_step_of_checkpoint = 0
     elif args.last_step_of_checkpoint == _AUTO_LAST_EPOCH:
     elif args.last_step_of_checkpoint == _AUTO_LAST_EPOCH:
         folders = os.listdir(args.from_checkpoint)
         folders = os.listdir(args.from_checkpoint)
         args.last_step_of_checkpoint = 0
         args.last_step_of_checkpoint = 0
         for folder in folders:
         for folder in folders:
-            tmp_ckpt_path = os.path.join(args.from_checkpoint, folder,
-                                         args.model_prefix)
+            tmp_ckpt_path = os.path.join(
+                args.from_checkpoint, folder, args.model_prefix
+            )
 
 
             try:
             try:
                 folder = int(folder)
                 folder = int(folder)
@@ -72,23 +79,32 @@ def _get_full_path_of_ckpt(args):
                 )
                 )
                 continue
                 continue
 
 
-            if folder > args.last_step_of_checkpoint and \
-                _check_file_exist(tmp_ckpt_path)[0]:
+            if (
+                folder > args.last_step_of_checkpoint
+                and _check_file_exist(tmp_ckpt_path)[0]
+            ):
                 args.last_step_of_checkpoint = folder
                 args.last_step_of_checkpoint = folder
-        step_with_prefix = os.path.join(str(args.last_step_of_checkpoint), args.model_prefix) \
-                            if args.last_step_of_checkpoint > 0 else args.model_prefix
-        target_from_checkpoint = os.path.join(args.from_checkpoint,
-                                              step_with_prefix)
+        step_with_prefix = (
+            os.path.join(str(args.last_step_of_checkpoint), args.model_prefix)
+            if args.last_step_of_checkpoint > 0
+            else args.model_prefix
+        )
+        target_from_checkpoint = os.path.join(
+            args.from_checkpoint, step_with_prefix
+        )
     else:
     else:
         try:
         try:
             args.last_step_of_checkpoint = int(args.last_step_of_checkpoint)
             args.last_step_of_checkpoint = int(args.last_step_of_checkpoint)
         except ValueError:
         except ValueError:
-            raise ValueError(f"The value of --last-step-of-checkpoint should be None, {_AUTO_LAST_EPOCH}"  \
-                            f" or integer >= 0, but receive {args.last_step_of_checkpoint}")
+            raise ValueError(
+                f"The value of --last-step-of-checkpoint should be None, {_AUTO_LAST_EPOCH}"
+                f" or integer >= 0, but receive {args.last_step_of_checkpoint}"
+            )
 
 
     args.from_checkpoint = target_from_checkpoint
     args.from_checkpoint = target_from_checkpoint
     found, pdopt_path, pdparams_path, progress_path = _check_file_exist(
     found, pdopt_path, pdparams_path, progress_path = _check_file_exist(
-        args.from_checkpoint)
+        args.from_checkpoint
+    )
     if not found:
     if not found:
         args.from_checkpoint = None
         args.from_checkpoint = None
         args.last_step_of_checkpoint = 0
         args.last_step_of_checkpoint = 0
@@ -98,19 +114,28 @@ def _get_full_path_of_ckpt(args):
 
 
 
 
 def _get_full_path_of_pretrained_params(args, task=Task.pretrain):
 def _get_full_path_of_pretrained_params(args, task=Task.pretrain):
-    if args.from_pretrained_params is None and args.from_phase1_final_params is None:
+    if (
+        args.from_pretrained_params is None
+        and args.from_phase1_final_params is None
+    ):
         args.last_step_of_checkpoint = 0
         args.last_step_of_checkpoint = 0
         return
         return
-    if task == Task.pretrain and args.from_phase1_final_params is not None and args.last_step_of_checkpoint == 0:
+    if (
+        task == Task.pretrain
+        and args.from_phase1_final_params is not None
+        and args.last_step_of_checkpoint == 0
+    ):
         args.from_pretrained_params = args.from_phase1_final_params
         args.from_pretrained_params = args.from_phase1_final_params
 
 
-    args.from_pretrained_params = os.path.join(args.from_pretrained_params,
-                                               args.model_prefix)
+    args.from_pretrained_params = os.path.join(
+        args.from_pretrained_params, args.model_prefix
+    )
     pdparams_path = args.from_pretrained_params + _PDPARAMS_SUFFIX
     pdparams_path = args.from_pretrained_params + _PDPARAMS_SUFFIX
     if not os.path.exists(pdparams_path):
     if not os.path.exists(pdparams_path):
         args.from_pretrained_params = None
         args.from_pretrained_params = None
         logging.warning(
         logging.warning(
-            f"Cannot find {pdparams_path}, disable --from-pretrained-params.")
+            f"Cannot find {pdparams_path}, disable --from-pretrained-params."
+        )
     args.last_step_of_checkpoint = 0
     args.last_step_of_checkpoint = 0
 
 
 
 
@@ -121,20 +146,28 @@ def print_args(args):
 
 
 def check_and_process_args(args, task=Task.pretrain):
 def check_and_process_args(args, task=Task.pretrain):
     if task == Task.pretrain:
     if task == Task.pretrain:
-        assert not (args.from_checkpoint is not None and \
-            args.from_pretrained_params is not None), \
-           "--from-pretrained-params and --from-checkpoint should " \
-           "not be set simultaneously."
-        assert not (args.phase1 and args.phase2), \
-            "--phase1 and --phase2 should not be set simultaneously in bert pretraining."
+        assert not (
+            args.from_checkpoint is not None
+            and args.from_pretrained_params is not None
+        ), (
+            "--from-pretrained-params and --from-checkpoint should "
+            "not be set simultaneously."
+        )
+        assert not (
+            args.phase1 and args.phase2
+        ), "--phase1 and --phase2 should not be set simultaneously in bert pretraining."
         if args.from_phase1_final_params is not None:
         if args.from_phase1_final_params is not None:
-            assert args.phase2, "--from-phase1-final-params should only be used in phase2"
+            assert (
+                args.phase2
+            ), "--from-phase1-final-params should only be used in phase2"
 
 
         # SQuAD finetuning does not support suspend-resume yet.(TODO)
         # SQuAD finetuning does not support suspend-resume yet.(TODO)
         _get_full_path_of_ckpt(args)
         _get_full_path_of_ckpt(args)
 
 
     if args.bert_model == 'custom':
     if args.bert_model == 'custom':
-        assert args.config_file is not None, "--config-file must be specified if --bert-model=custom"
+        assert (
+            args.config_file is not None
+        ), "--config-file must be specified if --bert-model=custom"
     elif args.config_file is None:
     elif args.config_file is None:
         args.config_file = _DEFAULT_BERT_CONFIG[args.bert_model]
         args.config_file = _DEFAULT_BERT_CONFIG[args.bert_model]
         logging.info(
         logging.info(
@@ -144,7 +177,19 @@ def check_and_process_args(args, task=Task.pretrain):
         _get_full_path_of_pretrained_params(args, task)
         _get_full_path_of_pretrained_params(args, task)
 
 
     assert os.path.isfile(
     assert os.path.isfile(
-        args.config_file), f"Cannot find config file in {args.config_file}"
+        args.config_file
+    ), f"Cannot find config file in {args.config_file}"
+
+    # cudnn mha fusion is only supported after v8.9.1 on Ampere and Hopper GPU
+    device_capability = paddle.device.cuda.get_device_capability()
+    cudnn_mha_supported = paddle.get_cudnn_version() >= 8901 and (
+        device_capability == (8, 0) or device_capability == (9, 0)
+    )
+    if (not cudnn_mha_supported or args.amp is False) and args.fuse_mha is True:
+        logging.info(
+            f"cudnn mha fusion is not supported, fall back to unfused mha"
+        )
+        args.fuse_mha = False
 
 
 
 
 def add_global_args(parser, task=Task.pretrain):
 def add_global_args(parser, task=Task.pretrain):
@@ -155,145 +200,165 @@ def add_global_args(parser, task=Task.pretrain):
             type=str,
             type=str,
             default=None,
             default=None,
             required=True,
             required=True,
-            help='The input data directory. Should be specified by users and contain .hdf5 files for the task.'
+            help='The input data directory. Should be specified by users and contain .hdf5 files for the task.',
         )
         )
-        group.add_argument('--num-workers', default=4, type=int)
+        group.add_argument('--num-workers', default=1, type=int)
     if task == Task.squad:
     if task == Task.squad:
         group.add_argument(
         group.add_argument(
             '--train-file',
             '--train-file',
             type=str,
             type=str,
             default=None,
             default=None,
-            help='SQuAD json for training. E.g., train-v1.1.json')
+            help='SQuAD json for training. E.g., train-v1.1.json',
+        )
         group.add_argument(
         group.add_argument(
             '--predict-file',
             '--predict-file',
             type=str,
             type=str,
             default=None,
             default=None,
-            help='SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json'
+            help='SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json',
         )
         )
         group.add_argument(
         group.add_argument(
             "--eval-script",
             "--eval-script",
             help="Script to evaluate squad predictions",
             help="Script to evaluate squad predictions",
             default="evaluate.py",
             default="evaluate.py",
-            type=str)
+            type=str,
+        )
         group.add_argument(
         group.add_argument(
             '--epochs',
             '--epochs',
             type=int,
             type=int,
             default=3,
             default=3,
-            help='The number of epochs for training.')
+            help='The number of epochs for training.',
+        )
 
 
     group.add_argument(
     group.add_argument(
         '--vocab-file',
         '--vocab-file',
         type=str,
         type=str,
         default=None,
         default=None,
         required=True,
         required=True,
-        help="Vocabulary mapping/file BERT was pretrainined on")
+        help="Vocabulary mapping/file BERT was pretrainined on",
+    )
     group.add_argument(
     group.add_argument(
         '--output-dir',
         '--output-dir',
         type=str,
         type=str,
         default=None,
         default=None,
         required=True,
         required=True,
-        help='The output directory where the model checkpoints will be written. Should be specified by users.'
+        help='The output directory where the model checkpoints will be written. Should be specified by users.',
     )
     )
     group.add_argument(
     group.add_argument(
         '--bert-model',
         '--bert-model',
         type=str,
         type=str,
         default='bert-large-uncased',
         default='bert-large-uncased',
-        choices=('bert-base-uncased', 'bert-base-cased', 'bert-large-uncased',
-                 'bert-large-cased', 'custom'),
+        choices=(
+            'bert-base-uncased',
+            'bert-base-cased',
+            'bert-large-uncased',
+            'bert-large-cased',
+            'custom',
+        ),
         help='Specifies the type of BERT model to use. If it is set as custom, '
         help='Specifies the type of BERT model to use. If it is set as custom, '
-        'the path to the config file must be given by specifying --config-file')
+        'the path to the config file must be given by specifying --config-file',
+    )
     group.add_argument(
     group.add_argument(
         '--config-file',
         '--config-file',
         type=str,
         type=str,
         default=None,
         default=None,
-        help='The BERT model config. If set to None, `<--bert-model>.json` in folder `bert_configs` will be used.'
+        help='The BERT model config. If set to None, `<--bert-model>.json` in folder `bert_configs` will be used.',
     )
     )
     group.add_argument(
     group.add_argument(
         '--max-steps',
         '--max-steps',
         type=int,
         type=int,
         default=None,
         default=None,
         required=True if task == Task.pretrain else False,
         required=True if task == Task.pretrain else False,
-        help='Total number of training steps to perform.')
+        help='Total number of training steps to perform.',
+    )
     group.add_argument(
     group.add_argument(
-        '--log-freq', type=int, default=10, help='Frequency of logging loss.')
+        '--log-freq', type=int, default=10, help='Frequency of logging loss.'
+    )
     group.add_argument(
     group.add_argument(
         '--num-steps-per-checkpoint',
         '--num-steps-per-checkpoint',
         type=int,
         type=int,
         default=100,
         default=100,
-        help='Number of update steps until a model checkpoint is saved to disk.'
+        help='Number of update steps until a model checkpoint is saved to disk.',
     )
     )
     # Init model
     # Init model
     group.add_argument(
     group.add_argument(
         '--from-pretrained-params',
         '--from-pretrained-params',
         type=str,
         type=str,
         default=None,
         default=None,
-        help='Path to pretrained parameters. If set to None, no pretrained params will be used.'
+        help='Path to pretrained parameters. If set to None, no pretrained params will be used.',
     )
     )
     group.add_argument(
     group.add_argument(
         '--from-checkpoint',
         '--from-checkpoint',
         type=str,
         type=str,
         default=None,
         default=None,
-        help='A checkpoint path to resume training. If set to None, no checkpoint will be used. ' \
-             'If not None, --from-pretrained-params will be ignored.')
+        help='A checkpoint path to resume training. If set to None, no checkpoint will be used. '
+        'If not None, --from-pretrained-params will be ignored.',
+    )
     group.add_argument(
     group.add_argument(
         '--last-step-of-checkpoint',
         '--last-step-of-checkpoint',
         type=str,
         type=str,
         default=None,
         default=None,
-        help='The step id of the checkpoint given by --from-checkpoint. ' \
-             'It should be None, auto, or integer > 0. If it is set as ' \
-             'None, then training will start from the 1-th epoch. If it is set as ' \
-             'auto, then it will search largest integer-convertable folder ' \
-             ' --from-checkpoint, which contains required checkpoint. '
+        help='The step id of the checkpoint given by --from-checkpoint. '
+        'It should be None, auto, or integer > 0. If it is set as '
+        'None, then training will start from the 1-th epoch. If it is set as '
+        'auto, then it will search largest integer-convertable folder '
+        ' --from-checkpoint, which contains required checkpoint. ',
     )
     )
     if task == Task.pretrain:
     if task == Task.pretrain:
         group.add_argument(
         group.add_argument(
             '--from-phase1-final-params',
             '--from-phase1-final-params',
             type=str,
             type=str,
             default=None,
             default=None,
-            help='Path to final checkpoint of phase1, which will be used to ' \
-   'initialize the parameter in the first step of phase2, and ' \
-                 'ignored in the rest steps of phase2.'
+            help='Path to final checkpoint of phase1, which will be used to '
+            'initialize the parameter in the first step of phase2, and '
+            'ignored in the rest steps of phase2.',
         )
         )
         group.add_argument(
         group.add_argument(
             '--steps-this-run',
             '--steps-this-run',
             type=int,
             type=int,
             default=None,
             default=None,
-            help='If provided, only run this many steps before exiting.' \
+            help='If provided, only run this many steps before exiting.',
         )
         )
     group.add_argument(
     group.add_argument(
-        '--seed', type=int, default=42, help="random seed for initialization")
+        '--seed', type=int, default=42, help="random seed for initialization"
+    )
     group.add_argument(
     group.add_argument(
         '--report-file',
         '--report-file',
         type=str,
         type=str,
         default='./report.json',
         default='./report.json',
-        help='A file in which to store JSON experiment report.')
+        help='A file in which to store JSON experiment report.',
+    )
     group.add_argument(
     group.add_argument(
         '--model-prefix',
         '--model-prefix',
         type=str,
         type=str,
         default='bert_paddle',
         default='bert_paddle',
-        help='The prefix name of model files to save/load.')
+        help='The prefix name of model files to save/load.',
+    )
     group.add_argument(
     group.add_argument(
         '--show-config',
         '--show-config',
         type=distutils.util.strtobool,
         type=distutils.util.strtobool,
         default=True,
         default=True,
-        help='To show arguments.')
+        help='To show arguments.',
+    )
     group.add_argument(
     group.add_argument(
         '--enable-cpu-affinity',
         '--enable-cpu-affinity',
         type=distutils.util.strtobool,
         type=distutils.util.strtobool,
         default=True,
         default=True,
-        help='To enable in-built GPU-CPU affinity.')
+        help='To enable in-built GPU-CPU affinity.',
+    )
     group.add_argument(
     group.add_argument(
-        '--benchmark', action='store_true', help='To enable benchmark mode.')
+        '--benchmark', action='store_true', help='To enable benchmark mode.'
+    )
     group.add_argument(
     group.add_argument(
         '--benchmark-steps',
         '--benchmark-steps',
         type=int,
         type=int,
         default=20,
         default=20,
-        help='Steps for a benchmark run, only applied when --benchmark is set.')
+        help='Steps for a benchmark run, only applied when --benchmark is set.',
+    )
     group.add_argument(
     group.add_argument(
         '--benchmark-warmup-steps',
         '--benchmark-warmup-steps',
         type=int,
         type=int,
         default=20,
         default=20,
-        help='Warmup steps for a benchmark run, only applied when --benchmark is set.'
+        help='Warmup steps for a benchmark run, only applied when --benchmark is set.',
     )
     )
     return parser
     return parser
 
 
@@ -305,145 +370,166 @@ def add_training_args(parser, task=Task.pretrain):
         default='Lamb',
         default='Lamb',
         metavar="OPTIMIZER",
         metavar="OPTIMIZER",
         choices=('Lamb', 'AdamW'),
         choices=('Lamb', 'AdamW'),
-        help='The name of optimizer. It should be one of {Lamb, AdamW}.')
+        help='The name of optimizer. It should be one of {Lamb, AdamW}.',
+    )
     group.add_argument(
     group.add_argument(
         '--gradient-merge-steps',
         '--gradient-merge-steps',
         type=int,
         type=int,
         default=1,
         default=1,
-        help="Number of update steps to accumualte before performing a backward/update pass."
+        help="Number of update steps to accumualte before performing a backward/update pass.",
     )
     )
     group.add_argument(
     group.add_argument(
         '--learning-rate',
         '--learning-rate',
         type=float,
         type=float,
         default=1e-4,
         default=1e-4,
-        help='The initial learning rate.')
+        help='The initial learning rate.',
+    )
     group.add_argument(
     group.add_argument(
         '--warmup-start-lr',
         '--warmup-start-lr',
         type=float,
         type=float,
         default=0.0,
         default=0.0,
-        help='The initial learning rate for warmup.')
+        help='The initial learning rate for warmup.',
+    )
     group.add_argument(
     group.add_argument(
         '--warmup-proportion',
         '--warmup-proportion',
         type=float,
         type=float,
         default=0.01,
         default=0.01,
         help='Proportion of training to perform linear learning rate warmup for. '
         help='Proportion of training to perform linear learning rate warmup for. '
-        'For example, 0.1 = 10%% of training.')
+        'For example, 0.1 = 10%% of training.',
+    )
     group.add_argument(
     group.add_argument(
         '--beta1',
         '--beta1',
         type=float,
         type=float,
         default=0.9,
         default=0.9,
-        help='The exponential decay rate for the 1st moment estimates.')
+        help='The exponential decay rate for the 1st moment estimates.',
+    )
     group.add_argument(
     group.add_argument(
         '--beta2',
         '--beta2',
         type=float,
         type=float,
         default=0.999,
         default=0.999,
-        help='The exponential decay rate for the 2st moment estimates.')
+        help='The exponential decay rate for the 2st moment estimates.',
+    )
     group.add_argument(
     group.add_argument(
         '--epsilon',
         '--epsilon',
         type=float,
         type=float,
         default=1e-6,
         default=1e-6,
-        help='A small float value for numerical stability.')
+        help='A small float value for numerical stability.',
+    )
     group.add_argument(
     group.add_argument(
         '--weight-decay',
         '--weight-decay',
         type=float,
         type=float,
         default=0.01,
         default=0.01,
-        help='The weight decay coefficient.')
+        help='The weight decay coefficient.',
+    )
     group.add_argument(
     group.add_argument(
         '--max-seq-length',
         '--max-seq-length',
         default=512,
         default=512,
         type=int,
         type=int,
         help='The maximum total input sequence length after WordPiece tokenization. \n'
         help='The maximum total input sequence length after WordPiece tokenization. \n'
         'Sequences longer than this will be truncated, and sequences shorter \n'
         'Sequences longer than this will be truncated, and sequences shorter \n'
-        'than this will be padded.')
+        'than this will be padded.',
+    )
     if task == Task.pretrain:
     if task == Task.pretrain:
         group.add_argument(
         group.add_argument(
             '--batch-size',
             '--batch-size',
             type=int,
             type=int,
             default=32,
             default=32,
-            help='The batch size for training')
+            help='The batch size for training',
+        )
         group.add_argument(
         group.add_argument(
             '--phase1',
             '--phase1',
             action='store_true',
             action='store_true',
-            help='The phase of BERT pretraining. It should not be set ' \
-                'with --phase2 at the same time.'
+            help='The phase of BERT pretraining. It should not be set '
+            'with --phase2 at the same time.',
         )
         )
         group.add_argument(
         group.add_argument(
             '--phase2',
             '--phase2',
             action='store_true',
             action='store_true',
-            help='The phase of BERT pretraining. It should not be set ' \
-                'with --phase1 at the same time.'
+            help='The phase of BERT pretraining. It should not be set '
+            'with --phase1 at the same time.',
         )
         )
         group.add_argument(
         group.add_argument(
             '--max-predictions-per-seq',
             '--max-predictions-per-seq',
             default=80,
             default=80,
             type=int,
             type=int,
-            help='The maximum total of masked tokens in the input sequence')
+            help='The maximum total of masked tokens in the input sequence',
+        )
 
 
     if task == Task.squad:
     if task == Task.squad:
         group.add_argument(
         group.add_argument(
-            "--do-train", action='store_true', help="Whether to run training.")
+            "--do-train", action='store_true', help="Whether to run training."
+        )
         group.add_argument(
         group.add_argument(
             "--do-predict",
             "--do-predict",
             action='store_true',
             action='store_true',
-            help="Whether to run eval on the dev set.")
+            help="Whether to run eval on the dev set.",
+        )
         group.add_argument(
         group.add_argument(
             "--do-eval",
             "--do-eval",
             action='store_true',
             action='store_true',
-            help="Whether to use evaluate accuracy of predictions")
+            help="Whether to use evaluate accuracy of predictions",
+        )
         group.add_argument(
         group.add_argument(
             "--train-batch-size",
             "--train-batch-size",
             default=32,
             default=32,
             type=int,
             type=int,
-            help="Total batch size for training.")
+            help="Total batch size for training.",
+        )
         group.add_argument(
         group.add_argument(
             "--predict-batch-size",
             "--predict-batch-size",
             default=8,
             default=8,
             type=int,
             type=int,
-            help="Total batch size for predictions.")
+            help="Total batch size for predictions.",
+        )
         group.add_argument(
         group.add_argument(
             "--verbose-logging",
             "--verbose-logging",
             action='store_true',
             action='store_true',
             help="If true, all of the warnings related to data processing will be printed. "
             help="If true, all of the warnings related to data processing will be printed. "
-            "A number of warnings are expected for a normal SQuAD evaluation.")
+            "A number of warnings are expected for a normal SQuAD evaluation.",
+        )
         group.add_argument(
         group.add_argument(
             "--doc-stride",
             "--doc-stride",
             default=128,
             default=128,
             type=int,
             type=int,
             help="When splitting up a long document into chunks, how much stride to take "
             help="When splitting up a long document into chunks, how much stride to take "
-            "between chunks.")
+            "between chunks.",
+        )
         group.add_argument(
         group.add_argument(
             "--max-query-length",
             "--max-query-length",
             default=64,
             default=64,
             type=int,
             type=int,
             help="The maximum number of tokens for the question. Questions longer than this "
             help="The maximum number of tokens for the question. Questions longer than this "
-            "will be truncated to this length.")
+            "will be truncated to this length.",
+        )
         group.add_argument(
         group.add_argument(
             "--n-best-size",
             "--n-best-size",
             default=20,
             default=20,
             type=int,
             type=int,
             help="The total number of n-best predictions to generate in the nbest_predictions.json "
             help="The total number of n-best predictions to generate in the nbest_predictions.json "
-            "output file.")
+            "output file.",
+        )
         group.add_argument(
         group.add_argument(
             "--max-answer-length",
             "--max-answer-length",
             default=30,
             default=30,
             type=int,
             type=int,
             help="The maximum length of an answer that can be generated. This is needed because the start "
             help="The maximum length of an answer that can be generated. This is needed because the start "
-            "and end predictions are not conditioned on one another.")
+            "and end predictions are not conditioned on one another.",
+        )
         group.add_argument(
         group.add_argument(
             "--do-lower-case",
             "--do-lower-case",
             action='store_true',
             action='store_true',
-            help="Whether to lower case the input text. True for uncased models, False for cased models."
+            help="Whether to lower case the input text. True for uncased models, False for cased models.",
         )
         )
         group.add_argument(
         group.add_argument(
             '--version-2-with-negative',
             '--version-2-with-negative',
             action='store_true',
             action='store_true',
-            help='If true, the SQuAD examples contain some that do not have an answer.'
+            help='If true, the SQuAD examples contain some that do not have an answer.',
         )
         )
         group.add_argument(
         group.add_argument(
             '--null-score-diff-threshold',
             '--null-score-diff-threshold',
             type=float,
             type=float,
             default=0.0,
             default=0.0,
-            help="If null_score - best_non_null is greater than the threshold predict null."
+            help="If null_score - best_non_null is greater than the threshold predict null.",
         )
         )
     return parser
     return parser
 
 
@@ -453,22 +539,29 @@ def add_advance_args(parser):
     group.add_argument(
     group.add_argument(
         '--amp',
         '--amp',
         action='store_true',
         action='store_true',
-        help='Enable automatic mixed precision training (AMP).')
+        help='Enable automatic mixed precision training (AMP).',
+    )
     group.add_argument(
     group.add_argument(
         '--scale-loss',
         '--scale-loss',
         type=float,
         type=float,
         default=1.0,
         default=1.0,
-        help='The loss scalar for AMP training, only applied when --amp is set.'
+        help='The loss scalar for AMP training, only applied when --amp is set.',
     )
     )
     group.add_argument(
     group.add_argument(
         '--use-dynamic-loss-scaling',
         '--use-dynamic-loss-scaling',
         action='store_true',
         action='store_true',
-        help='Enable dynamic loss scaling in AMP training, only applied when --amp is set.'
+        help='Enable dynamic loss scaling in AMP training, only applied when --amp is set.',
     )
     )
     group.add_argument(
     group.add_argument(
         '--use-pure-fp16',
         '--use-pure-fp16',
         action='store_true',
         action='store_true',
-        help='Enable pure FP16 training, only applied when --amp is set.')
+        help='Enable pure FP16 training, only applied when --amp is set.',
+    )
+    group.add_argument(
+        '--fuse-mha',
+        action='store_true',
+        help='Enable multihead attention fusion. Require cudnn version >= 8.9.1',
+    )
 
 
     return parser
     return parser
 
 
@@ -476,8 +569,10 @@ def add_advance_args(parser):
 def parse_args(task=Task.pretrain):
 def parse_args(task=Task.pretrain):
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         description="PaddlePaddle BERT pretraining script"
         description="PaddlePaddle BERT pretraining script"
-        if task == Task.pretrain else "PaddlePaddle SQuAD finetuning script",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+        if task == Task.pretrain
+        else "PaddlePaddle SQuAD finetuning script",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
 
 
     parser = add_global_args(parser, task)
     parser = add_global_args(parser, task)
     parser = add_training_args(parser, task)
     parser = add_training_args(parser, task)