Explorar el Código

[BERT/Paddle] Update base image and integrate cuDNN fused MHA

Shijie Wang hace 2 años
padre
commit
a5388a45f7

+ 1 - 1
PaddlePaddle/LanguageModeling/BERT/Dockerfile

@@ -1,4 +1,4 @@
-ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:22.12-py3
+ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.06-py3
 FROM ${FROM_IMAGE_NAME}
 RUN apt-get update && apt-get install -y pbzip2 pv bzip2 cabextract
 

+ 39 - 25
PaddlePaddle/LanguageModeling/BERT/README.md

@@ -437,6 +437,7 @@ Advanced Training:
   --use-dynamic-loss-scaling
                         Enable dynamic loss scaling in AMP training, only applied when --amp is set. (default: False)
   --use-pure-fp16       Enable pure FP16 training, only applied when --amp is set. (default: False)
+  --fuse-mha            Enable multihead attention fusion. Require cudnn version >= 8.9.1.
 ```
 
 
@@ -463,6 +464,7 @@ Default arguments are listed below in the order `scripts/run_squad.sh` expects:
 -   Enable benchmark - The default is `false`.
 -   Benchmark steps - The default is `100`.
 -   Benchmark warmup steps - The default is `100`.
+-   Fuse MHA fusion - The default is `true`
 
 The script saves the final checkpoint to the `/results/bert-large-uncased/squad` folder.
 
@@ -593,7 +595,8 @@ bash run_pretraining.sh \
     <bert_config_file> \
     <enable_benchmark> \
     <benchmark_steps> \
-    <benchmark_warmup_steps>
+    <benchmark_warmup_steps> \
+    <fuse_mha>
 ```
 
 Where:
