train.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. #!/usr/bin/env python
  2. """ EfficientDet Training Script
  3. This script was started from an early version of the PyTorch ImageNet example
  4. (https://github.com/pytorch/examples/tree/master/imagenet)
  5. NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
  6. (https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
  7. Hacked together by Ross Wightman (https://github.com/rwightman)
  8. """
  9. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. import argparse
  23. import time
  24. import yaml
  25. import os
  26. from datetime import datetime
  27. import ctypes
  28. import numpy as np
  29. import random
  30. import copy
  31. import torch
  32. import torchvision.utils
  33. from torch.nn.parallel import DistributedDataParallel as DDP
  34. import dllogger
  35. from effdet.factory import create_model
  36. from effdet.evaluator import COCOEvaluator
  37. from effdet.bench import unwrap_bench
  38. from data import create_loader, CocoDetection
  39. from utils.gpu_affinity import set_affinity
  40. from utils.utils import *
  41. from utils.optimizers import create_optimizer, clip_grad_norm_2
  42. from utils.scheduler import create_scheduler
  43. from utils.model_ema import ModelEma
  44. torch.backends.cudnn.benchmark = True
  45. _libcudart = ctypes.CDLL('libcudart.so')
  46. # The first arg parser parses out only the --config argument, this argument is used to
  47. # load a yaml file containing key-values that override the defaults for the main parser below
  48. config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
  49. parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
  50. help='YAML config file specifying default arguments')
  51. def add_bool_arg(parser, name, default=False, help=''): # FIXME move to utils
  52. dest_name = name.replace('-', '_')
  53. group = parser.add_mutually_exclusive_group(required=False)
  54. group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
  55. group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
  56. parser.set_defaults(**{dest_name: default})
  57. parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
  58. # Dataset / Model parameters
  59. parser.add_argument('data', metavar='DIR',
  60. help='path to dataset')
  61. parser.add_argument('--model', default='tf_efficientdet_d1', type=str, metavar='MODEL',
  62. help='Name of model to train (default: "countception"')
  63. add_bool_arg(parser, 'redundant-bias', default=None,
  64. help='override model config for redundant bias')
  65. parser.set_defaults(redundant_bias=None)
  66. parser.add_argument('--pretrained', action='store_true', default=False,
  67. help='Start with pretrained version of specified network (if avail)')
  68. parser.add_argument('--pretrained-backbone-path', default='', type=str, metavar='PATH',
  69. help='Start from pretrained backbone weights.')
  70. parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
  71. help='Initialize model from this checkpoint (default: none)')
  72. parser.add_argument('--resume', action='store_true', default=False,
  73. help='Resume full model and optimizer state from checkpoint (default: False)')
  74. parser.add_argument('--no-resume-opt', action='store_true', default=False,
  75. help='prevent resume of optimizer state when resuming model')
  76. parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
  77. help='Image resize interpolation type (overrides model)')
  78. parser.add_argument('--fill-color', default='0', type=str, metavar='NAME',
  79. help='Image augmentation fill (background) color ("mean" or int)')
  80. parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
  81. help='input batch size for training (default: 32)')
  82. parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
  83. help='ratio of validation batch size to training batch size (default: 1)')
  84. parser.add_argument('--input_size', type=int, default=None, metavar='PCT',
  85. help='Image size (default: None) if this is not set default model image size is taken')
  86. parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
  87. help='Dropout rate (default: 0.)')
  88. parser.add_argument('--clip-grad', type=float, default=10.0, metavar='NORM',
  89. help='Clip gradient norm (default: 10.0)')
  90. # Optimizer parameters
  91. parser.add_argument('--opt', default='momentum', type=str, metavar='OPTIMIZER',
  92. help='Optimizer (default: "momentum"')
  93. parser.add_argument('--opt-eps', default=1e-3, type=float, metavar='EPSILON',
  94. help='Optimizer Epsilon (default: 1e-3)')
  95. parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
  96. help='SGD momentum (default: 0.9)')
  97. parser.add_argument('--weight-decay', type=float, default=4e-5,
  98. help='weight decay (default: 0.00004)')
  99. # Learning rate schedule parameters
  100. parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
  101. help='LR scheduler (default: "step"')
  102. parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
  103. help='learning rate (default: 0.01)')
  104. parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
  105. help='learning rate noise on/off epoch percentages')
  106. parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
  107. help='learning rate noise limit percent (default: 0.67)')
  108. parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
  109. help='learning rate noise std-dev (default: 1.0)')
  110. parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
  111. help='learning rate cycle len multiplier (default: 1.0)')
  112. parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
  113. help='learning rate cycle limit')
  114. parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
  115. help='warmup learning rate (default: 0.0001)')
  116. parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
  117. help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
  118. parser.add_argument('--epochs', type=int, default=300, metavar='N',
  119. help='number of epochs to train (default: 2)')
  120. parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
  121. help='manual epoch number (useful on restarts)')
  122. parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
  123. help='epoch interval to decay LR')
  124. parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
  125. help='epochs to warmup LR, if scheduler supports')
  126. parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
  127. help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
  128. parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
  129. help='patience epochs for Plateau LR scheduler (default: 10')
  130. parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
  131. help='LR decay rate (default: 0.1)')
  132. # Augmentation parameters
  133. parser.add_argument('--mixup', type=float, default=0.0,
  134. help='mixup alpha, mixup enabled if > 0. (default: 0.)')
  135. parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
  136. help='turn off mixup after this epoch, disabled if 0 (default: 0)')
  137. parser.add_argument('--smoothing', type=float, default=0.0,
  138. help='label smoothing (default: 0.0)')
  139. parser.add_argument('--train-interpolation', type=str, default='random',
  140. help='Training interpolation (random, bilinear, bicubic default: "random")')
  141. parser.add_argument('--sync-bn', action='store_true',
  142. help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
  143. parser.add_argument('--dist-bn', type=str, default='',
  144. help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
  145. # Model Exponential Moving Average
  146. parser.add_argument('--model-ema', action='store_true', default=False,
  147. help='Enable tracking moving average of model weights')
  148. parser.add_argument('--model-ema-decay', type=float, default=0.9998,
  149. help='decay factor for model weights moving average (default: 0.9998)')
  150. # Misc
  151. parser.add_argument('--dist-group-size', type=int, default=0,
  152. help='Group size for sync-bn')
  153. parser.add_argument('--seed', type=int, default=42, metavar='S',
  154. help='random seed (default: 42)')
  155. parser.add_argument('--log-interval', type=int, default=50, metavar='N',
  156. help='how many batches to wait before logging training status')
  157. parser.add_argument('--eval-after', type=int, default=0, metavar='N',
  158. help='Start evaluating after eval-after epochs')
  159. parser.add_argument('--benchmark', action='store_true', default=False,
  160. help='Turn this on when measuring performance')
  161. parser.add_argument('--benchmark-steps', type=int, default=0, metavar='N',
  162. help='Run training for this number of steps for performance measurement')
  163. parser.add_argument('--dllogger-file', default='log.json', type=str, metavar='PATH',
  164. help='File name of dllogger json file (default: log.json, current dir)')
  165. parser.add_argument('--save-checkpoint-interval', type=int, default=10, metavar='N',
  166. help='Save checkpoints after so many epochs')
  167. parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
  168. help='how many training processes to use (default: 1)')
  169. parser.add_argument('--amp', action='store_true', default=False,
  170. help='use NVIDIA amp for mixed precision training')
  171. parser.add_argument('--no-pin-mem', dest='pin_mem', action='store_false',
  172. help='Disable pin CPU memory in DataLoader.')
  173. parser.add_argument('--no-prefetcher', action='store_true', default=False,
  174. help='disable fast prefetcher')
  175. parser.add_argument('--output', default='', type=str, metavar='PATH',
  176. help='path to output folder (default: none, current dir)')
  177. parser.add_argument('--eval-metric', default='map', type=str, metavar='EVAL_METRIC',
  178. help='Best metric (default: "map"')
  179. parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
  180. parser.add_argument("--memory-format", type=str, default="nchw", choices=["nchw", "nhwc"],
  181. help="memory layout, nchw or nhwc")
  182. parser.add_argument("--fused-focal-loss", action='store_true',
  183. help="Use fused focal loss for better performance.")
  184. # Waymo
  185. parser.add_argument('--waymo', action='store_true', default=False,
  186. help='Train on Waymo dataset or COCO dataset. Default: False (COCO dataset)')
  187. parser.add_argument('--num_classes', type=int, default=None, metavar='PCT',
  188. help='Number of classes the model needs to be trained for (default: None)')
  189. parser.add_argument('--remove-weights', nargs='*', default=[],
  190. help='Remove these weights from the state dict before loading checkpoint (use case can be not loading heads)')
  191. parser.add_argument('--freeze-layers', nargs='*', default=[],
  192. help='Freeze these layers')
  193. parser.add_argument('--waymo-train-annotation', default=None, type=str,
  194. help='Absolute Path to waymo training annotation (default: "None")')
  195. parser.add_argument('--waymo-val-annotation', default=None, type=str,
  196. help='Absolute Path to waymo validation annotation (default: "None")')
  197. parser.add_argument('--waymo-train', default=None, type=str,
  198. help='Path to waymo training relative to waymo data (default: "None")')
  199. parser.add_argument('--waymo-val', default=None, type=str,
  200. help='Path to waymo validation relative to waymo data (default: "None")')
  201. def _parse_args():
  202. # Do we have a config file to parse?
  203. args_config, remaining = config_parser.parse_known_args()
  204. if args_config.config:
  205. with open(args_config.config, 'r') as f:
  206. cfg = yaml.safe_load(f)
  207. parser.set_defaults(**cfg)
  208. # The main arg parser parses the rest of the args, the usual
  209. # defaults will have been overridden if config file specified.
  210. args = parser.parse_args(remaining)
  211. # Cache the args as a text string to save them in the output dir later
  212. args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
  213. return args, args_text
  214. def get_outdirectory(path, *paths):
  215. outdir = os.path.join(path, *paths)
  216. if not os.path.exists(outdir):
  217. os.makedirs(outdir, exist_ok=True)
  218. return outdir
  219. def main():
  220. setup_default_logging() ## TODO(sugh) replace
  221. args, args_text = _parse_args()
  222. set_affinity(args.local_rank)
  223. random.seed(args.seed)
  224. torch.manual_seed(args.seed)
  225. np.random.seed(args.seed)
  226. args.prefetcher = not args.no_prefetcher
  227. args.distributed = False
  228. if 'WORLD_SIZE' in os.environ:
  229. args.distributed = int(os.environ['WORLD_SIZE']) > 1
  230. args.device = 'cuda:0'
  231. args.world_size = 1
  232. args.rank = 0 # global rank
  233. if args.distributed:
  234. torch.cuda.manual_seed_all(args.seed)
  235. args.device = 'cuda:%d' % args.local_rank
  236. torch.cuda.set_device(args.local_rank)
  237. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  238. args.world_size = torch.distributed.get_world_size()
  239. args.rank = torch.distributed.get_rank()
  240. # Set device limit on the current device
  241. # cudaLimitMaxL2FetchGranularity = 0x05
  242. pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
  243. _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
  244. _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
  245. assert pValue.contents.value == 128
  246. assert args.rank >= 0
  247. setup_dllogger(args.rank, filename=args.dllogger_file)
  248. dllogger.metadata('eval_batch_time', {'unit': 's'})
  249. dllogger.metadata('train_batch_time', {'unit': 's'})
  250. dllogger.metadata('eval_throughput', {'unit': 'images/s'})
  251. dllogger.metadata('train_throughout', {'unit': 'images/s'})
  252. dllogger.metadata('eval_loss', {'unit': None})
  253. dllogger.metadata('train_loss', {'unit': None})
  254. dllogger.metadata('map', {'unit': None})
  255. if args.distributed:
  256. logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
  257. % (args.rank, args.world_size))
  258. else:
  259. logging.info('Training with a single process on 1 GPU.')
  260. if args.waymo:
  261. if (args.waymo_train is not None and args.waymo_val is None) or (args.waymo_train is None and args.waymo_val is not None):
  262. raise Exception("waymo_train or waymo_val is not set")
  263. memory_format = (
  264. torch.channels_last if args.memory_format == "nhwc" else torch.contiguous_format
  265. )
  266. model = create_model(
  267. args.model,
  268. input_size=args.input_size,
  269. num_classes=args.num_classes,
  270. bench_task='train',
  271. pretrained=args.pretrained,
  272. pretrained_backbone_path=args.pretrained_backbone_path,
  273. redundant_bias=args.redundant_bias,
  274. checkpoint_path=args.initial_checkpoint,
  275. label_smoothing=args.smoothing,
  276. fused_focal_loss=args.fused_focal_loss,
  277. remove_params=args.remove_weights,
  278. freeze_layers=args.freeze_layers,
  279. strict_load=False
  280. )
  281. # FIXME decide which args to keep and overlay on config / pass to backbone
  282. # num_classes=args.num_classes,
  283. input_size = model.config.image_size
  284. data_config = model.config
  285. print("Input size to be passed to dataloaders: {}".format(input_size))
  286. print("Image size used in model: {}".format(model.config.image_size))
  287. if args.rank == 0:
  288. dllogger.log(step='PARAMETER', data={'model_name':args.model, 'param_count': sum([m.numel() for m in model.parameters()])})
  289. model = model.cuda().to(memory_format=memory_format)
  290. # # optionally resume from a checkpoint
  291. if args.distributed:
  292. if args.sync_bn:
  293. try:
  294. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  295. if args.local_rank == 0:
  296. logging.info(
  297. 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
  298. 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
  299. except Exception as e:
  300. logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
  301. optimizer = create_optimizer(args, model)
  302. scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
  303. resume_state = {}
  304. resume_epoch = None
  305. output_base = args.output if args.output else './output'
  306. resume_checkpoint_path = get_latest_checkpoint(os.path.join(output_base, 'train'))
  307. if args.resume and resume_checkpoint_path is not None:
  308. print("Trying to load checkpoint from {}".format(resume_checkpoint_path))
  309. resume_state, resume_epoch = resume_checkpoint(unwrap_bench(model), resume_checkpoint_path)
  310. if resume_epoch is not None:
  311. print("Resume training from {} epoch".format(resume_epoch))
  312. if resume_state and not args.no_resume_opt:
  313. if 'optimizer' in resume_state:
  314. if args.local_rank == 0:
  315. logging.info('Restoring Optimizer state from checkpoint')
  316. optimizer.load_state_dict(resume_state['optimizer'])
  317. if args.amp and 'scaler' in resume_state:
  318. if args.local_rank == 0:
  319. logging.info('Restoring NVIDIA AMP state from checkpoint')
  320. scaler.load_state_dict(resume_state['scaler'])
  321. del resume_state
  322. model_ema = None
  323. if args.model_ema:
  324. # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
  325. if args.resume and resume_checkpoint_path is not None:
  326. resume_path = resume_checkpoint_path
  327. else:
  328. resume_path = ''
  329. model_ema = ModelEma(
  330. model,
  331. decay=args.model_ema_decay,
  332. resume=resume_path)
  333. if args.distributed:
  334. if args.local_rank == 0:
  335. logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
  336. model = DDP(model, device_ids=[args.device]) # can use device str in Torch >= 1.1
  337. # NOTE: EMA model does not need to be wrapped by DDP
  338. lr_scheduler, num_epochs = create_scheduler(args, optimizer)
  339. start_epoch = 0
  340. if args.start_epoch is not None:
  341. # a specified start_epoch will always override the resume epoch
  342. start_epoch = args.start_epoch
  343. elif resume_epoch is not None:
  344. start_epoch = resume_epoch
  345. if lr_scheduler is not None and start_epoch > 0:
  346. lr_scheduler.step(start_epoch)
  347. if args.local_rank == 0:
  348. dllogger.log(step="PARAMETER", data={'Scheduled_epochs': num_epochs}, verbosity=0)
  349. # Benchmark will always override every other setting.
  350. if args.benchmark:
  351. start_epoch = 0
  352. num_epochs = args.epochs
  353. if args.waymo:
  354. train_annotation_path = args.waymo_train_annotation
  355. train_image_dir = args.waymo_train
  356. else:
  357. train_anno_set = 'train2017'
  358. train_annotation_path = os.path.join(args.data, 'annotations', f'instances_{train_anno_set}.json')
  359. train_image_dir = train_anno_set
  360. dataset_train = CocoDetection(os.path.join(args.data, train_image_dir), train_annotation_path, data_config)
  361. loader_train = create_loader(
  362. dataset_train,
  363. input_size=input_size,
  364. batch_size=args.batch_size,
  365. is_training=True,
  366. use_prefetcher=args.prefetcher,
  367. interpolation=args.train_interpolation,
  368. num_workers=args.workers,
  369. distributed=args.distributed,
  370. pin_mem=args.pin_mem,
  371. memory_format=memory_format
  372. )
  373. loader_train_iter = iter(loader_train)
  374. steps_per_epoch = int(np.ceil( len(dataset_train) / (args.world_size * args.batch_size) ))
  375. if args.waymo:
  376. val_annotation_path = args.waymo_val_annotation
  377. val_image_dir = args.waymo_val
  378. else:
  379. val_anno_set = 'val2017'
  380. val_annotation_path = os.path.join(args.data, 'annotations', f'instances_{val_anno_set}.json')
  381. val_image_dir = val_anno_set
  382. dataset_eval = CocoDetection(os.path.join(args.data, val_image_dir), val_annotation_path, data_config)
  383. loader_eval = create_loader(
  384. dataset_eval,
  385. input_size=input_size,
  386. batch_size=args.validation_batch_size_multiplier * args.batch_size,
  387. is_training=False,
  388. use_prefetcher=args.prefetcher,
  389. interpolation=args.interpolation,
  390. num_workers=args.workers,
  391. distributed=args.distributed,
  392. pin_mem=args.pin_mem,
  393. memory_format=memory_format
  394. )
  395. evaluator = COCOEvaluator(dataset_eval.coco, distributed=args.distributed, waymo=args.waymo)
  396. eval_metric = args.eval_metric
  397. eval_metrics = None
  398. train_metrics = {}
  399. best_metric = -1
  400. is_best = False
  401. best_epoch = None
  402. saver = None
  403. output_dir = ''
  404. if args.rank == 0:
  405. output_base = args.output if args.output else './output'
  406. output_dir = get_outdirectory(output_base, 'train')
  407. decreasing = True if eval_metric == 'loss' else False
  408. saver = CheckpointSaver(checkpoint_dir=output_dir)
  409. with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
  410. f.write(args_text)
  411. try:
  412. for epoch in range(start_epoch, num_epochs):
  413. if args.distributed:
  414. loader_train.sampler.set_epoch(epoch)
  415. train_metrics = train_epoch(
  416. epoch, steps_per_epoch, model, loader_train_iter, optimizer, args,
  417. lr_scheduler=lr_scheduler, output_dir=output_dir, use_amp=args.amp, scaler=scaler, model_ema=model_ema)
  418. if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
  419. if args.local_rank == 0:
  420. logging.info("Distributing BatchNorm running means and vars")
  421. distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
  422. # the overhead of evaluating with coco style datasets is fairly high, so just ema or non, not both
  423. if model_ema is not None:
  424. if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
  425. distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
  426. if epoch >= args.eval_after:
  427. eval_metrics = validate(model_ema.ema, loader_eval, args, evaluator, epoch, log_suffix=' (EMA)')
  428. else:
  429. eval_metrics = validate(model, loader_eval, args, evaluator, epoch)
  430. lr_scheduler.step(epoch + 1)
  431. if saver is not None and args.rank == 0 and epoch % args.save_checkpoint_interval == 0:
  432. if eval_metrics is not None:
  433. # save proper checkpoint with eval metric
  434. is_best = eval_metrics[eval_metric] > best_metric
  435. best_metric = max(
  436. eval_metrics[eval_metric],
  437. best_metric
  438. )
  439. best_epoch = epoch
  440. else:
  441. is_best = False
  442. best_metric = 0
  443. saver.save_checkpoint(model, optimizer, epoch, model_ema=model_ema, metric=best_metric, is_best=is_best)
  444. except KeyboardInterrupt:
  445. dllogger.flush()
  446. torch.cuda.empty_cache()
  447. if best_metric > 0:
  448. train_metrics.update({'best_map': best_metric, 'best_epoch': best_epoch})
  449. if eval_metrics is not None:
  450. train_metrics.update(eval_metrics)
  451. dllogger.log(step=(), data=train_metrics, verbosity=0)
  452. def train_epoch(
  453. epoch, steps_per_epoch, model, loader_iter, optimizer, args,
  454. lr_scheduler=None, output_dir='', use_amp=False, scaler=None, model_ema=None):
  455. batch_time_m = AverageMeter()
  456. data_time_m = AverageMeter()
  457. losses_m = AverageMeter()
  458. throughput_m = AverageMeter()
  459. model.train()
  460. end = time.time()
  461. last_idx = steps_per_epoch - 1
  462. num_updates = epoch * steps_per_epoch
  463. for batch_idx in range(steps_per_epoch):
  464. input, target = next(loader_iter)
  465. last_batch = batch_idx == last_idx
  466. data_time_m.update(time.time() - end)
  467. with torch.cuda.amp.autocast(enabled=use_amp):
  468. output = model(input, target)
  469. loss = output['loss']
  470. if not args.distributed:
  471. losses_m.update(loss.item(), input.size(0))
  472. scaler.scale(loss).backward()
  473. if args.clip_grad > 0:
  474. scaler.unscale_(optimizer)
  475. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad)
  476. scaler.step(optimizer)
  477. scaler.update()
  478. for p in model.parameters():
  479. p.grad = None
  480. torch.cuda.synchronize()
  481. if model_ema is not None:
  482. model_ema.update(model)
  483. num_updates += 1
  484. if batch_idx == 10:
  485. batch_time_m.reset()
  486. throughput_m.reset()
  487. batch_time_m.update(time.time() - end)
  488. throughput_m.update(float(input.size(0) * args.world_size / batch_time_m.val))
  489. if last_batch or (batch_idx+1) % args.log_interval == 0:
  490. lrl = [param_group['lr'] for param_group in optimizer.param_groups]
  491. lr = sum(lrl) / len(lrl)
  492. if args.distributed:
  493. reduced_loss = reduce_tensor(loss.data, args.world_size)
  494. losses_m.update(reduced_loss.item(), input.size(0))
  495. if args.rank == 0:
  496. dllogger_data = {'train_batch_time': batch_time_m.avg,
  497. 'train_loss': losses_m.avg,
  498. 'throughput': throughput_m.avg,
  499. 'lr': lr,
  500. 'train_data_time': data_time_m.avg}
  501. dllogger.log(step=(epoch, steps_per_epoch, batch_idx), data=dllogger_data, verbosity=0)
  502. if lr_scheduler is not None:
  503. lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
  504. end = time.time()
  505. if args.benchmark:
  506. if batch_idx >= args.benchmark_steps:
  507. break
  508. # end for
  509. if hasattr(optimizer, 'sync_lookahead'):
  510. optimizer.sync_lookahead()
  511. metrics = {'train_loss': losses_m.avg, 'train_batch_time': batch_time_m.avg, 'train_throughout': throughput_m.avg}
  512. dllogger.log(step=(epoch,), data=metrics, verbosity=0)
  513. return metrics
  514. def validate(model, loader, args, evaluator=None, epoch=0, log_suffix=''):
  515. batch_time_m = AverageMeter()
  516. losses_m = AverageMeter()
  517. throughput_m = AverageMeter()
  518. model.eval()
  519. end = time.time()
  520. last_idx = len(loader) - 1
  521. with torch.no_grad():
  522. for batch_idx, (input, target) in enumerate(loader):
  523. last_batch = batch_idx == last_idx
  524. with torch.cuda.amp.autocast(enabled=args.amp):
  525. output = model(input, target)
  526. loss = output['loss']
  527. if evaluator is not None:
  528. evaluator.add_predictions(output['detections'], target)
  529. if args.distributed:
  530. reduced_loss = reduce_tensor(loss.data, args.world_size)
  531. else:
  532. reduced_loss = loss.data
  533. torch.cuda.synchronize()
  534. losses_m.update(reduced_loss.item(), input.size(0))
  535. batch_time_m.update(time.time() - end)
  536. throughput_m.update(float(input.size(0) * args.world_size / batch_time_m.val))
  537. end = time.time()
  538. if args.rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
  539. log_name = 'Test' + log_suffix
  540. dllogger_data = {'eval_batch_time': batch_time_m.val, 'eval_loss': losses_m.val}
  541. dllogger.log(step=(epoch, last_idx, batch_idx), data=dllogger_data, verbosity=0)
  542. metrics = {'eval_batch_time': batch_time_m.avg, 'eval_throughput': throughput_m.avg, 'eval_loss': losses_m.avg}
  543. if evaluator is not None:
  544. metrics['map'] = evaluator.evaluate()
  545. if args.rank == 0:
  546. dllogger.log(step=(epoch,), data=metrics, verbosity=0)
  547. return metrics
  548. if __name__ == '__main__':
  549. torch.cuda.empty_cache()
  550. main()