config.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import copy
  16. import argparse
  17. import logging
  18. import distutils.util
  19. import dllogger
  20. from utils.mode import RunScope
  21. from utils.utility import get_num_trainers
  22. from utils.save_load import _PDOPT_SUFFIX, _PDPARAMS_SUFFIX
  23. _AUTO_LAST_EPOCH = 'auto'
  24. def _get_full_path_of_ckpt(args):
  25. if args.from_checkpoint is None:
  26. args.last_epoch_of_checkpoint = -1
  27. return
  28. def _check_file_exist(path_with_prefix):
  29. pdopt_path = path_with_prefix + _PDOPT_SUFFIX
  30. pdparams_path = path_with_prefix + _PDPARAMS_SUFFIX
  31. found = False
  32. if os.path.exists(pdopt_path) and os.path.exists(pdparams_path):
  33. found = True
  34. return found, pdopt_path, pdparams_path
  35. target_from_checkpoint = os.path.join(args.from_checkpoint,
  36. args.model_prefix)
  37. if args.last_epoch_of_checkpoint is None:
  38. args.last_epoch_of_checkpoint = -1
  39. elif args.last_epoch_of_checkpoint == _AUTO_LAST_EPOCH:
  40. folders = os.listdir(args.from_checkpoint)
  41. args.last_epoch_of_checkpoint = -1
  42. for folder in folders:
  43. tmp_ckpt_path = os.path.join(args.from_checkpoint, folder,
  44. args.model_prefix)
  45. try:
  46. folder = int(folder)
  47. except ValueError:
  48. logging.warning(
  49. f"Skip folder '{folder}' since its name is not integer-convertable."
  50. )
  51. continue
  52. if folder > args.last_epoch_of_checkpoint and \
  53. _check_file_exist(tmp_ckpt_path)[0]:
  54. args.last_epoch_of_checkpoint = folder
  55. epoch_with_prefix = os.path.join(str(args.last_epoch_of_checkpoint), args.model_prefix) \
  56. if args.last_epoch_of_checkpoint > -1 else args.model_prefix
  57. target_from_checkpoint = os.path.join(args.from_checkpoint,
  58. epoch_with_prefix)
  59. else:
  60. try:
  61. args.last_epoch_of_checkpoint = int(args.last_epoch_of_checkpoint)
  62. except ValueError:
  63. raise ValueError(f"The value of --last-epoch-of-checkpoint should be None, {_AUTO_LAST_EPOCH}" \
  64. f" or integer >= 0, but receive {args.last_epoch_of_checkpoint}")
  65. args.from_checkpoint = target_from_checkpoint
  66. found, pdopt_path, pdparams_path = _check_file_exist(args.from_checkpoint)
  67. if not found:
  68. args.from_checkpoint = None
  69. args.last_epoch_of_checkpoint = -1
  70. logging.warning(
  71. f"Cannot find {pdopt_path} and {pdparams_path}, disable --from-checkpoint."
  72. )
  73. def _get_full_path_of_pretrained_params(args):
  74. if args.from_pretrained_params is None:
  75. args.last_epoch_of_checkpoint = -1
  76. return
  77. args.from_pretrained_params = os.path.join(args.from_pretrained_params,
  78. args.model_prefix)
  79. pdparams_path = args.from_pretrained_params + _PDPARAMS_SUFFIX
  80. if not os.path.exists(pdparams_path):
  81. args.from_pretrained_params = None
  82. logging.warning(
  83. f"Cannot find {pdparams_path}, disable --from-pretrained-params.")
  84. args.last_epoch_of_checkpoint = -1
  85. def print_args(args):
  86. args_for_log = copy.deepcopy(args)
  87. # Due to dllogger cannot serialize Enum into JSON.
  88. if hasattr(args_for_log, 'run_scope'):
  89. args_for_log.run_scope = args_for_log.run_scope.value
  90. dllogger.log(step='PARAMETER', data=vars(args_for_log))
  91. def check_and_process_args(args):
  92. # Precess the scope of run
  93. run_scope = None
  94. for scope in RunScope:
  95. if args.run_scope == scope.value:
  96. run_scope = scope
  97. break
  98. assert run_scope is not None, \
  99. f"only support {[scope.value for scope in RunScope]} as run_scope"
  100. args.run_scope = run_scope
  101. # Precess image layout and channel
  102. args.image_channel = args.image_shape[0]
  103. if args.data_layout == "NHWC":
  104. args.image_shape = [
  105. args.image_shape[1], args.image_shape[2], args.image_shape[0]
  106. ]
  107. # Precess learning rate
  108. args.lr = get_num_trainers() * args.lr
  109. # Precess model loading
  110. assert not (args.from_checkpoint is not None and \
  111. args.from_pretrained_params is not None), \
  112. "--from-pretrained-params and --from-checkpoint should " \
  113. "not be set simultaneously."
  114. _get_full_path_of_pretrained_params(args)
  115. _get_full_path_of_ckpt(args)
  116. args.start_epoch = args.last_epoch_of_checkpoint + 1
  117. # Precess benchmark
  118. if args.benchmark:
  119. assert args.run_scope in [
  120. RunScope.TRAIN_ONLY, RunScope.EVAL_ONLY
  121. ], "If benchmark enabled, run_scope must be `train_only` or `eval_only`"
  122. # Only run one epoch when benchmark or eval_only.
  123. if args.benchmark or \
  124. (args.run_scope == RunScope.EVAL_ONLY):
  125. args.epochs = args.start_epoch + 1
  126. if args.run_scope == RunScope.EVAL_ONLY:
  127. args.eval_interval = 1
  128. def add_general_args(parser):
  129. group = parser.add_argument_group('General')
  130. group.add_argument(
  131. '--checkpoint-dir',
  132. type=str,
  133. default='./checkpoint/',
  134. help='A path to store trained models.')
  135. group.add_argument(
  136. '--inference-dir',
  137. type=str,
  138. default='./inference/',
  139. help='A path to store inference model once the training is finished.'
  140. )
  141. group.add_argument(
  142. '--run-scope',
  143. default='train_eval',
  144. choices=('train_eval', 'train_only', 'eval_only'),
  145. help='Running scope. It should be one of {train_eval, train_only, eval_only}.'
  146. )
  147. group.add_argument(
  148. '--epochs',
  149. type=int,
  150. default=90,
  151. help='The number of epochs for training.')
  152. group.add_argument(
  153. '--save-interval',
  154. type=int,
  155. default=1,
  156. help='The iteration interval to save checkpoints.')
  157. group.add_argument(
  158. '--eval-interval',
  159. type=int,
  160. default=1,
  161. help='The iteration interval to test trained models on a given validation dataset. ' \
  162. 'Ignored when --run-scope is train_only.'
  163. )
  164. group.add_argument(
  165. '--print-interval',
  166. type=int,
  167. default=10,
  168. help='The iteration interval to show training/evaluation message.')
  169. group.add_argument(
  170. '--report-file',
  171. type=str,
  172. default='./train.json',
  173. help='A file in which to store JSON experiment report.')
  174. group.add_argument(
  175. '--benchmark', action='store_true', help='To enable benchmark mode.')
  176. group.add_argument(
  177. '--benchmark-steps',
  178. type=int,
  179. default=100,
  180. help='Steps for benchmark run, only be applied when --benchmark is set.'
  181. )
  182. group.add_argument(
  183. '--benchmark-warmup-steps',
  184. type=int,
  185. default=100,
  186. help='Warmup steps for benchmark run, only be applied when --benchmark is set.'
  187. )
  188. group.add_argument(
  189. '--model-prefix',
  190. type=str,
  191. default="resnet_50_paddle",
  192. help='The prefix name of model files to save/load.')
  193. group.add_argument(
  194. '--from-pretrained-params',
  195. type=str,
  196. default=None,
  197. help='A folder path which contains pretrained parameters, that is a file in name' \
  198. ' --model-prefix + .pdparams. It should not be set with --from-checkpoint' \
  199. ' at the same time.'
  200. )
  201. group.add_argument(
  202. '--from-checkpoint',
  203. type=str,
  204. default=None,
  205. help='A checkpoint path to resume training. It should not be set ' \
  206. 'with --from-pretrained-params at the same time. The path provided ' \
  207. 'could be a folder contains < epoch_id/ckpt_files > or < ckpt_files >.'
  208. )
  209. group.add_argument(
  210. '--last-epoch-of-checkpoint',
  211. type=str,
  212. default=None,
  213. help='The epoch id of the checkpoint given by --from-checkpoint. ' \
  214. 'It should be None, auto or integer >= 0. If it is set as ' \
  215. 'None, then training will start from 0-th epoch. If it is set as ' \
  216. 'auto, then it will search largest integer-convertable folder ' \
  217. ' --from-checkpoint, which contains required checkpoint. ' \
  218. 'Default is None.'
  219. )
  220. group.add_argument(
  221. '--show-config',
  222. type=distutils.util.strtobool,
  223. default=True,
  224. help='To show arguments.')
  225. group.add_argument(
  226. '--enable-cpu-affinity',
  227. type=distutils.util.strtobool,
  228. default=True,
  229. help='To enable in-built GPU-CPU affinity.')
  230. return parser
  231. def add_advance_args(parser):
  232. group = parser.add_argument_group('Advanced Training')
  233. # AMP
  234. group.add_argument(
  235. '--amp',
  236. action='store_true',
  237. help='Enable automatic mixed precision training (AMP).')
  238. group.add_argument(
  239. '--scale-loss',
  240. type=float,
  241. default=1.0,
  242. help='The loss scalar for AMP training, only be applied when --amp is set.'
  243. )
  244. group.add_argument(
  245. '--use-dynamic-loss-scaling',
  246. action='store_true',
  247. help='Enable dynamic loss scaling in AMP training, only be applied when --amp is set.'
  248. )
  249. group.add_argument(
  250. '--use-pure-fp16',
  251. action='store_true',
  252. help='Enable pure FP16 training, only be applied when --amp is set.')
  253. group.add_argument(
  254. '--fuse-resunit',
  255. action='store_true',
  256. help='Enable CUDNNv8 ResUnit fusion, only be applied when --amp is set.')
  257. # ASP
  258. group.add_argument(
  259. '--asp',
  260. action='store_true',
  261. help='Enable automatic sparse training (ASP).')
  262. group.add_argument(
  263. '--prune-model',
  264. action='store_true',
  265. help='Prune model to 2:4 sparse pattern, only be applied when --asp is set.'
  266. )
  267. group.add_argument(
  268. '--mask-algo',
  269. default='mask_1d',
  270. choices=('mask_1d', 'mask_2d_greedy', 'mask_2d_best'),
  271. help='The algorithm to generate sparse masks. It should be one of ' \
  272. '{mask_1d, mask_2d_greedy, mask_2d_best}. This only be applied ' \
  273. 'when --asp and --prune-model is set.'
  274. )
  275. # QAT
  276. group.add_argument(
  277. '--qat',
  278. action='store_true',
  279. help='Enable quantization aware training (QAT).')
  280. return parser
  281. def add_dataset_args(parser):
  282. def float_list(x):
  283. return list(map(float, x.split(',')))
  284. def int_list(x):
  285. return list(map(int, x.split(',')))
  286. dataset_group = parser.add_argument_group('Dataset')
  287. dataset_group.add_argument(
  288. '--image-root',
  289. type=str,
  290. default='/imagenet',
  291. help='A root folder of train/val images. It should contain train and val folders, ' \
  292. 'which store corresponding images.'
  293. )
  294. dataset_group.add_argument(
  295. '--image-shape',
  296. type=int_list,
  297. default=[4, 224, 224],
  298. help='The image shape. Its shape should be [channel, height, width].')
  299. # Data Loader
  300. dataset_group.add_argument(
  301. '--batch-size',
  302. type=int,
  303. default=256,
  304. help='The batch size for both training and evaluation.')
  305. dataset_group.add_argument(
  306. '--dali-random-seed',
  307. type=int,
  308. default=42,
  309. help='The random seed for DALI data loader.')
  310. dataset_group.add_argument(
  311. '--dali-num-threads',
  312. type=int,
  313. default=4,
  314. help='The number of threads applied to DALI data loader.')
  315. dataset_group.add_argument(
  316. '--dali-output-fp16',
  317. action='store_true',
  318. help='Output FP16 data from DALI data loader.')
  319. # Augmentation
  320. augmentation_group = parser.add_argument_group('Data Augmentation')
  321. augmentation_group.add_argument(
  322. '--crop-size',
  323. type=int,
  324. default=224,
  325. help='The size to crop input images.')
  326. augmentation_group.add_argument(
  327. '--rand-crop-scale',
  328. type=float_list,
  329. default=[0.08, 1.],
  330. help='Range from which to choose a random area fraction.')
  331. augmentation_group.add_argument(
  332. '--rand-crop-ratio',
  333. type=float_list,
  334. default=[3.0 / 4, 4.0 / 3],
  335. help='Range from which to choose a random aspect ratio (width/height).')
  336. augmentation_group.add_argument(
  337. '--normalize-scale',
  338. type=float,
  339. default=1.0 / 255.0,
  340. help='A scalar to normalize images.')
  341. augmentation_group.add_argument(
  342. '--normalize-mean',
  343. type=float_list,
  344. default=[0.485, 0.456, 0.406],
  345. help='The mean values to normalize RGB images.')
  346. augmentation_group.add_argument(
  347. '--normalize-std',
  348. type=float_list,
  349. default=[0.229, 0.224, 0.225],
  350. help='The std values to normalize RGB images.')
  351. augmentation_group.add_argument(
  352. '--resize-short',
  353. type=int,
  354. default=256,
  355. help='The length of the shorter dimension of the resized image.')
  356. return parser
  357. def add_model_args(parser):
  358. group = parser.add_argument_group('Model')
  359. group.add_argument(
  360. '--model-arch-name',
  361. type=str,
  362. default='ResNet50',
  363. help='The model architecture name. It should be one of {ResNet50}.')
  364. group.add_argument(
  365. '--num-of-class',
  366. type=int,
  367. default=1000,
  368. help='The number classes of images.')
  369. group.add_argument(
  370. '--data-layout',
  371. default='NCHW',
  372. choices=('NCHW', 'NHWC'),
  373. help='Data format. It should be one of {NCHW, NHWC}.')
  374. group.add_argument(
  375. '--bn-weight-decay',
  376. action='store_true',
  377. help='Apply weight decay to BatchNorm shift and scale.')
  378. return parser
  379. def add_training_args(parser):
  380. group = parser.add_argument_group('Training')
  381. group.add_argument(
  382. '--label-smoothing',
  383. type=float,
  384. default=0.1,
  385. help='The ratio of label smoothing.')
  386. group.add_argument(
  387. '--optimizer',
  388. default='Momentum',
  389. metavar="OPTIMIZER",
  390. choices=('Momentum'),
  391. help='The name of optimizer. It should be one of {Momentum}.')
  392. group.add_argument(
  393. '--momentum',
  394. type=float,
  395. default=0.875,
  396. help='The momentum value of optimizer.')
  397. group.add_argument(
  398. '--weight-decay',
  399. type=float,
  400. default=3.0517578125e-05,
  401. help='The coefficient of weight decay.')
  402. group.add_argument(
  403. '--lr-scheduler',
  404. default='Cosine',
  405. metavar="LR_SCHEDULER",
  406. choices=('Cosine'),
  407. help='The name of learning rate scheduler. It should be one of {Cosine}.'
  408. )
  409. group.add_argument(
  410. '--lr', type=float, default=0.256, help='The initial learning rate.')
  411. group.add_argument(
  412. '--warmup-epochs',
  413. type=int,
  414. default=5,
  415. help='The number of epochs for learning rate warmup.')
  416. group.add_argument(
  417. '--warmup-start-lr',
  418. type=float,
  419. default=0.0,
  420. help='The initial learning rate for warmup.')
  421. return parser
  422. def add_trt_args(parser):
  423. def int_list(x):
  424. return list(map(int, x.split(',')))
  425. group = parser.add_argument_group('Paddle-TRT')
  426. group.add_argument(
  427. '--device',
  428. type=int,
  429. default='0',
  430. help='The GPU device id for Paddle-TRT inference.'
  431. )
  432. group.add_argument(
  433. '--inference-dir',
  434. type=str,
  435. default='./inference',
  436. help='A path to load inference models.'
  437. )
  438. group.add_argument(
  439. '--data-layout',
  440. default='NCHW',
  441. choices=('NCHW', 'NHWC'),
  442. help='Data format. It should be one of {NCHW, NHWC}.')
  443. group.add_argument(
  444. '--precision',
  445. default='FP32',
  446. choices=('FP32', 'FP16', 'INT8'),
  447. help='The precision of TensorRT. It should be one of {FP32, FP16, INT8}.'
  448. )
  449. group.add_argument(
  450. '--workspace-size',
  451. type=int,
  452. default=(1 << 30),
  453. help='The memory workspace of TensorRT in MB.')
  454. group.add_argument(
  455. '--min-subgraph-size',
  456. type=int,
  457. default=3,
  458. help='The minimal subgraph size to enable PaddleTRT.')
  459. group.add_argument(
  460. '--use-static',
  461. type=distutils.util.strtobool,
  462. default=False,
  463. help='Fix TensorRT engine at first running.')
  464. group.add_argument(
  465. '--use-calib-mode',
  466. type=distutils.util.strtobool,
  467. default=False,
  468. help='Use the PTQ calibration of PaddleTRT int8.')
  469. group.add_argument(
  470. '--report-file',
  471. type=str,
  472. default='./inference.json',
  473. help='A file in which to store JSON inference report.')
  474. group.add_argument(
  475. '--use-synthetic',
  476. type=distutils.util.strtobool,
  477. default=False,
  478. help='Apply synthetic data for benchmark.')
  479. group.add_argument(
  480. '--benchmark-steps',
  481. type=int,
  482. default=100,
  483. help='Steps for benchmark run, only be applied when --benchmark is set.'
  484. )
  485. group.add_argument(
  486. '--benchmark-warmup-steps',
  487. type=int,
  488. default=100,
  489. help='Warmup steps for benchmark run, only be applied when --benchmark is set.'
  490. )
  491. group.add_argument(
  492. '--show-config',
  493. type=distutils.util.strtobool,
  494. default=True,
  495. help='To show arguments.')
  496. return parser
  497. def parse_args(script='train'):
  498. assert script in ['train', 'inference']
  499. parser = argparse.ArgumentParser(
  500. description=f'PaddlePaddle RN50v1.5 {script} script',
  501. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  502. if script == 'train':
  503. parser = add_general_args(parser)
  504. parser = add_dataset_args(parser)
  505. parser = add_model_args(parser)
  506. parser = add_training_args(parser)
  507. parser = add_advance_args(parser)
  508. args = parser.parse_args()
  509. check_and_process_args(args)
  510. else:
  511. parser = add_trt_args(parser)
  512. parser = add_dataset_args(parser)
  513. args = parser.parse_args()
  514. # Precess image layout and channel
  515. args.image_channel = args.image_shape[0]
  516. if args.data_layout == "NHWC":
  517. args.image_shape = [
  518. args.image_shape[1], args.image_shape[2], args.image_shape[0]
  519. ]
  520. return args