@@ -627,6 +630,7 @@ Where:
 -   `masking` LDDL supports both static and dynamic masking. Refer to [LDDL's README](https://github.com/NVIDIA/LDDL/blob/main/README.md) for more information.
 -   `<bert_config_file>` is the path to the bert config file.
 -   `<enable_benchmark>` a flag to enable benchmark. The train process will warmup for `<benchmark_warmup_steps>` and then measure the throughput of the following `<benchmark_steps>`.
+-   `<fuse_mha>` a flag to enable cuDNN MHA fusion.
 
 Note that: 
 - If users follow [Quick Start Guide](#quick-start-guide) to set up container and dataset, there is no need to set any parameters. For example:
@@ -670,6 +674,7 @@ python3 -m paddle.distributed.launch \
     --max-predictions-per-seq=20 \
     --gradient-merge-steps=32 \
     --amp \
+    --fuse-mha \
     --use-dynamic-loss-scaling \
     --optimizer=Lamb \
     --phase1 \
@@ -769,7 +774,8 @@ bash scripts/run_squad.sh \
     <max_steps> \
     <enable_benchmark> \
     <benchmark_steps> \
-    <benchmark_warmup_steps>
+    <benchmark_warmup_steps> \
+    <fuse_mha>
 ```
  
 By default, the `mode` argument is set to `train eval`. Refer to the [Quick Start Guide](#quick-start-guide) for explanations of each positional argument.
@@ -812,7 +818,7 @@ bash scripts/run_pretraining.sh \
     None \
     /path/to/wikipedia/source \
     32 128 4 0.9 64 static \
-    None true 10 10
+    None true 10 10 true
 ```
 
 To benchmark the training performance on a specific batch size for SQuAD, refer to [Fine-tuning](#fine-tuning) and turn on the `<benchmark>` flags. An example call to run training for 200 steps (100 steps for warmup and 100 steps to measure), and generate throughput numbers:
@@ -825,7 +831,7 @@ bash scripts/run_squad.sh \
     results/checkpoints \
     train \
     bert_configs/bert-large-uncased.json \
-    -1 true 100 100
+    -1 true 100 100 true
 ```
  
 #### Inference performance benchmark
@@ -841,7 +847,8 @@ bash scripts/run_squad.sh \
     <results directory> \
     eval \
     <BERT config path> \
-    <max steps> <benchmark> <benchmark_steps> <benchmark_warmup_steps>
+    <max steps> <benchmark> <benchmark_steps> <benchmark_warmup_steps> \
+    <fuse_mha>
 ```
  
 An example call to run inference and generate throughput numbers:
@@ -854,7 +861,7 @@ bash scripts/run_squad.sh \
     results/checkpoints \
     eval \
     bert_configs/bert-large-uncased.json \
-    -1 true 100 100
+    -1 true 100 100 true
 ```
 
  
@@ -870,7 +877,7 @@ Our results were obtained by running the `scripts/run_squad.sh` and `scripts/run
 
 | DGX System         | GPUs / Node | Precision | Accumulated Batch size / GPU (Phase 1 and Phase 2) | Accumulation steps (Phase 1 and Phase 2) |     Final Loss    | Time to train(hours) | Time to train speedup (TF32 to mixed precision) |
 |--------------------|-------------|-----------|----------------------------------------------------|------------------------------------------|-------------------|----------------------|-------------------------------------------------|
-| 32 x DGX A100 80GB | 8           | AMP       | 256 and 128                                        | 1 and 4                                  |       1.409       |    ~ 1.2 hours       | 1.72                                            |
+| 32 x DGX A100 80GB | 8           | AMP       | 256 and 128                                        | 1 and 4                                  |       1.409       |    ~ 1.1 hours       | 2.27                                            |
 | 32 x DGX A100 80GB | 8           | TF32      | 128 and 16b                                        | 2 and 8                                  |       1.421       |    ~ 2.5 hours       | 1                                               |
 
 
@@ -914,28 +921,28 @@ Our results were obtained by running the script `run_pretraining.sh` in the Padd
 
 | GPUs | Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
 |------|----------------------------------|------------------------------------|-----------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
-| 1    | 8192 and 8192                    | 64 and 32                          | 128             |  307                             |   633                                       | 2.06                                        | 1.00                | 1.00                           |
-| 8    | 8192 and 8192                    | 64 and 32                          | 128             | 2428                             |  4990                                       | 2.06                                        | 7.91                | 7.88                           |
-| 1    | 4096 and 4096                    | 256 and 128                        | 512             |  107                             |   219                                       | 2.05                                        | 1.00                | 1.00                           |
-| 8    | 4096 and 4096                    | 256 and 128                        | 512             |  851                             |  1724                                       | 2.26                                        | 7.95                | 7.87                           |
+| 1    | 8192 and 8192                    | 64 and 32                          | 128             |  307                             |   694                                       | 2.26                                        | 1.00                | 1.00                           |
+| 8    | 8192 and 8192                    | 64 and 32                          | 128             | 2428                             |  5541                                       | 2.28                                        | 7.91                | 7.98                           |
+| 1    | 4096 and 4096                    | 256 and 128                        | 512             |  107                             |   264                                       | 2.47                                        | 1.00                | 1.00                           |
+| 8    | 4096 and 4096                    | 256 and 128                        | 512             |  851                             |  2109                                       | 2.48                                        | 7.95                | 7.99                           |
 
 
 ###### Pre-training NVIDIA DGX A100 (8x A100 80GB) Multi-node Scaling
 
 | Nodes | GPUs / node | Batch size / GPU (TF32 and FP16) | Accumulated Batch size / GPU (TF32 and FP16) | Accumulation steps (TF32 and FP16) | Sequence length | Mixed Precision Throughput | Mixed Precision Strong Scaling | TF32 Throughput | TF32 Strong Scaling | Speedup (Mixed Precision to TF32) |
 |-------|-------------|----------------------------------|------------------------------------|-----------------|----------------------------|--------------------------------|-----------------|---------------------|-----------------------------------|-----|
-| 1     | 8           | 126 and 256 | 8192 and 8192                    | 64 and 32             | 128             |   4990               | 1                              |   2428          |  1                  |  2.06               |
-| 2     | 8           | 126 and 256 | 4096 and 4096                    | 32 and 16             | 128             |   9581               | 1.92                           |   4638          |  1.91               |  2.07               |
-| 4     | 8           | 126 and 256 | 2048 and 2048                    | 16 and 8              | 128             |   19262              | 3.86                           |   9445          |  3.89               |  2.04               |
-| 8     | 8           | 126 and 256 | 1024 and 1024                    | 8 and 4               | 128             |   37526              | 7.52                           |   18335         |  7.55               |  2.05               |
-| 16    | 8           | 126 and 256 | 512 and 512                      | 4 and 2               | 128             |   71156              | 14.26                          |   35526         |  14.63              |  2.00               |
-| 32    | 8           | 126 and 256 | 256 and 256                      | 2 and 1               | 128             |   142087             | 28.47                          |   69701         |  28.71              |  2.04               |
-| 1     | 8           | 16  and 32  | 4096 and 4096                    | 256 and 128           | 512             |   1724               | 1                              |   851           |  1                  |  2.03               |
-| 2     | 8           | 16  and 32  | 2048 and 2048                    | 128 and 64            | 512             |   3305               | 1.92                           |   1601          |  1.88               |  2.06               |
-| 4     | 8           | 16  and 32  | 1024 and 1024                    | 64 and 32             | 512             |   6492               | 3.77                           |   3240          |  3.81               |  2.00               |
-| 8     | 8           | 16  and 32  | 512 and 512                      | 32 and 16             | 512             |   12884              | 7.47                           |   6329          |  7.44               |  2.04               |
-| 16    | 8           | 16  and 32  | 256 and 256                      | 16 and 8              | 512             |   25493              | 14.79                          |   12273         |  14.42              |  2.08               |
-| 32    | 8           | 16  and 32  | 128 and 128                      | 8 and 4               | 512             |   49307              | 28.60                          |   24047         |  28.26              |  2.05               |
+| 1     | 8           | 126 and 256 | 8192 and 8192                    | 64 and 32             | 128             |   5541               | 1                              |   2428          |  1                  |  2.28               |
+| 2     | 8           | 126 and 256 | 4096 and 4096                    | 32 and 16             | 128             |   10646              | 1.92                           |   4638          |  1.91               |  2.29               |
+| 4     | 8           | 126 and 256 | 2048 and 2048                    | 16 and 8              | 128             |   21389              | 3.86                           |   9445          |  3.89               |  2.26               |
+| 8     | 8           | 126 and 256 | 1024 and 1024                    | 8 and 4               | 128             |   41681              | 7.52                           |   18335         |  7.55               |  2.27               |
+| 16    | 8           | 126 and 256 | 512 and 512                      | 4 and 2               | 128             |   79023              | 14.26                          |   35526         |  14.63              |  2.22               |
+| 32    | 8           | 126 and 256 | 256 and 256                      | 2 and 1               | 128             |   157952             | 28.51                          |   69701         |  28.71              |  2.27               |
+| 1     | 8           | 16  and 32  | 4096 and 4096                    | 256 and 128           | 512             |   2109               | 1                              |   851           |  1                  |  2.48               |
+| 2     | 8           | 16  and 32  | 2048 and 2048                    | 128 and 64            | 512             |   4051               | 1.92                           |   1601          |  1.88               |  2.53               |
+| 4     | 8           | 16  and 32  | 1024 and 1024                    | 64 and 32             | 512             |   7972               | 3.78                           |   3240          |  3.81               |  2.46               |
+| 8     | 8           | 16  and 32  | 512 and 512                      | 32 and 16             | 512             |   15760              | 7.47                           |   6329          |  7.44               |  2.49               |
+| 16    | 8           | 16  and 32  | 256 and 256                      | 16 and 8              | 512             |   31129              | 14.76                          |   12273         |  14.42              |  2.54               |
+| 32    | 8           | 16  and 32  | 128 and 128                      | 8 and 4               | 512             |   60206              | 28.55                          |   24047         |  28.26              |  2.50               |
 
 
 ###### Fine-tuning NVIDIA DGX A100 (8x A100 80GB)
@@ -944,8 +951,8 @@ Our results were obtained by running the script `run_pretraining.sh` in the Padd
   
 | GPUs | Batch size / GPU (TF32 and FP16) | Throughput - TF32(sequences/sec) | Throughput - mixed precision(sequences/sec) | Throughput speedup (TF32 - mixed precision) | Weak scaling - TF32 | Weak scaling - mixed precision |
 |------|----------------------------------|----------------------------------|---------------------------------------------|---------------------------------------------|---------------------|--------------------------------|
-| 1    | 32 and 32                        |          83                      |               120                           |               1.45                          | 1.00                | 1.00                           |
-| 8    | 32 and 32                        |         629                      |               876                           |               1.39                          | 7.59                | 7.30                           |
+| 1    | 32 and 32                        |          83                      |               123                           |               1.48                          | 1.00                | 1.00                           |
+| 8    | 32 and 32                        |         629                      |               929                           |               1.48                          | 7.59                | 7.55                           |
  
 #### Inference performance results
 
@@ -983,6 +990,13 @@ August 2022
 - SQuAD finetune support with AdamW optimizer.
 - Updated accuracy and performance tables tested on A100.
 - Initial release.
+
+March 2023
+- Pre-training using [Language Datasets and Data Loaders (LDDL)](https://github.com/NVIDIA/LDDL)
+- Binned pretraining for phase2 with LDDL using a bin size of 64
+
+July 2023
+- Optimize AMP training with cuDNN fused dot product attention kernel.
  
 ### Known issues
  

+ 2 - 1
PaddlePaddle/LanguageModeling/BERT/modeling.py

@@ -172,7 +172,8 @@ class BertModel(nn.Layer):
                 dropout=bert_config.hidden_dropout_prob,
                 activation=bert_config.hidden_act,
                 attn_dropout=bert_config.attention_probs_dropout_prob,
-                act_dropout=0)
+                act_dropout=0,
+                fuse_qkv=bert_config.fuse_mha)
             self.encoder = nn.TransformerEncoder(encoder_layer,
                                                  bert_config.num_hidden_layers)
 

+ 11 - 6
PaddlePaddle/LanguageModeling/BERT/program.py

@@ -44,12 +44,12 @@ def create_pretraining_data_holder():
     ]
 
 
-def create_strategy(use_amp, use_distributed_fused_lamb=False):
+def create_strategy(args, use_distributed_fused_lamb=False):
     """
     Create paddle.static.BuildStrategy and paddle.static.ExecutionStrategy with arguments.
 
     Args:
-        use_amp(bool): Whether to use amp.
+        args(Namespace): Arguments obtained from ArgumentParser.
         use_distributed_fused_lamb(bool, optional): Whether to use distributed fused lamb.
     Returns:
         build_strategy(paddle.static.BuildStrategy): A instance of BuildStrategy.
@@ -59,8 +59,9 @@ def create_strategy(use_amp, use_distributed_fused_lamb=False):
     exec_strategy = paddle.static.ExecutionStrategy()
 
     build_strategy.enable_addto = True
-    if use_amp:
+    if args.amp:
         build_strategy.fuse_gemm_epilogue = True
+        build_strategy.fuse_dot_product_attention = args.fuse_mha
 
     if use_distributed_fused_lamb:
         build_strategy.fuse_all_reduce_ops = False
@@ -86,7 +87,7 @@ def dist_optimizer(args, optimizer):
         optimizer(fleet.distributed_optimizer): A distributed optimizer.
     """
     use_distributed_fused_lamb = True if args.optimizer == 'DistributedFusedLamb' else False
-    build_strategy, exec_strategy = create_strategy(args.amp,
+    build_strategy, exec_strategy = create_strategy(args,
                                                     use_distributed_fused_lamb)
     dist_strategy = fleet.DistributedStrategy()
 
@@ -160,6 +161,7 @@ def build(args, main_prog, startup_prog, is_train=True):
             bert_config = BertConfig.from_json_file(args.config_file)
             if bert_config.vocab_size % 8 != 0:
                 bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
+            bert_config.fuse_mha = args.fuse_mha
             model = BertForPretraining(bert_config)
             criterion = BertPretrainingCriterion(bert_config.vocab_size)
             prediction_scores, seq_relationship_score = model(
@@ -224,6 +226,7 @@ def run(exe,
     logging.info(f"Training will start at the {last_step+1}th step")
 
     max_steps = args.max_steps
+    steps_this_run = max_steps
     if args.steps_this_run is not None:
         if args.steps_this_run + last_step > max_steps:
             logging.info(
@@ -231,12 +234,14 @@ def run(exe,
             )
         else:
             steps_this_run = args.steps_this_run
-            if args.benchmark:
-                steps_this_run = min(steps_this_run, args.benchmark_warmup_steps + args.benchmark_steps)
             max_steps = steps_this_run + last_step
             logging.warning(
                 f"{steps_this_run} steps will be performed in this run.")
 
+    if args.benchmark:
+        max_steps = args.benchmark_warmup_steps + args.benchmark_steps + last_step
+
+
     total_samples = 0
     raw_train_start = time.time()
     step_start = time.time()

+ 2 - 1
PaddlePaddle/LanguageModeling/BERT/run_pretraining.py

@@ -81,7 +81,8 @@ def main():
         log_dir=None if args.output_dir is None else
         os.path.join(args.output_dir, 'lddl_log'),
         log_level=logging.WARNING,
-        start_epoch=0 if progress is None else progress.get("epoch", 0), )
+        start_epoch=0 if progress is None else progress.get("epoch", 0),
+        sequence_length_alignment=64)
 
     if args.amp:
         optimizer.amp_init(device)

+ 2 - 0
PaddlePaddle/LanguageModeling/BERT/run_squad.py

@@ -186,9 +186,11 @@ def main(args):
 
     with paddle.static.program_guard(main_program, startup_program):
         bert_config = BertConfig.from_json_file(args.config_file)
+        bert_config.fuse_mha = args.fuse_mha
         if bert_config.vocab_size % 8 != 0:
             bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
 
+
         model = BertForQuestionAnswering(bert_config)
         criterion = CrossEntropyLossForSQuAD()
         logits = model(input_ids=input_ids, token_type_ids=segment_ids)

+ 6 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_pretraining.sh

@@ -54,6 +54,7 @@ BERT_CONFIG=${28:-"None"}
 enable_benchmark=${29:-"false"}
 benchmark_steps=${30:-"10"}
 benchmark_warmup_steps=${31:-"10"}
+fuse_mha=${32:-"true"}
 
 # Calculate the total number of shards.
 readonly num_blocks=$((num_shards_per_worker * $(( num_workers > 0 ? num_workers : 1 )) * num_nodes * num_gpus))
@@ -130,8 +131,12 @@ if [ "$BERT_CONFIG" != "None" ] ; then
 fi
 
 PREC=""
+FUSE_MHA=""
 if [ "$precision" = "amp" ] ; then
    PREC="--amp --use-dynamic-loss-scaling --scale-loss=1048576"
+   if [ "$fuse_mha" = "true" ] ; then
+      FUSE_MHA="--fuse-mha"
+   fi
 elif [ "$precision" = "fp32" ] ; then
    PREC=""
 elif [ "$precision" = "tf32" ] ; then
@@ -197,6 +202,7 @@ CMD+=" --log-freq=1"
 CMD+=" --optimizer=Lamb"
 CMD+=" --phase1"
 CMD+=" $PREC"
+CMD+=" $FUSE_MHA"
 CMD+=" $ACCUMULATE_GRADIENTS"
 CMD+=" $INIT_CHECKPOINT"
 CMD+=" $BENCH"

+ 1 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_pretraining_p1.sh

@@ -31,6 +31,7 @@ python3 -m paddle.distributed.launch \
 --amp \
 --use-dynamic-loss-scaling \
 --optimizer=Lamb \
+--fuse-mha \
 --phase1 \
 --scale-loss=1048576 \
 --learning-rate=6e-3 \

+ 1 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_pretraining_p2.sh

@@ -32,6 +32,7 @@ python3 -m paddle.distributed.launch \
 --amp \
 --use-dynamic-loss-scaling \
 --optimizer=Lamb \
+--fuse-mha \
 --phase2 \
 --scale-loss=1048576 \
 --learning-rate=4e-3 \

+ 6 - 0
PaddlePaddle/LanguageModeling/BERT/scripts/run_squad.sh

@@ -31,6 +31,7 @@ max_steps=${14:-"-1"}
 enable_benchmark=${15:-"false"}
 benchmark_steps=${16:-"100"}
 benchmark_warmup_steps=${17:-"100"}
+fuse_mha=${18:-"true"}
 
 
 echo "out dir is $OUT_DIR"
@@ -41,9 +42,13 @@ if [ ! -d "$OUT_DIR" ]; then
 fi
 
 amp=""
+FUSE_MHA=""
 if [ "$precision" = "amp" ] ; then
   echo "amp activated!"
   amp=" --amp --use-dynamic-loss-scaling --scale-loss=128.0"
+  if [ "$fuse_mha" = "true" ] ; then
+    FUSE_MHA="--fuse-mha"
+  fi
 fi
 
 CONFIG=""
@@ -119,6 +124,7 @@ CMD+=" --max-steps=$max_steps "
 CMD+=" --optimizer=AdamW "
 CMD+=" --log-freq=100 "
 CMD+=" $amp "
+CMD+=" $FUSE_MHA "
 CMD+=" $BENCH "
 CMD+=" --report-file $OUT_DIR/dllogger_${num_gpus}_${precision}.json "
 

+ 196 - 101
PaddlePaddle/LanguageModeling/BERT/utils/config.py

@@ -18,6 +18,7 @@ import argparse
 import distutils.util
 import logging
 import dllogger
+import paddle
 from utils.task import Task
 from utils.save_load import _PDOPT_SUFFIX, _PDPARAMS_SUFFIX, _PROGRESS_SUFFIX
 
@@ -27,7 +28,7 @@ _DEFAULT_BERT_CONFIG = {
     'bert-large-uncased': './bert_configs/bert-large-uncased.json',
     'bert-large-cased': './bert_configs/bert-large-cased.json',
     'bert-base-uncased': './bert_configs/bert-base-uncased.json',
-    'bert-base-cased': './bert_configs/bert-base-cased.json'
+    'bert-base-cased': './bert_configs/bert-base-cased.json',
 }
 
 
@@ -41,28 +42,34 @@ def _get_full_path_of_ckpt(args):
         pdparams_path = path_with_prefix + _PDPARAMS_SUFFIX
         progress_path = path_with_prefix + _PROGRESS_SUFFIX
         found = False
-        if os.path.exists(pdopt_path) and os.path.exists(
-                pdparams_path) and os.path.exists(progress_path):
+        if (
+            os.path.exists(pdopt_path)
+            and os.path.exists(pdparams_path)
+            and os.path.exists(progress_path)
+        ):
             found = True
         return found, pdopt_path, pdparams_path, progress_path
 
     if not os.path.exists(args.from_checkpoint):
         logging.warning(
-            f"Start training from scratch since no checkpoint is found.")
+            f"Start training from scratch since no checkpoint is found."
+        )
         args.from_checkpoint = None
         args.last_step_of_checkpoint = 0
         return
 
-    target_from_checkpoint = os.path.join(args.from_checkpoint,
-                                          args.model_prefix)
+    target_from_checkpoint = os.path.join(
+        args.from_checkpoint, args.model_prefix
+    )
     if args.last_step_of_checkpoint is None:
         args.last_step_of_checkpoint = 0
     elif args.last_step_of_checkpoint == _AUTO_LAST_EPOCH:
         folders = os.listdir(args.from_checkpoint)
         args.last_step_of_checkpoint = 0
         for folder in folders:
-            tmp_ckpt_path = os.path.join(args.from_checkpoint, folder,
-                                         args.model_prefix)
+            tmp_ckpt_path = os.path.join(
+                args.from_checkpoint, folder, args.model_prefix
+            )
 
             try:
                 folder = int(folder)
@@ -72,23 +79,32 @@ def _get_full_path_of_ckpt(args):
                 )
                 continue
 
-            if folder > args.last_step_of_checkpoint and \
-                _check_file_exist(tmp_ckpt_path)[0]:
+            if (
+                folder > args.last_step_of_checkpoint
+                and _check_file_exist(tmp_ckpt_path)[0]
+            ):
                 args.last_step_of_checkpoint = folder
-        step_with_prefix = os.path.join(str(args.last_step_of_checkpoint), args.model_prefix) \
-                            if args.last_step_of_checkpoint > 0 else args.model_prefix
-        target_from_checkpoint = os.path.join(args.from_checkpoint,
-                                              step_with_prefix)
+        step_with_prefix = (
+            os.path.join(str(args.last_step_of_checkpoint), args.model_prefix)
+            if args.last_step_of_checkpoint > 0
+            else args.model_prefix
+        )
+        target_from_checkpoint = os.path.join(
+            args.from_checkpoint, step_with_prefix
+        )
     else:
         try:
             args.last_step_of_checkpoint = int(args.last_step_of_checkpoint)
         except ValueError:
-            raise ValueError(f"The value of --last-step-of-checkpoint should be None, {_AUTO_LAST_EPOCH}"  \
-                            f" or integer >= 0, but receive {args.last_step_of_checkpoint}")
+            raise ValueError(
+                f"The value of --last-step-of-checkpoint should be None, {_AUTO_LAST_EPOCH}"
+                f" or integer >= 0, but receive {args.last_step_of_checkpoint}"
+            )
 
     args.from_checkpoint = target_from_checkpoint
     found, pdopt_path, pdparams_path, progress_path = _check_file_exist(
-        args.from_checkpoint)
+        args.from_checkpoint
+    )
     if not found:
         args.from_checkpoint = None
         args.last_step_of_checkpoint = 0
