config.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import argparse
  17. import distutils.util
  18. import logging
  19. import dllogger
  20. import paddle
  21. from utils.task import Task
  22. from utils.save_load import _PDOPT_SUFFIX, _PDPARAMS_SUFFIX, _PROGRESS_SUFFIX
  23. _AUTO_LAST_EPOCH = 'auto'
  24. _DEFAULT_BERT_CONFIG = {
  25. 'bert-large-uncased': './bert_configs/bert-large-uncased.json',
  26. 'bert-large-cased': './bert_configs/bert-large-cased.json',
  27. 'bert-base-uncased': './bert_configs/bert-base-uncased.json',
  28. 'bert-base-cased': './bert_configs/bert-base-cased.json',
  29. }
  30. def _get_full_path_of_ckpt(args):
  31. if args.from_checkpoint is None:
  32. args.last_step_of_checkpoint = 0
  33. return
  34. def _check_file_exist(path_with_prefix):
  35. pdopt_path = path_with_prefix + _PDOPT_SUFFIX
  36. pdparams_path = path_with_prefix + _PDPARAMS_SUFFIX
  37. progress_path = path_with_prefix + _PROGRESS_SUFFIX
  38. found = False
  39. if (
  40. os.path.exists(pdopt_path)
  41. and os.path.exists(pdparams_path)
  42. and os.path.exists(progress_path)
  43. ):
  44. found = True
  45. return found, pdopt_path, pdparams_path, progress_path
  46. if not os.path.exists(args.from_checkpoint):
  47. logging.warning(
  48. f"Start training from scratch since no checkpoint is found."
  49. )
  50. args.from_checkpoint = None
  51. args.last_step_of_checkpoint = 0
  52. return
  53. target_from_checkpoint = os.path.join(
  54. args.from_checkpoint, args.model_prefix
  55. )
  56. if args.last_step_of_checkpoint is None:
  57. args.last_step_of_checkpoint = 0
  58. elif args.last_step_of_checkpoint == _AUTO_LAST_EPOCH:
  59. folders = os.listdir(args.from_checkpoint)
  60. args.last_step_of_checkpoint = 0
  61. for folder in folders:
  62. tmp_ckpt_path = os.path.join(
  63. args.from_checkpoint, folder, args.model_prefix
  64. )
  65. try:
  66. folder = int(folder)
  67. except ValueError:
  68. logging.warning(
  69. f"Skip folder '{folder}' since its name is not integer-convertable."
  70. )
  71. continue
  72. if (
  73. folder > args.last_step_of_checkpoint
  74. and _check_file_exist(tmp_ckpt_path)[0]
  75. ):
  76. args.last_step_of_checkpoint = folder
  77. step_with_prefix = (
  78. os.path.join(str(args.last_step_of_checkpoint), args.model_prefix)
  79. if args.last_step_of_checkpoint > 0
  80. else args.model_prefix
  81. )
  82. target_from_checkpoint = os.path.join(
  83. args.from_checkpoint, step_with_prefix
  84. )
  85. else:
  86. try:
  87. args.last_step_of_checkpoint = int(args.last_step_of_checkpoint)
  88. except ValueError:
  89. raise ValueError(
  90. f"The value of --last-step-of-checkpoint should be None, {_AUTO_LAST_EPOCH}"
  91. f" or integer >= 0, but receive {args.last_step_of_checkpoint}"
  92. )
  93. args.from_checkpoint = target_from_checkpoint
  94. found, pdopt_path, pdparams_path, progress_path = _check_file_exist(
  95. args.from_checkpoint
  96. )
  97. if not found:
  98. args.from_checkpoint = None
  99. args.last_step_of_checkpoint = 0
  100. logging.warning(
  101. f"Cannot find {pdopt_path} and {pdparams_path} and {progress_path}, disable --from-checkpoint."
  102. )
  103. def _get_full_path_of_pretrained_params(args, task=Task.pretrain):
  104. if (
  105. args.from_pretrained_params is None
  106. and args.from_phase1_final_params is None
  107. ):
  108. args.last_step_of_checkpoint = 0
  109. return
  110. if (
  111. task == Task.pretrain
  112. and args.from_phase1_final_params is not None
  113. and args.last_step_of_checkpoint == 0
  114. ):
  115. args.from_pretrained_params = args.from_phase1_final_params
  116. args.from_pretrained_params = os.path.join(
  117. args.from_pretrained_params, args.model_prefix
  118. )
  119. pdparams_path = args.from_pretrained_params + _PDPARAMS_SUFFIX
  120. if not os.path.exists(pdparams_path):
  121. args.from_pretrained_params = None
  122. logging.warning(
  123. f"Cannot find {pdparams_path}, disable --from-pretrained-params."
  124. )
  125. args.last_step_of_checkpoint = 0
  126. def print_args(args):
  127. args_for_log = copy.deepcopy(args)
  128. dllogger.log(step='PARAMETER', data=vars(args_for_log))
  129. def check_and_process_args(args, task=Task.pretrain):
  130. if task == Task.pretrain:
  131. assert not (
  132. args.from_checkpoint is not None
  133. and args.from_pretrained_params is not None
  134. ), (
  135. "--from-pretrained-params and --from-checkpoint should "
  136. "not be set simultaneously."
  137. )
  138. assert not (
  139. args.phase1 and args.phase2
  140. ), "--phase1 and --phase2 should not be set simultaneously in bert pretraining."
  141. if args.from_phase1_final_params is not None:
  142. assert (
  143. args.phase2
  144. ), "--from-phase1-final-params should only be used in phase2"
  145. # SQuAD finetuning does not support suspend-resume yet.(TODO)
  146. _get_full_path_of_ckpt(args)
  147. if args.bert_model == 'custom':
  148. assert (
  149. args.config_file is not None
  150. ), "--config-file must be specified if --bert-model=custom"
  151. elif args.config_file is None:
  152. args.config_file = _DEFAULT_BERT_CONFIG[args.bert_model]
  153. logging.info(
  154. f"According to the name of bert_model, the default config_file: {args.config_file} will be used."
  155. )
  156. if args.from_checkpoint is None:
  157. _get_full_path_of_pretrained_params(args, task)
  158. assert os.path.isfile(
  159. args.config_file
  160. ), f"Cannot find config file in {args.config_file}"
  161. # cudnn mha fusion is only supported after v8.9.1 on Ampere and Hopper GPU
  162. device_capability = paddle.device.cuda.get_device_capability()
  163. cudnn_mha_supported = paddle.get_cudnn_version() >= 8901 and (
  164. device_capability == (8, 0) or device_capability == (9, 0)
  165. )
  166. if (not cudnn_mha_supported or args.amp is False) and args.fuse_mha is True:
  167. logging.info(
  168. f"cudnn mha fusion is not supported, fall back to unfused mha"
  169. )
  170. args.fuse_mha = False
  171. def add_global_args(parser, task=Task.pretrain):
  172. group = parser.add_argument_group('Global')
  173. if task == Task.pretrain:
  174. group.add_argument(
  175. '--input-dir',
  176. type=str,
  177. default=None,
  178. required=True,
  179. help='The input data directory. Should be specified by users and contain .hdf5 files for the task.',
  180. )
  181. group.add_argument('--num-workers', default=1, type=int)
  182. if task == Task.squad:
  183. group.add_argument(
  184. '--train-file',
  185. type=str,
  186. default=None,
  187. help='SQuAD json for training. E.g., train-v1.1.json',
  188. )
  189. group.add_argument(
  190. '--predict-file',
  191. type=str,
  192. default=None,
  193. help='SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json',
  194. )
  195. group.add_argument(
  196. "--eval-script",
  197. help="Script to evaluate squad predictions",
  198. default="evaluate.py",
  199. type=str,
  200. )
  201. group.add_argument(
  202. '--epochs',
  203. type=int,
  204. default=3,
  205. help='The number of epochs for training.',
  206. )
  207. group.add_argument(
  208. '--vocab-file',
  209. type=str,
  210. default=None,
  211. required=True,
  212. help="Vocabulary mapping/file BERT was pretrainined on",
  213. )
  214. group.add_argument(
  215. '--output-dir',
  216. type=str,
  217. default=None,
  218. required=True,
  219. help='The output directory where the model checkpoints will be written. Should be specified by users.',
  220. )
  221. group.add_argument(
  222. '--bert-model',
  223. type=str,
  224. default='bert-large-uncased',
  225. choices=(
  226. 'bert-base-uncased',
  227. 'bert-base-cased',
  228. 'bert-large-uncased',
  229. 'bert-large-cased',
  230. 'custom',
  231. ),
  232. help='Specifies the type of BERT model to use. If it is set as custom, '
  233. 'the path to the config file must be given by specifying --config-file',
  234. )
  235. group.add_argument(
  236. '--config-file',
  237. type=str,
  238. default=None,
  239. help='The BERT model config. If set to None, `<--bert-model>.json` in folder `bert_configs` will be used.',
  240. )
  241. group.add_argument(
  242. '--max-steps',
  243. type=int,
  244. default=None,
  245. required=True if task == Task.pretrain else False,
  246. help='Total number of training steps to perform.',
  247. )
  248. group.add_argument(
  249. '--log-freq', type=int, default=10, help='Frequency of logging loss.'
  250. )
  251. group.add_argument(
  252. '--num-steps-per-checkpoint',
  253. type=int,
  254. default=100,
  255. help='Number of update steps until a model checkpoint is saved to disk.',
  256. )
  257. # Init model
  258. group.add_argument(
  259. '--from-pretrained-params',
  260. type=str,
  261. default=None,
  262. help='Path to pretrained parameters. If set to None, no pretrained params will be used.',
  263. )
  264. group.add_argument(
  265. '--from-checkpoint',
  266. type=str,
  267. default=None,
  268. help='A checkpoint path to resume training. If set to None, no checkpoint will be used. '
  269. 'If not None, --from-pretrained-params will be ignored.',
  270. )
  271. group.add_argument(
  272. '--last-step-of-checkpoint',
  273. type=str,
  274. default=None,
  275. help='The step id of the checkpoint given by --from-checkpoint. '
  276. 'It should be None, auto, or integer > 0. If it is set as '
  277. 'None, then training will start from the 1-th epoch. If it is set as '
  278. 'auto, then it will search largest integer-convertable folder '
  279. ' --from-checkpoint, which contains required checkpoint. ',
  280. )
  281. if task == Task.pretrain:
  282. group.add_argument(
  283. '--from-phase1-final-params',
  284. type=str,
  285. default=None,
  286. help='Path to final checkpoint of phase1, which will be used to '
  287. 'initialize the parameter in the first step of phase2, and '
  288. 'ignored in the rest steps of phase2.',
  289. )
  290. group.add_argument(
  291. '--steps-this-run',
  292. type=int,
  293. default=None,
  294. help='If provided, only run this many steps before exiting.',
  295. )
  296. group.add_argument(
  297. '--seed', type=int, default=42, help="random seed for initialization"
  298. )
  299. group.add_argument(
  300. '--report-file',
  301. type=str,
  302. default='./report.json',
  303. help='A file in which to store JSON experiment report.',
  304. )
  305. group.add_argument(
  306. '--model-prefix',
  307. type=str,
  308. default='bert_paddle',
  309. help='The prefix name of model files to save/load.',
  310. )
  311. group.add_argument(
  312. '--show-config',
  313. type=distutils.util.strtobool,
  314. default=True,
  315. help='To show arguments.',
  316. )
  317. group.add_argument(
  318. '--enable-cpu-affinity',
  319. type=distutils.util.strtobool,
  320. default=True,
  321. help='To enable in-built GPU-CPU affinity.',
  322. )
  323. group.add_argument(
  324. '--benchmark', action='store_true', help='To enable benchmark mode.'
  325. )
  326. group.add_argument(
  327. '--benchmark-steps',
  328. type=int,
  329. default=20,
  330. help='Steps for a benchmark run, only applied when --benchmark is set.',
  331. )
  332. group.add_argument(
  333. '--benchmark-warmup-steps',
  334. type=int,
  335. default=20,
  336. help='Warmup steps for a benchmark run, only applied when --benchmark is set.',
  337. )
  338. return parser
  339. def add_training_args(parser, task=Task.pretrain):
  340. group = parser.add_argument_group('Training')
  341. group.add_argument(
  342. '--optimizer',
  343. default='Lamb',
  344. metavar="OPTIMIZER",
  345. choices=('Lamb', 'AdamW'),
  346. help='The name of optimizer. It should be one of {Lamb, AdamW}.',
  347. )
  348. group.add_argument(
  349. '--gradient-merge-steps',
  350. type=int,
  351. default=1,
  352. help="Number of update steps to accumualte before performing a backward/update pass.",
  353. )
  354. group.add_argument(
  355. '--learning-rate',
  356. type=float,
  357. default=1e-4,
  358. help='The initial learning rate.',
  359. )
  360. group.add_argument(
  361. '--warmup-start-lr',
  362. type=float,
  363. default=0.0,
  364. help='The initial learning rate for warmup.',
  365. )
  366. group.add_argument(
  367. '--warmup-proportion',
  368. type=float,
  369. default=0.01,
  370. help='Proportion of training to perform linear learning rate warmup for. '
  371. 'For example, 0.1 = 10%% of training.',
  372. )
  373. group.add_argument(
  374. '--beta1',
  375. type=float,
  376. default=0.9,
  377. help='The exponential decay rate for the 1st moment estimates.',
  378. )
  379. group.add_argument(
  380. '--beta2',
  381. type=float,
  382. default=0.999,
  383. help='The exponential decay rate for the 2st moment estimates.',
  384. )
  385. group.add_argument(
  386. '--epsilon',
  387. type=float,
  388. default=1e-6,
  389. help='A small float value for numerical stability.',
  390. )
  391. group.add_argument(
  392. '--weight-decay',
  393. type=float,
  394. default=0.01,
  395. help='The weight decay coefficient.',
  396. )
  397. group.add_argument(
  398. '--max-seq-length',
  399. default=512,
  400. type=int,
  401. help='The maximum total input sequence length after WordPiece tokenization. \n'
  402. 'Sequences longer than this will be truncated, and sequences shorter \n'
  403. 'than this will be padded.',
  404. )
  405. if task == Task.pretrain:
  406. group.add_argument(
  407. '--batch-size',
  408. type=int,
  409. default=32,
  410. help='The batch size for training',
  411. )
  412. group.add_argument(
  413. '--phase1',
  414. action='store_true',
  415. help='The phase of BERT pretraining. It should not be set '
  416. 'with --phase2 at the same time.',
  417. )
  418. group.add_argument(
  419. '--phase2',
  420. action='store_true',
  421. help='The phase of BERT pretraining. It should not be set '
  422. 'with --phase1 at the same time.',
  423. )
  424. group.add_argument(
  425. '--max-predictions-per-seq',
  426. default=80,
  427. type=int,
  428. help='The maximum total of masked tokens in the input sequence',
  429. )
  430. if task == Task.squad:
  431. group.add_argument(
  432. "--do-train", action='store_true', help="Whether to run training."
  433. )
  434. group.add_argument(
  435. "--do-predict",
  436. action='store_true',
  437. help="Whether to run eval on the dev set.",
  438. )
  439. group.add_argument(
  440. "--do-eval",
  441. action='store_true',
  442. help="Whether to use evaluate accuracy of predictions",
  443. )
  444. group.add_argument(
  445. "--train-batch-size",
  446. default=32,
  447. type=int,
  448. help="Total batch size for training.",
  449. )
  450. group.add_argument(
  451. "--predict-batch-size",
  452. default=8,
  453. type=int,
  454. help="Total batch size for predictions.",
  455. )
  456. group.add_argument(
  457. "--verbose-logging",
  458. action='store_true',
  459. help="If true, all of the warnings related to data processing will be printed. "
  460. "A number of warnings are expected for a normal SQuAD evaluation.",
  461. )
  462. group.add_argument(
  463. "--doc-stride",
  464. default=128,
  465. type=int,
  466. help="When splitting up a long document into chunks, how much stride to take "
  467. "between chunks.",
  468. )
  469. group.add_argument(
  470. "--max-query-length",
  471. default=64,
  472. type=int,
  473. help="The maximum number of tokens for the question. Questions longer than this "
  474. "will be truncated to this length.",
  475. )
  476. group.add_argument(
  477. "--n-best-size",
  478. default=20,
  479. type=int,
  480. help="The total number of n-best predictions to generate in the nbest_predictions.json "
  481. "output file.",
  482. )
  483. group.add_argument(
  484. "--max-answer-length",
  485. default=30,
  486. type=int,
  487. help="The maximum length of an answer that can be generated. This is needed because the start "
  488. "and end predictions are not conditioned on one another.",
  489. )
  490. group.add_argument(
  491. "--do-lower-case",
  492. action='store_true',
  493. help="Whether to lower case the input text. True for uncased models, False for cased models.",
  494. )
  495. group.add_argument(
  496. '--version-2-with-negative',
  497. action='store_true',
  498. help='If true, the SQuAD examples contain some that do not have an answer.',
  499. )
  500. group.add_argument(
  501. '--null-score-diff-threshold',
  502. type=float,
  503. default=0.0,
  504. help="If null_score - best_non_null is greater than the threshold predict null.",
  505. )
  506. return parser
  507. def add_advance_args(parser):
  508. group = parser.add_argument_group('Advanced Training')
  509. group.add_argument(
  510. '--amp',
  511. action='store_true',
  512. help='Enable automatic mixed precision training (AMP).',
  513. )
  514. group.add_argument(
  515. '--scale-loss',
  516. type=float,
  517. default=1.0,
  518. help='The loss scalar for AMP training, only applied when --amp is set.',
  519. )
  520. group.add_argument(
  521. '--use-dynamic-loss-scaling',
  522. action='store_true',
  523. help='Enable dynamic loss scaling in AMP training, only applied when --amp is set.',
  524. )
  525. group.add_argument(
  526. '--use-pure-fp16',
  527. action='store_true',
  528. help='Enable pure FP16 training, only applied when --amp is set.',
  529. )
  530. group.add_argument(
  531. '--fuse-mha',
  532. action='store_true',
  533. help='Enable multihead attention fusion. Require cudnn version >= 8.9.1',
  534. )
  535. return parser
  536. def parse_args(task=Task.pretrain):
  537. parser = argparse.ArgumentParser(
  538. description="PaddlePaddle BERT pretraining script"
  539. if task == Task.pretrain
  540. else "PaddlePaddle SQuAD finetuning script",
  541. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  542. )
  543. parser = add_global_args(parser, task)
  544. parser = add_training_args(parser, task)
  545. parser = add_advance_args(parser)
  546. args = parser.parse_args()
  547. check_and_process_args(args, task)
  548. return args