program.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import logging
  17. import shutil
  18. import paddle
  19. import paddle.distributed.fleet as fleet
  20. from modeling import BertForPretraining, BertConfig
  21. from loss import BertPretrainingCriterion
  22. from utils.save_load import save_model
  23. from utils.utility import get_trainer_id
  24. from lr_scheduler import build_lr_scheduler
  25. from optimizer import build_optimizer
  26. import dllogger
  27. def create_pretraining_data_holder():
  28. input_ids = paddle.static.data(
  29. name="input_ids", shape=[-1, -1], dtype="int64")
  30. token_type_ids = paddle.static.data(
  31. name="token_type_ids", shape=[-1, -1], dtype="int64")
  32. attention_mask = paddle.static.data(
  33. name="attention_mask", shape=[-1, 1, 1, -1], dtype="int64")
  34. next_sentence_labels = paddle.static.data(
  35. name="next_sentence_labels", shape=[-1, 1], dtype="int64")
  36. masked_lm_labels = paddle.static.data(
  37. name="masked_lm_labels", shape=[-1, -1], dtype="int64")
  38. return [
  39. input_ids, token_type_ids, attention_mask, next_sentence_labels,
  40. masked_lm_labels
  41. ]
  42. def create_strategy(args, use_distributed_fused_lamb=False):
  43. """
  44. Create paddle.static.BuildStrategy and paddle.static.ExecutionStrategy with arguments.
  45. Args:
  46. args(Namespace): Arguments obtained from ArgumentParser.
  47. use_distributed_fused_lamb(bool, optional): Whether to use distributed fused lamb.
  48. Returns:
  49. build_strategy(paddle.static.BuildStrategy): A instance of BuildStrategy.
  50. exec_strategy(paddle.static.ExecutionStrategy): A instance of ExecutionStrategy.
  51. """
  52. build_strategy = paddle.static.BuildStrategy()
  53. exec_strategy = paddle.static.ExecutionStrategy()
  54. build_strategy.enable_addto = True
  55. if args.amp:
  56. build_strategy.fuse_gemm_epilogue = True
  57. build_strategy.fuse_dot_product_attention = args.fuse_mha
  58. if use_distributed_fused_lamb:
  59. build_strategy.fuse_all_reduce_ops = False
  60. build_strategy.reduce_strategy = paddle.static.BuildStrategy.ReduceStrategy._NoReduce
  61. else:
  62. build_strategy.fuse_all_reduce_ops = True
  63. build_strategy.reduce_strategy = paddle.static.BuildStrategy.ReduceStrategy.AllReduce
  64. exec_strategy.num_threads = 1
  65. exec_strategy.num_iteration_per_drop_scope = 10000
  66. return build_strategy, exec_strategy
  67. def dist_optimizer(args, optimizer):
  68. """
  69. Create a distributed optimizer based on a given optimizer.
  70. Args:
  71. args(Namespace): Arguments obtained from ArgumentParser.
  72. optimizer(paddle.optimizer): A normal optimizer.
  73. Returns:
  74. optimizer(fleet.distributed_optimizer): A distributed optimizer.
  75. """
  76. use_distributed_fused_lamb = True if args.optimizer == 'DistributedFusedLamb' else False
  77. build_strategy, exec_strategy = create_strategy(args,
  78. use_distributed_fused_lamb)
  79. dist_strategy = fleet.DistributedStrategy()
  80. if use_distributed_fused_lamb:
  81. dist_strategy.gradient_scale_configs = {'scale_strategy': 'sum'}
  82. dist_strategy.execution_strategy = exec_strategy
  83. dist_strategy.build_strategy = build_strategy
  84. if use_distributed_fused_lamb:
  85. dist_strategy.fuse_all_reduce_ops = False
  86. else:
  87. dist_strategy.fuse_all_reduce_ops = True
  88. dist_strategy.fuse_grad_size_in_MB = 0
  89. if args.amp:
  90. dist_strategy.amp = True
  91. custom_white_list = ['softmax', 'layer_norm', 'gelu']
  92. custom_black_list = ['lookup_table',
  93. 'lookup_table_v2'] if args.use_pure_fp16 else None
  94. dist_strategy.amp_configs = {
  95. 'custom_white_list': custom_white_list,
  96. 'custom_black_list': custom_black_list,
  97. 'init_loss_scaling': args.scale_loss,
  98. 'use_dynamic_loss_scaling': True,
  99. 'incr_every_n_steps': 2000,
  100. 'decr_every_n_nan_or_inf': 1,
  101. 'incr_ratio': 2.0,
  102. 'decr_ratio': 0.5,
  103. 'use_pure_fp16': args.use_pure_fp16,
  104. 'use_fp16_guard': args.use_pure_fp16
  105. }
  106. if not use_distributed_fused_lamb and args.gradient_merge_steps > 1:
  107. dist_strategy.gradient_merge = True
  108. dist_strategy.gradient_merge_configs = {
  109. 'k_steps': args.gradient_merge_steps
  110. }
  111. optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
  112. return optimizer
  113. def build(args, main_prog, startup_prog, is_train=True):
  114. """
  115. Build a executable paddle.static.Program via following 3 steps:
  116. 1. Create feeds.
  117. 2. Create model.
  118. 3. Create loss.
  119. 4. Create optimizer if is_train==True.
  120. Args:
  121. args(Namespace): Arguments obtained from ArgumentParser.
  122. main_prog(paddle.static.Program):The main program.
  123. startup_prog(paddle.static.Program):The startup program.
  124. is_train(bool, optional): Whether the main programe created is for training. Default: True.
  125. Returns:
  126. model(paddle.nn.Layer): An instance of BERT Model defined in modeling.py.
  127. lr_scheduler(paddle.optimizer.lr.LRScheduler): A learning rate scheduler.
  128. optimizer(Optimizer): An optimizer with distributed/AMP strategy.
  129. loss(variable): The output variable of loss function.
  130. feeds(dict): A dict of mapping variables' names to their values
  131. """
  132. with paddle.static.program_guard(main_prog, startup_prog):
  133. with paddle.utils.unique_name.guard():
  134. feeds = create_pretraining_data_holder()
  135. [
  136. input_ids, token_type_ids, attention_mask,
  137. next_sentence_labels, masked_lm_labels
  138. ] = feeds
  139. bert_config = BertConfig.from_json_file(args.config_file)
  140. if bert_config.vocab_size % 8 != 0:
  141. bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
  142. bert_config.fuse_mha = args.fuse_mha
  143. model = BertForPretraining(bert_config)
  144. criterion = BertPretrainingCriterion(bert_config.vocab_size)
  145. prediction_scores, seq_relationship_score = model(
  146. input_ids=input_ids,
  147. token_type_ids=token_type_ids,
  148. attention_mask=attention_mask,
  149. masked_lm_labels=masked_lm_labels)
  150. loss = criterion(prediction_scores, seq_relationship_score,
  151. masked_lm_labels, next_sentence_labels)
  152. lr_scheduler = None
  153. optimizer = None
  154. if is_train:
  155. lr_scheduler = build_lr_scheduler(args)
  156. optimizer = build_optimizer(args, lr_scheduler)
  157. optimizer = dist_optimizer(args, optimizer)
  158. optimizer.minimize(loss)
  159. return model, lr_scheduler, optimizer, loss, feeds
  160. def run(exe,
  161. program,
  162. args,
  163. lr_scheduler,
  164. loss,
  165. train_dataloader,
  166. progress=None):
  167. """
  168. Execute program.
  169. Args:
  170. exe(paddle.static.Executor): A executor to run program.
  171. program(paddle.static.Program): The program to be executed.
  172. args(Namespace): Arguments obtained from ArgumentParser.
  173. lr_scheduler(paddle.optimizer.lr.LRScheduler): A learning rate scheduler.
  174. Default: None.
  175. loss(variable): The output variable of loss function.
  176. progress(dict, optional): A dict to record the training progress of checkpoint.
  177. Returns:
  178. global_step(int): Final step id of this run.
  179. loss_return(float): Final loss of this run.
  180. train_time_raw(float): Time to train of this run.
  181. """
  182. trainer_id = get_trainer_id()
  183. batch_size_per_gpu = args.batch_size
  184. log_steps = args.log_freq
  185. save_steps = args.num_steps_per_checkpoint
  186. gradient_merge_steps = args.gradient_merge_steps
  187. most_recent_ckpts_paths = []
  188. last_step = args.last_step_of_checkpoint
  189. train_iter = 0
  190. epoch = 0
  191. train_time_raw = 0
  192. if progress is None:
  193. progress = dict()
  194. else:
  195. epoch = progress.get('epoch', 0)
  196. global_step = 0 + last_step
  197. logging.info(f"Training will start at the {last_step+1}th step")
  198. max_steps = args.max_steps
  199. steps_this_run = max_steps
  200. if args.steps_this_run is not None:
  201. if args.steps_this_run + last_step > max_steps:
  202. logging.info(
  203. f"Only {max_steps - last_step} steps will be performed in this run due to the limit of --max-steps."
  204. )
  205. else:
  206. steps_this_run = args.steps_this_run
  207. max_steps = steps_this_run + last_step
  208. logging.warning(
  209. f"{steps_this_run} steps will be performed in this run.")
  210. if args.benchmark:
  211. max_steps = args.benchmark_warmup_steps + args.benchmark_steps + last_step
  212. total_samples = 0
  213. raw_train_start = time.time()
  214. step_start = time.time()
  215. avg_loss = 0
  216. while True:
  217. for batch in train_dataloader:
  218. train_iter += 1
  219. loss_return = exe.run(program, feed=batch, fetch_list=[loss])
  220. total_samples += batch_size_per_gpu
  221. avg_loss += loss_return[0].item()
  222. lr = lr_scheduler.get_lr()
  223. if train_iter % (log_steps * gradient_merge_steps) == 0:
  224. step_cost = time.time() - step_start
  225. dllogger_it_data = {
  226. 'loss': avg_loss / gradient_merge_steps,
  227. 'learning_rate': lr,
  228. 'step_cost': step_cost,
  229. 'step_samples': total_samples,
  230. 'seqs_per_sec': total_samples / step_cost,
  231. }
  232. dllogger.log((epoch, global_step + 1), data=dllogger_it_data)
  233. total_samples = 0
  234. step_start = time.time()
  235. if train_iter % gradient_merge_steps == 0:
  236. global_step += 1
  237. lr_scheduler.step()
  238. avg_loss = 0
  239. if args.benchmark and train_iter == (args.benchmark_warmup_steps *
  240. gradient_merge_steps):
  241. raw_train_start = time.time()
  242. if train_iter % (save_steps * gradient_merge_steps
  243. ) == 0 or global_step >= max_steps:
  244. train_time_raw = time.time() - raw_train_start
  245. if trainer_id == 0:
  246. model_path = os.path.join(
  247. args.output_dir, args.bert_model, "phase1"
  248. if args.phase1 else "phase2", f"{global_step}")
  249. progress = {
  250. 'epoch': epoch,
  251. 'global_step': global_step,
  252. 'phase': 1 if args.phase1 else 2,
  253. }
  254. save_model(program, model_path, args.model_prefix,
  255. progress)
  256. most_recent_ckpts_paths.append(model_path)
  257. if len(most_recent_ckpts_paths) > 3:
  258. ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
  259. shutil.rmtree(ckpt_to_be_removed)
  260. if global_step >= max_steps:
  261. actual_steps_this_run = global_step - last_step
  262. return global_step, actual_steps_this_run, loss_return[0].item(), train_time_raw
  263. epoch += 1