@@ -98,19 +114,28 @@ def _get_full_path_of_ckpt(args):
 
 
 def _get_full_path_of_pretrained_params(args, task=Task.pretrain):
-    if args.from_pretrained_params is None and args.from_phase1_final_params is None:
+    if (
+        args.from_pretrained_params is None
+        and args.from_phase1_final_params is None
+    ):
         args.last_step_of_checkpoint = 0
         return
-    if task == Task.pretrain and args.from_phase1_final_params is not None and args.last_step_of_checkpoint == 0:
+    if (
+        task == Task.pretrain
+        and args.from_phase1_final_params is not None
+        and args.last_step_of_checkpoint == 0
+    ):
         args.from_pretrained_params = args.from_phase1_final_params
 
-    args.from_pretrained_params = os.path.join(args.from_pretrained_params,
-                                               args.model_prefix)
+    args.from_pretrained_params = os.path.join(
+        args.from_pretrained_params, args.model_prefix
+    )
     pdparams_path = args.from_pretrained_params + _PDPARAMS_SUFFIX
     if not os.path.exists(pdparams_path):
         args.from_pretrained_params = None
         logging.warning(
-            f"Cannot find {pdparams_path}, disable --from-pretrained-params.")
+            f"Cannot find {pdparams_path}, disable --from-pretrained-params."
+        )
     args.last_step_of_checkpoint = 0
 
 
