main.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. # Copyright (c) 2018-2019, NVIDIA CORPORATION
  2. # Copyright (c) 2017- Facebook, Inc
  3. #
  4. # All rights reserved.
  5. #
  6. # Redistribution and use in source and binary forms, with or without
  7. # modification, are permitted provided that the following conditions are met:
  8. #
  9. # * Redistributions of source code must retain the above copyright notice, this
  10. # list of conditions and the following disclaimer.
  11. #
  12. # * Redistributions in binary form must reproduce the above copyright notice,
  13. # this list of conditions and the following disclaimer in the documentation
  14. # and/or other materials provided with the distribution.
  15. #
  16. # * Neither the name of the copyright holder nor the names of its
  17. # contributors may be used to endorse or promote products derived from
  18. # this software without specific prior written permission.
  19. #
  20. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  23. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  24. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  25. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  26. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  27. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  28. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  29. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. import os
  31. os.environ[
  32. "KMP_AFFINITY"
  33. ] = "disabled" # We need to do this before importing anything else as a workaround for this bug: https://github.com/pytorch/pytorch/issues/28389
  34. import argparse
  35. import random
  36. from copy import deepcopy
  37. import torch.backends.cudnn as cudnn
  38. import torch.distributed as dist
  39. import torch.nn.parallel
  40. import torch.optim
  41. import torch.utils.data
  42. import torch.utils.data.distributed
  43. import image_classification.logger as log
  44. from image_classification.smoothing import LabelSmoothing
  45. from image_classification.mixup import NLLMultiLabelSmooth, MixUpWrapper
  46. from image_classification.dataloaders import *
  47. from image_classification.training import *
  48. from image_classification.utils import *
  49. from image_classification.models import (
  50. resnet50,
  51. resnext101_32x4d,
  52. se_resnext101_32x4d,
  53. efficientnet_b0,
  54. efficientnet_b4,
  55. efficientnet_widese_b0,
  56. efficientnet_widese_b4,
  57. )
  58. from image_classification.optimizers import (
  59. get_optimizer,
  60. lr_cosine_policy,
  61. lr_linear_policy,
  62. lr_step_policy,
  63. )
  64. from image_classification.gpu_affinity import set_affinity, AffinityMode
  65. import dllogger
  66. def available_models():
  67. models = {
  68. m.name: m
  69. for m in [
  70. resnet50,
  71. resnext101_32x4d,
  72. se_resnext101_32x4d,
  73. efficientnet_b0,
  74. efficientnet_b4,
  75. efficientnet_widese_b0,
  76. efficientnet_widese_b4,
  77. ]
  78. }
  79. return models
  80. def add_parser_arguments(parser, skip_arch=False):
  81. parser.add_argument("data", metavar="DIR", help="path to dataset")
  82. parser.add_argument(
  83. "--data-backend",
  84. metavar="BACKEND",
  85. default="dali-cpu",
  86. choices=DATA_BACKEND_CHOICES,
  87. help="data backend: "
  88. + " | ".join(DATA_BACKEND_CHOICES)
  89. + " (default: dali-cpu)",
  90. )
  91. parser.add_argument(
  92. "--interpolation",
  93. metavar="INTERPOLATION",
  94. default="bilinear",
  95. help="interpolation type for resizing images: bilinear, bicubic or triangular(DALI only)",
  96. )
  97. if not skip_arch:
  98. model_names = available_models().keys()
  99. parser.add_argument(
  100. "--arch",
  101. "-a",
  102. metavar="ARCH",
  103. default="resnet50",
  104. choices=model_names,
  105. help="model architecture: "
  106. + " | ".join(model_names)
  107. + " (default: resnet50)",
  108. )
  109. parser.add_argument(
  110. "-j",
  111. "--workers",
  112. default=5,
  113. type=int,
  114. metavar="N",
  115. help="number of data loading workers (default: 5)",
  116. )
  117. parser.add_argument(
  118. "--prefetch",
  119. default=2,
  120. type=int,
  121. metavar="N",
  122. help="number of samples prefetched by each loader",
  123. )
  124. parser.add_argument(
  125. "--epochs",
  126. default=90,
  127. type=int,
  128. metavar="N",
  129. help="number of total epochs to run",
  130. )
  131. parser.add_argument(
  132. "--run-epochs",
  133. default=-1,
  134. type=int,
  135. metavar="N",
  136. help="run only N epochs, used for checkpointing runs",
  137. )
  138. parser.add_argument(
  139. "--early-stopping-patience",
  140. default=-1,
  141. type=int,
  142. metavar="N",
  143. help="early stopping after N epochs without validation accuracy improving",
  144. )
  145. parser.add_argument(
  146. "--image-size", default=None, type=int, help="resolution of image"
  147. )
  148. parser.add_argument(
  149. "-b",
  150. "--batch-size",
  151. default=256,
  152. type=int,
  153. metavar="N",
  154. help="mini-batch size (default: 256) per gpu",
  155. )
  156. parser.add_argument(
  157. "--optimizer-batch-size",
  158. default=-1,
  159. type=int,
  160. metavar="N",
  161. help="size of a total batch size, for simulating bigger batches using gradient accumulation",
  162. )
  163. parser.add_argument(
  164. "--lr",
  165. "--learning-rate",
  166. default=0.1,
  167. type=float,
  168. metavar="LR",
  169. help="initial learning rate",
  170. )
  171. parser.add_argument(
  172. "--lr-schedule",
  173. default="step",
  174. type=str,
  175. metavar="SCHEDULE",
  176. choices=["step", "linear", "cosine"],
  177. help="Type of LR schedule: {}, {}, {}".format("step", "linear", "cosine"),
  178. )
  179. parser.add_argument("--end-lr", default=0, type=float)
  180. parser.add_argument(
  181. "--warmup", default=0, type=int, metavar="E", help="number of warmup epochs"
  182. )
  183. parser.add_argument(
  184. "--label-smoothing",
  185. default=0.0,
  186. type=float,
  187. metavar="S",
  188. help="label smoothing",
  189. )
  190. parser.add_argument(
  191. "--mixup", default=0.0, type=float, metavar="ALPHA", help="mixup alpha"
  192. )
  193. parser.add_argument(
  194. "--optimizer", default="sgd", type=str, choices=("sgd", "rmsprop")
  195. )
  196. parser.add_argument(
  197. "--momentum", default=0.9, type=float, metavar="M", help="momentum"
  198. )
  199. parser.add_argument(
  200. "--weight-decay",
  201. "--wd",
  202. default=1e-4,
  203. type=float,
  204. metavar="W",
  205. help="weight decay (default: 1e-4)",
  206. )
  207. parser.add_argument(
  208. "--bn-weight-decay",
  209. action="store_true",
  210. help="use weight_decay on batch normalization learnable parameters, (default: false)",
  211. )
  212. parser.add_argument(
  213. "--rmsprop-alpha",
  214. default=0.9,
  215. type=float,
  216. help="value of alpha parameter in rmsprop optimizer (default: 0.9)",
  217. )
  218. parser.add_argument(
  219. "--rmsprop-eps",
  220. default=1e-3,
  221. type=float,
  222. help="value of eps parameter in rmsprop optimizer (default: 1e-3)",
  223. )
  224. parser.add_argument(
  225. "--nesterov",
  226. action="store_true",
  227. help="use nesterov momentum, (default: false)",
  228. )
  229. parser.add_argument(
  230. "--print-freq",
  231. "-p",
  232. default=10,
  233. type=int,
  234. metavar="N",
  235. help="print frequency (default: 10)",
  236. )
  237. parser.add_argument(
  238. "--resume",
  239. default=None,
  240. type=str,
  241. metavar="PATH",
  242. help="path to latest checkpoint (default: none)",
  243. )
  244. parser.add_argument(
  245. "--static-loss-scale",
  246. type=float,
  247. default=1,
  248. help="Static loss scale, positive power of 2 values can improve amp convergence.",
  249. )
  250. parser.add_argument(
  251. "--prof", type=int, default=-1, metavar="N", help="Run only N iterations"
  252. )
  253. parser.add_argument(
  254. "--amp",
  255. action="store_true",
  256. help="Run model AMP (automatic mixed precision) mode.",
  257. )
  258. parser.add_argument(
  259. "--seed", default=None, type=int, help="random seed used for numpy and pytorch"
  260. )
  261. parser.add_argument(
  262. "--gather-checkpoints",
  263. default="0",
  264. type=int,
  265. help=(
  266. "Gather N last checkpoints throughout the training,"
  267. " without this flag only best and last checkpoints will be stored. "
  268. "Use -1 for all checkpoints"
  269. ),
  270. )
  271. parser.add_argument(
  272. "--raport-file",
  273. default="experiment_raport.json",
  274. type=str,
  275. help="file in which to store JSON experiment raport",
  276. )
  277. parser.add_argument(
  278. "--evaluate", action="store_true", help="evaluate checkpoint/model"
  279. )
  280. parser.add_argument("--training-only", action="store_true", help="do not evaluate")
  281. parser.add_argument(
  282. "--no-checkpoints",
  283. action="store_false",
  284. dest="save_checkpoints",
  285. help="do not store any checkpoints, useful for benchmarking",
  286. )
  287. parser.add_argument(
  288. "--jit",
  289. type=str,
  290. default="no",
  291. choices=["no", "script"],
  292. help="no -> do not use torch.jit; script -> use torch.jit.script",
  293. )
  294. parser.add_argument("--checkpoint-filename", default="checkpoint.pth.tar", type=str)
  295. parser.add_argument(
  296. "--workspace",
  297. type=str,
  298. default="./",
  299. metavar="DIR",
  300. help="path to directory where checkpoints will be stored",
  301. )
  302. parser.add_argument(
  303. "--memory-format",
  304. type=str,
  305. default="nchw",
  306. choices=["nchw", "nhwc"],
  307. help="memory layout, nchw or nhwc",
  308. )
  309. parser.add_argument("--use-ema", default=None, type=float, help="use EMA")
  310. parser.add_argument(
  311. "--augmentation",
  312. type=str,
  313. default=None,
  314. choices=[None, "autoaugment"],
  315. help="augmentation method",
  316. )
  317. parser.add_argument(
  318. "--gpu-affinity",
  319. type=str,
  320. default="none",
  321. required=False,
  322. choices=[am.name for am in AffinityMode],
  323. )
  324. parser.add_argument(
  325. "--topk",
  326. type=int,
  327. default=5,
  328. required=False,
  329. )
  330. def prepare_for_training(args, model_args, model_arch):
  331. args.distributed = False
  332. if "WORLD_SIZE" in os.environ:
  333. args.distributed = int(os.environ["WORLD_SIZE"]) > 1
  334. args.local_rank = int(os.environ["LOCAL_RANK"])
  335. else:
  336. args.local_rank = 0
  337. args.gpu = 0
  338. args.world_size = 1
  339. if args.distributed:
  340. args.gpu = args.local_rank % torch.cuda.device_count()
  341. torch.cuda.set_device(args.gpu)
  342. dist.init_process_group(backend="nccl", init_method="env://")
  343. args.world_size = torch.distributed.get_world_size()
  344. affinity = set_affinity(args.gpu, mode=args.gpu_affinity)
  345. print(f"Training process {args.local_rank} affinity: {affinity}")
  346. if args.seed is not None:
  347. print("Using seed = {}".format(args.seed))
  348. torch.manual_seed(args.seed + args.local_rank)
  349. torch.cuda.manual_seed(args.seed + args.local_rank)
  350. np.random.seed(seed=args.seed + args.local_rank)
  351. random.seed(args.seed + args.local_rank)
  352. def _worker_init_fn(id):
  353. # Worker process should inherit its affinity from parent
  354. affinity = os.sched_getaffinity(0)
  355. print(f"Process {args.local_rank} Worker {id} set affinity to: {affinity}")
  356. np.random.seed(seed=args.seed + args.local_rank + id)
  357. random.seed(args.seed + args.local_rank + id)
  358. else:
  359. def _worker_init_fn(id):
  360. # Worker process should inherit its affinity from parent
  361. affinity = os.sched_getaffinity(0)
  362. print(f"Process {args.local_rank} Worker {id} set affinity to: {affinity}")
  363. if args.static_loss_scale != 1.0:
  364. if not args.amp:
  365. print("Warning: if --amp is not used, static_loss_scale will be ignored.")
  366. if args.optimizer_batch_size < 0:
  367. batch_size_multiplier = 1
  368. else:
  369. tbs = args.world_size * args.batch_size
  370. if args.optimizer_batch_size % tbs != 0:
  371. print(
  372. "Warning: simulated batch size {} is not divisible by actual batch size {}".format(
  373. args.optimizer_batch_size, tbs
  374. )
  375. )
  376. batch_size_multiplier = int(args.optimizer_batch_size / tbs)
  377. print("BSM: {}".format(batch_size_multiplier))
  378. start_epoch = 0
  379. # optionally resume from a checkpoint
  380. if args.resume is not None:
  381. if os.path.isfile(args.resume):
  382. print("=> loading checkpoint '{}'".format(args.resume))
  383. checkpoint = torch.load(
  384. args.resume, map_location=lambda storage, loc: storage.cuda(args.gpu)
  385. )
  386. start_epoch = checkpoint["epoch"]
  387. best_prec1 = checkpoint["best_prec1"]
  388. model_state = checkpoint["state_dict"]
  389. optimizer_state = checkpoint["optimizer"]
  390. if "state_dict_ema" in checkpoint:
  391. model_state_ema = checkpoint["state_dict_ema"]
  392. print(
  393. "=> loaded checkpoint '{}' (epoch {})".format(
  394. args.resume, checkpoint["epoch"]
  395. )
  396. )
  397. if start_epoch >= args.epochs:
  398. print(
  399. f"Launched training for {args.epochs}, checkpoint already run {start_epoch}"
  400. )
  401. exit(1)
  402. else:
  403. print("=> no checkpoint found at '{}'".format(args.resume))
  404. model_state = None
  405. model_state_ema = None
  406. optimizer_state = None
  407. else:
  408. model_state = None
  409. model_state_ema = None
  410. optimizer_state = None
  411. loss = nn.CrossEntropyLoss
  412. if args.mixup > 0.0:
  413. loss = lambda: NLLMultiLabelSmooth(args.label_smoothing)
  414. elif args.label_smoothing > 0.0:
  415. loss = lambda: LabelSmoothing(args.label_smoothing)
  416. memory_format = (
  417. torch.channels_last if args.memory_format == "nhwc" else torch.contiguous_format
  418. )
  419. model = model_arch(
  420. **{
  421. k: v
  422. if k != "pretrained"
  423. else v and (not args.distributed or dist.get_rank() == 0)
  424. for k, v in model_args.__dict__.items()
  425. }
  426. )
  427. image_size = (
  428. args.image_size
  429. if args.image_size is not None
  430. else model.arch.default_image_size
  431. )
  432. scaler = torch.cuda.amp.GradScaler(
  433. init_scale=args.static_loss_scale,
  434. growth_factor=2,
  435. backoff_factor=0.5,
  436. growth_interval=100,
  437. enabled=args.amp,
  438. )
  439. executor = Executor(
  440. model,
  441. loss(),
  442. cuda=True,
  443. memory_format=memory_format,
  444. amp=args.amp,
  445. scaler=scaler,
  446. divide_loss=batch_size_multiplier,
  447. ts_script=args.jit == "script",
  448. )
  449. # Create data loaders and optimizers as needed
  450. if args.data_backend == "pytorch":
  451. get_train_loader = get_pytorch_train_loader
  452. get_val_loader = get_pytorch_val_loader
  453. elif args.data_backend == "dali-gpu":
  454. get_train_loader = get_dali_train_loader(dali_cpu=False)
  455. get_val_loader = get_dali_val_loader()
  456. elif args.data_backend == "dali-cpu":
  457. get_train_loader = get_dali_train_loader(dali_cpu=True)
  458. get_val_loader = get_dali_val_loader()
  459. elif args.data_backend == "synthetic":
  460. get_val_loader = get_synthetic_loader
  461. get_train_loader = get_synthetic_loader
  462. else:
  463. print("Bad databackend picked")
  464. exit(1)
  465. train_loader, train_loader_len = get_train_loader(
  466. args.data,
  467. image_size,
  468. args.batch_size,
  469. model_args.num_classes,
  470. args.mixup > 0.0,
  471. interpolation=args.interpolation,
  472. augmentation=args.augmentation,
  473. start_epoch=start_epoch,
  474. workers=args.workers,
  475. _worker_init_fn=_worker_init_fn,
  476. memory_format=memory_format,
  477. prefetch_factor=args.prefetch,
  478. )
  479. if args.mixup != 0.0:
  480. train_loader = MixUpWrapper(args.mixup, train_loader)
  481. val_loader, val_loader_len = get_val_loader(
  482. args.data,
  483. image_size,
  484. args.batch_size,
  485. model_args.num_classes,
  486. False,
  487. interpolation=args.interpolation,
  488. workers=args.workers,
  489. _worker_init_fn=_worker_init_fn,
  490. memory_format=memory_format,
  491. prefetch_factor=args.prefetch,
  492. )
  493. if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
  494. logger = log.Logger(
  495. args.print_freq,
  496. [
  497. dllogger.StdOutBackend(
  498. dllogger.Verbosity.DEFAULT, step_format=log.format_step
  499. ),
  500. dllogger.JSONStreamBackend(
  501. dllogger.Verbosity.VERBOSE,
  502. os.path.join(args.workspace, args.raport_file),
  503. ),
  504. ],
  505. start_epoch=start_epoch - 1,
  506. )
  507. else:
  508. logger = log.Logger(args.print_freq, [], start_epoch=start_epoch - 1)
  509. logger.log_parameter(args.__dict__, verbosity=dllogger.Verbosity.DEFAULT)
  510. logger.log_parameter(
  511. {f"model.{k}": v for k, v in model_args.__dict__.items()},
  512. verbosity=dllogger.Verbosity.DEFAULT,
  513. )
  514. optimizer = get_optimizer(
  515. list(executor.model.named_parameters()),
  516. args.lr,
  517. args=args,
  518. state=optimizer_state,
  519. )
  520. if args.lr_schedule == "step":
  521. lr_policy = lr_step_policy(args.lr, [30, 60, 80], 0.1, args.warmup)
  522. elif args.lr_schedule == "cosine":
  523. lr_policy = lr_cosine_policy(
  524. args.lr, args.warmup, args.epochs, end_lr=args.end_lr
  525. )
  526. elif args.lr_schedule == "linear":
  527. lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs)
  528. if args.distributed:
  529. executor.distributed(args.gpu)
  530. if model_state is not None:
  531. executor.model.load_state_dict(model_state)
  532. trainer = Trainer(
  533. executor,
  534. optimizer,
  535. grad_acc_steps=batch_size_multiplier,
  536. ema=args.use_ema,
  537. )
  538. if (args.use_ema is not None) and (model_state_ema is not None):
  539. trainer.ema_executor.model.load_state_dict(model_state_ema)
  540. return (
  541. trainer,
  542. lr_policy,
  543. train_loader,
  544. train_loader_len,
  545. val_loader,
  546. logger,
  547. start_epoch,
  548. )
  549. def main(args, model_args, model_arch):
  550. exp_start_time = time.time()
  551. global best_prec1
  552. best_prec1 = 0
  553. (
  554. trainer,
  555. lr_policy,
  556. train_loader,
  557. train_loader_len,
  558. val_loader,
  559. logger,
  560. start_epoch,
  561. ) = prepare_for_training(args, model_args, model_arch)
  562. train_loop(
  563. trainer,
  564. lr_policy,
  565. train_loader,
  566. train_loader_len,
  567. val_loader,
  568. logger,
  569. start_epoch=start_epoch,
  570. end_epoch=min((start_epoch + args.run_epochs), args.epochs)
  571. if args.run_epochs != -1
  572. else args.epochs,
  573. early_stopping_patience=args.early_stopping_patience,
  574. best_prec1=best_prec1,
  575. prof=args.prof,
  576. skip_training=args.evaluate,
  577. skip_validation=args.training_only,
  578. save_checkpoints=args.save_checkpoints and not args.evaluate,
  579. checkpoint_dir=args.workspace,
  580. checkpoint_filename=args.checkpoint_filename,
  581. keep_last_n_checkpoints=args.gather_checkpoints,
  582. topk=args.topk,
  583. )
  584. exp_duration = time.time() - exp_start_time
  585. if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
  586. logger.end()
  587. print("Experiment ended")
  588. if __name__ == "__main__":
  589. epilog = [
  590. "Based on the architecture picked by --arch flag, you may use the following options:\n"
  591. ]
  592. for model, ep in available_models().items():
  593. model_help = "\n".join(ep.parser().format_help().split("\n")[2:])
  594. epilog.append(model_help)
  595. parser = argparse.ArgumentParser(
  596. description="PyTorch ImageNet Training",
  597. epilog="\n".join(epilog),
  598. formatter_class=argparse.RawDescriptionHelpFormatter,
  599. )
  600. add_parser_arguments(parser)
  601. args, rest = parser.parse_known_args()
  602. model_arch = available_models()[args.arch]
  603. model_args, rest = model_arch.parser().parse_known_args(rest)
  604. print(model_args)
  605. assert len(rest) == 0, f"Unknown args passed: {rest}"
  606. cudnn.benchmark = True
  607. main(args, model_args, model_arch)