@@ -121,20 +146,28 @@ def print_args(args):
 
 def check_and_process_args(args, task=Task.pretrain):
     if task == Task.pretrain:
-        assert not (args.from_checkpoint is not None and \
-            args.from_pretrained_params is not None), \
-           "--from-pretrained-params and --from-checkpoint should " \
-           "not be set simultaneously."
-        assert not (args.phase1 and args.phase2), \
-            "--phase1 and --phase2 should not be set simultaneously in bert pretraining."
+        assert not (
+            args.from_checkpoint is not None
+            and args.from_pretrained_params is not None
+        ), (
+            "--from-pretrained-params and --from-checkpoint should "
+            "not be set simultaneously."
+        )
+        assert not (
+            args.phase1 and args.phase2
+        ), "--phase1 and --phase2 should not be set simultaneously in bert pretraining."
         if args.from_phase1_final_params is not None:
-            assert args.phase2, "--from-phase1-final-params should only be used in phase2"
+            assert (
+                args.phase2
+            ), "--from-phase1-final-params should only be used in phase2"
 
         # SQuAD finetuning does not support suspend-resume yet.(TODO)
         _get_full_path_of_ckpt(args)
 
     if args.bert_model == 'custom':
-        assert args.config_file is not None, "--config-file must be specified if --bert-model=custom"
+        assert (
+            args.config_file is not None
+        ), "--config-file must be specified if --bert-model=custom"
     elif args.config_file is None:
         args.config_file = _DEFAULT_BERT_CONFIG[args.bert_model]
         logging.info(
@@ -144,7 +177,19 @@ def check_and_process_args(args, task=Task.pretrain):
         _get_full_path_of_pretrained_params(args, task)
 
     assert os.path.isfile(
-        args.config_file), f"Cannot find config file in {args.config_file}"
+        args.config_file
+    ), f"Cannot find config file in {args.config_file}"
+
+    # cudnn mha fusion is only supported after v8.9.1 on Ampere and Hopper GPU
+    device_capability = paddle.device.cuda.get_device_capability()
+    cudnn_mha_supported = paddle.get_cudnn_version() >= 8901 and (
+        device_capability == (8, 0) or device_capability == (9, 0)
+    )
+    if (not cudnn_mha_supported or args.amp is False) and args.fuse_mha is True:
+        logging.info(
+            f"cudnn mha fusion is not supported, fall back to unfused mha"
+        )
+        args.fuse_mha = False
 
 
 def add_global_args(parser, task=Task.pretrain):
@@ -155,145 +200,165 @@ def add_global_args(parser, task=Task.pretrain):
             type=str,
             default=None,
             required=True,
-            help='The input data directory. Should be specified by users and contain .hdf5 files for the task.'
+            help='The input data directory. Should be specified by users and contain .hdf5 files for the task.',
         )
-        group.add_argument('--num-workers', default=4, type=int)
+        group.add_argument('--num-workers', default=1, type=int)
     if task == Task.squad:
         group.add_argument(
             '--train-file',
             type=str,
             default=None,
-            help='SQuAD json for training. E.g., train-v1.1.json')
+            help='SQuAD json for training. E.g., train-v1.1.json',
+        )
         group.add_argument(
             '--predict-file',
             type=str,
             default=None,
-            help='SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json'
+            help='SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json',
         )
         group.add_argument(
             "--eval-script",
             help="Script to evaluate squad predictions",
             default="evaluate.py",
-            type=str)
+            type=str,
+        )
         group.add_argument(
             '--epochs',
             type=int,
             default=3,
-            help='The number of epochs for training.')
+            help='The number of epochs for training.',
+        )
 
     group.add_argument(
         '--vocab-file',
         type=str,
         default=None,
         required=True,
-        help="Vocabulary mapping/file BERT was pretrainined on")
+        help="Vocabulary mapping/file BERT was pretrainined on",
+    )
     group.add_argument(
         '--output-dir',
         type=str,
         default=None,
         required=True,
-        help='The output directory where the model checkpoints will be written. Should be specified by users.'
+        help='The output directory where the model checkpoints will be written. Should be specified by users.',
     )
     group.add_argument(
         '--bert-model',
         type=str,
         default='bert-large-uncased',
-        choices=('bert-base-uncased', 'bert-base-cased', 'bert-large-uncased',
-                 'bert-large-cased', 'custom'),
+        choices=(
+            'bert-base-uncased',
+            'bert-base-cased',
+            'bert-large-uncased',
+            'bert-large-cased',
+            'custom',
+        ),
         help='Specifies the type of BERT model to use. If it is set as custom, '
-        'the path to the config file must be given by specifying --config-file')
+        'the path to the config file must be given by specifying --config-file',
+    )
     group.add_argument(
         '--config-file',
         type=str,
         default=None,
-        help='The BERT model config. If set to None, `<--bert-model>.json` in folder `bert_configs` will be used.'
+        help='The BERT model config. If set to None, `<--bert-model>.json` in folder `bert_configs` will be used.',
     )
     group.add_argument(
         '--max-steps',
         type=int,
         default=None,
         required=True if task == Task.pretrain else False,
-        help='Total number of training steps to perform.')
+        help='Total number of training steps to perform.',
+    )
     group.add_argument(
-        '--log-freq', type=int, default=10, help='Frequency of logging loss.')
+        '--log-freq', type=int, default=10, help='Frequency of logging loss.'
+    )
     group.add_argument(
         '--num-steps-per-checkpoint',
         type=int,
         default=100,
-        help='Number of update steps until a model checkpoint is saved to disk.'
+        help='Number of update steps until a model checkpoint is saved to disk.',
     )
     # Init model
     group.add_argument(
         '--from-pretrained-params',
         type=str,
         default=None,
-        help='Path to pretrained parameters. If set to None, no pretrained params will be used.'
+        help='Path to pretrained parameters. If set to None, no pretrained params will be used.',
     )
     group.add_argument(
         '--from-checkpoint',
         type=str,
         default=None,
-        help='A checkpoint path to resume training. If set to None, no checkpoint will be used. ' \
-             'If not None, --from-pretrained-params will be ignored.')
+        help='A checkpoint path to resume training. If set to None, no checkpoint will be used. '
+        'If not None, --from-pretrained-params will be ignored.',
+    )
     group.add_argument(
         '--last-step-of-checkpoint',
         type=str,
         default=None,
-        help='The step id of the checkpoint given by --from-checkpoint. ' \
-             'It should be None, auto, or integer > 0. If it is set as ' \
-             'None, then training will start from the 1-th epoch. If it is set as ' \
-             'auto, then it will search largest integer-convertable folder ' \
-             ' --from-checkpoint, which contains required checkpoint. '
+        help='The step id of the checkpoint given by --from-checkpoint. '
+        'It should be None, auto, or integer > 0. If it is set as '
+        'None, then training will start from the 1-th epoch. If it is set as '
+        'auto, then it will search largest integer-convertable folder '
+        ' --from-checkpoint, which contains required checkpoint. ',
     )
     if task == Task.pretrain:
         group.add_argument(
             '--from-phase1-final-params',
             type=str,
             default=None,
-            help='Path to final checkpoint of phase1, which will be used to ' \
-   'initialize the parameter in the first step of phase2, and ' \
-                 'ignored in the rest steps of phase2.'
+            help='Path to final checkpoint of phase1, which will be used to '
+            'initialize the parameter in the first step of phase2, and '
+            'ignored in the rest steps of phase2.',
         )
         group.add_argument(
             '--steps-this-run',
             type=int,
             default=None,
-            help='If provided, only run this many steps before exiting.' \
+            help='If provided, only run this many steps before exiting.',
         )
     group.add_argument(
-        '--seed', type=int, default=42, help="random seed for initialization")
+        '--seed', type=int, default=42, help="random seed for initialization"
+    )
     group.add_argument(
         '--report-file',
         type=str,
         default='./report.json',
-        help='A file in which to store JSON experiment report.')
+        help='A file in which to store JSON experiment report.',
+    )
     group.add_argument(
         '--model-prefix',
         type=str,
         default='bert_paddle',
-        help='The prefix name of model files to save/load.')
+        help='The prefix name of model files to save/load.',
+    )
     group.add_argument(
         '--show-config',
         type=distutils.util.strtobool,
         default=True,
-        help='To show arguments.')
+        help='To show arguments.',
+    )
     group.add_argument(
         '--enable-cpu-affinity',
         type=distutils.util.strtobool,
         default=True,
-        help='To enable in-built GPU-CPU affinity.')
+        help='To enable in-built GPU-CPU affinity.',
+    )
     group.add_argument(
-        '--benchmark', action='store_true', help='To enable benchmark mode.')
+        '--benchmark', action='store_true', help='To enable benchmark mode.'
+    )
     group.add_argument(
         '--benchmark-steps',
         type=int,
         default=20,
-        help='Steps for a benchmark run, only applied when --benchmark is set.')
+        help='Steps for a benchmark run, only applied when --benchmark is set.',
+    )
     group.add_argument(
         '--benchmark-warmup-steps',
         type=int,
         default=20,
-        help='Warmup steps for a benchmark run, only applied when --benchmark is set.'
+        help='Warmup steps for a benchmark run, only applied when --benchmark is set.',
     )
     return parser
 
@@ -305,145 +370,166 @@ def add_training_args(parser, task=Task.pretrain):
         default='Lamb',
         metavar="OPTIMIZER",
         choices=('Lamb', 'AdamW'),
-        help='The name of optimizer. It should be one of {Lamb, AdamW}.')
+        help='The name of optimizer. It should be one of {Lamb, AdamW}.',
+    )
     group.add_argument(
         '--gradient-merge-steps',
         type=int,
         default=1,
-        help="Number of update steps to accumualte before performing a backward/update pass."
+        help="Number of update steps to accumualte before performing a backward/update pass.",
     )
     group.add_argument(
         '--learning-rate',
         type=float,
         default=1e-4,
-        help='The initial learning rate.')
+        help='The initial learning rate.',
+    )
     group.add_argument(
         '--warmup-start-lr',
         type=float,
         default=0.0,
-        help='The initial learning rate for warmup.')
+        help='The initial learning rate for warmup.',
+    )
     group.add_argument(
         '--warmup-proportion',
         type=float,
         default=0.01,
         help='Proportion of training to perform linear learning rate warmup for. '
-        'For example, 0.1 = 10%% of training.')
+        'For example, 0.1 = 10%% of training.',
+    )
     group.add_argument(
         '--beta1',
         type=float,
         default=0.9,
-        help='The exponential decay rate for the 1st moment estimates.')
+        help='The exponential decay rate for the 1st moment estimates.',
+    )
     group.add_argument(
         '--beta2',
         type=float,
         default=0.999,
-        help='The exponential decay rate for the 2st moment estimates.')
+        help='The exponential decay rate for the 2st moment estimates.',
+    )
     group.add_argument(
         '--epsilon',
         type=float,
         default=1e-6,
-        help='A small float value for numerical stability.')
+        help='A small float value for numerical stability.',
+    )
     group.add_argument(
         '--weight-decay',
         type=float,
         default=0.01,
-        help='The weight decay coefficient.')
+        help='The weight decay coefficient.',
+    )
     group.add_argument(
         '--max-seq-length',
         default=512,
         type=int,
         help='The maximum total input sequence length after WordPiece tokenization. \n'
         'Sequences longer than this will be truncated, and sequences shorter \n'
-        'than this will be padded.')
+        'than this will be padded.',
+    )
     if task == Task.pretrain:
         group.add_argument(
             '--batch-size',
             type=int,
             default=32,
-            help='The batch size for training')
+            help='The batch size for training',
+        )
         group.add_argument(
             '--phase1',
             action='store_true',
-            help='The phase of BERT pretraining. It should not be set ' \
-                'with --phase2 at the same time.'
+            help='The phase of BERT pretraining. It should not be set '
+            'with --phase2 at the same time.',
         )
         group.add_argument(
             '--phase2',
             action='store_true',
-            help='The phase of BERT pretraining. It should not be set ' \
-                'with --phase1 at the same time.'
+            help='The phase of BERT pretraining. It should not be set '
+            'with --phase1 at the same time.',
         )
         group.add_argument(
             '--max-predictions-per-seq',
             default=80,
             type=int,
-            help='The maximum total of masked tokens in the input sequence')
+            help='The maximum total of masked tokens in the input sequence',
+        )
 
     if task == Task.squad:
         group.add_argument(
-            "--do-train", action='store_true', help="Whether to run training.")
+            "--do-train", action='store_true', help="Whether to run training."
+        )
         group.add_argument(
             "--do-predict",
             action='store_true',
-            help="Whether to run eval on the dev set.")
+            help="Whether to run eval on the dev set.",
+        )
         group.add_argument(
             "--do-eval",
             action='store_true',
-            help="Whether to use evaluate accuracy of predictions")
+            help="Whether to use evaluate accuracy of predictions",
+        )
         group.add_argument(
             "--train-batch-size",
             default=32,
             type=int,
-            help="Total batch size for training.")
+            help="Total batch size for training.",
+        )
         group.add_argument(
             "--predict-batch-size",
             default=8,
             type=int,
-            help="Total batch size for predictions.")
+            help="Total batch size for predictions.",
+        )
         group.add_argument(
             "--verbose-logging",
             action='store_true',
             help="If true, all of the warnings related to data processing will be printed. "
-            "A number of warnings are expected for a normal SQuAD evaluation.")
+            "A number of warnings are expected for a normal SQuAD evaluation.",
+        )
         group.add_argument(
             "--doc-stride",
             default=128,
             type=int,
             help="When splitting up a long document into chunks, how much stride to take "
-            "between chunks.")
+            "between chunks.",
+        )
         group.add_argument(
             "--max-query-length",
             default=64,
             type=int,
             help="The maximum number of tokens for the question. Questions longer than this "
-            "will be truncated to this length.")
+            "will be truncated to this length.",
+        )
         group.add_argument(
             "--n-best-size",
             default=20,
             type=int,
             help="The total number of n-best predictions to generate in the nbest_predictions.json "
-            "output file.")
+            "output file.",
+        )
         group.add_argument(
             "--max-answer-length",
             default=30,
             type=int,
             help="The maximum length of an answer that can be generated. This is needed because the start "
-            "and end predictions are not conditioned on one another.")
+            "and end predictions are not conditioned on one another.",
+        )
         group.add_argument(
             "--do-lower-case",
             action='store_true',
-            help="Whether to lower case the input text. True for uncased models, False for cased models."
+            help="Whether to lower case the input text. True for uncased models, False for cased models.",
         )
         group.add_argument(
             '--version-2-with-negative',
             action='store_true',
-            help='If true, the SQuAD examples contain some that do not have an answer.'
+            help='If true, the SQuAD examples contain some that do not have an answer.',
         )
         group.add_argument(
             '--null-score-diff-threshold',
             type=float,
             default=0.0,
-            help="If null_score - best_non_null is greater than the threshold predict null."
+            help="If null_score - best_non_null is greater than the threshold predict null.",
         )
     return parser
 
@@ -453,22 +539,29 @@ def add_advance_args(parser):
     group.add_argument(
         '--amp',
         action='store_true',
-        help='Enable automatic mixed precision training (AMP).')
+        help='Enable automatic mixed precision training (AMP).',
+    )
     group.add_argument(
         '--scale-loss',
         type=float,
         default=1.0,
-        help='The loss scalar for AMP training, only applied when --amp is set.'
+        help='The loss scalar for AMP training, only applied when --amp is set.',
     )
     group.add_argument(
         '--use-dynamic-loss-scaling',
         action='store_true',
-        help='Enable dynamic loss scaling in AMP training, only applied when --amp is set.'
+        help='Enable dynamic loss scaling in AMP training, only applied when --amp is set.',
     )
     group.add_argument(
         '--use-pure-fp16',
         action='store_true',
-        help='Enable pure FP16 training, only applied when --amp is set.')
+        help='Enable pure FP16 training, only applied when --amp is set.',
+    )
+    group.add_argument(
+        '--fuse-mha',
+        action='store_true',
+        help='Enable multihead attention fusion. Require cudnn version >= 8.9.1',
+    )
 
     return parser
 
@@ -476,8 +569,10 @@ def add_advance_args(parser):
 def parse_args(task=Task.pretrain):
     parser = argparse.ArgumentParser(
         description="PaddlePaddle BERT pretraining script"
-        if task == Task.pretrain else "PaddlePaddle SQuAD finetuning script",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+        if task == Task.pretrain
+        else "PaddlePaddle SQuAD finetuning script",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
 
     parser = add_global_args(parser, task)
     parser = add_training_args(parser, task)