training.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579
  1. # Copyright (c) 2018-2019, NVIDIA CORPORATION
  2. # Copyright (c) 2017- Facebook, Inc
  3. #
  4. # All rights reserved.
  5. #
  6. # Redistribution and use in source and binary forms, with or without
  7. # modification, are permitted provided that the following conditions are met:
  8. #
  9. # * Redistributions of source code must retain the above copyright notice, this
  10. # list of conditions and the following disclaimer.
  11. #
  12. # * Redistributions in binary form must reproduce the above copyright notice,
  13. # this list of conditions and the following disclaimer in the documentation
  14. # and/or other materials provided with the distribution.
  15. #
  16. # * Neither the name of the copyright holder nor the names of its
  17. # contributors may be used to endorse or promote products derived from
  18. # this software without specific prior written permission.
  19. #
  20. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  23. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  24. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  25. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  26. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  27. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  28. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  29. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. import os
  31. import time
  32. import numpy as np
  33. import torch
  34. import torch.nn as nn
  35. from torch.autograd import Variable
  36. from . import logger as log
  37. from . import resnet as models
  38. from . import utils
  39. import dllogger
  40. try:
  41. from apex.parallel import DistributedDataParallel as DDP
  42. from apex.fp16_utils import *
  43. from apex import amp
  44. except ImportError:
  45. raise ImportError(
  46. "Please install apex from https://www.github.com/nvidia/apex to run this example."
  47. )
  48. ACC_METADATA = {"unit": "%", "format": ":.2f"}
  49. IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
  50. TIME_METADATA = {"unit": "s", "format": ":.5f"}
  51. LOSS_METADATA = {"format": ":.5f"}
  52. class ModelAndLoss(nn.Module):
  53. def __init__(
  54. self,
  55. arch,
  56. loss,
  57. pretrained_weights=None,
  58. cuda=True,
  59. fp16=False,
  60. memory_format=torch.contiguous_format,
  61. ):
  62. super(ModelAndLoss, self).__init__()
  63. self.arch = arch
  64. print("=> creating model '{}'".format(arch))
  65. model = models.build_resnet(arch[0], arch[1], arch[2])
  66. if pretrained_weights is not None:
  67. print("=> using pre-trained model from a file '{}'".format(arch))
  68. model.load_state_dict(pretrained_weights)
  69. if cuda:
  70. model = model.cuda().to(memory_format=memory_format)
  71. if fp16:
  72. model = network_to_half(model)
  73. # define loss function (criterion) and optimizer
  74. criterion = loss()
  75. if cuda:
  76. criterion = criterion.cuda()
  77. self.model = model
  78. self.loss = criterion
  79. def forward(self, data, target):
  80. output = self.model(data)
  81. loss = self.loss(output, target)
  82. return loss, output
  83. def distributed(self):
  84. self.model = DDP(self.model)
  85. def load_model_state(self, state):
  86. if not state is None:
  87. self.model.load_state_dict(state)
  88. def get_optimizer(
  89. parameters,
  90. fp16,
  91. lr,
  92. momentum,
  93. weight_decay,
  94. nesterov=False,
  95. state=None,
  96. static_loss_scale=1.0,
  97. dynamic_loss_scale=False,
  98. bn_weight_decay=False,
  99. ):
  100. if bn_weight_decay:
  101. print(" ! Weight decay applied to BN parameters ")
  102. optimizer = torch.optim.SGD(
  103. [v for n, v in parameters],
  104. lr,
  105. momentum=momentum,
  106. weight_decay=weight_decay,
  107. nesterov=nesterov,
  108. )
  109. else:
  110. print(" ! Weight decay NOT applied to BN parameters ")
  111. bn_params = [v for n, v in parameters if "bn" in n]
  112. rest_params = [v for n, v in parameters if not "bn" in n]
  113. print(len(bn_params))
  114. print(len(rest_params))
  115. optimizer = torch.optim.SGD(
  116. [
  117. {"params": bn_params, "weight_decay": 0},
  118. {"params": rest_params, "weight_decay": weight_decay},
  119. ],
  120. lr,
  121. momentum=momentum,
  122. weight_decay=weight_decay,
  123. nesterov=nesterov,
  124. )
  125. if fp16:
  126. optimizer = FP16_Optimizer(
  127. optimizer,
  128. static_loss_scale=static_loss_scale,
  129. dynamic_loss_scale=dynamic_loss_scale,
  130. verbose=False,
  131. )
  132. if not state is None:
  133. optimizer.load_state_dict(state)
  134. return optimizer
  135. def lr_policy(lr_fn, logger=None):
  136. if logger is not None:
  137. logger.register_metric(
  138. "lr", log.LR_METER(), verbosity=dllogger.Verbosity.VERBOSE
  139. )
  140. def _alr(optimizer, iteration, epoch):
  141. lr = lr_fn(iteration, epoch)
  142. if logger is not None:
  143. logger.log_metric("lr", lr)
  144. for param_group in optimizer.param_groups:
  145. param_group["lr"] = lr
  146. return _alr
  147. def lr_step_policy(base_lr, steps, decay_factor, warmup_length, logger=None):
  148. def _lr_fn(iteration, epoch):
  149. if epoch < warmup_length:
  150. lr = base_lr * (epoch + 1) / warmup_length
  151. else:
  152. lr = base_lr
  153. for s in steps:
  154. if epoch >= s:
  155. lr *= decay_factor
  156. return lr
  157. return lr_policy(_lr_fn, logger=logger)
  158. def lr_linear_policy(base_lr, warmup_length, epochs, logger=None):
  159. def _lr_fn(iteration, epoch):
  160. if epoch < warmup_length:
  161. lr = base_lr * (epoch + 1) / warmup_length
  162. else:
  163. e = epoch - warmup_length
  164. es = epochs - warmup_length
  165. lr = base_lr * (1 - (e / es))
  166. return lr
  167. return lr_policy(_lr_fn, logger=logger)
  168. def lr_cosine_policy(base_lr, warmup_length, epochs, logger=None):
  169. def _lr_fn(iteration, epoch):
  170. if epoch < warmup_length:
  171. lr = base_lr * (epoch + 1) / warmup_length
  172. else:
  173. e = epoch - warmup_length
  174. es = epochs - warmup_length
  175. lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
  176. return lr
  177. return lr_policy(_lr_fn, logger=logger)
  178. def lr_exponential_policy(
  179. base_lr, warmup_length, epochs, final_multiplier=0.001, logger=None
  180. ):
  181. es = epochs - warmup_length
  182. epoch_decay = np.power(2, np.log2(final_multiplier) / es)
  183. def _lr_fn(iteration, epoch):
  184. if epoch < warmup_length:
  185. lr = base_lr * (epoch + 1) / warmup_length
  186. else:
  187. e = epoch - warmup_length
  188. lr = base_lr * (epoch_decay ** e)
  189. return lr
  190. return lr_policy(_lr_fn, logger=logger)
  191. def get_train_step(
  192. model_and_loss, optimizer, fp16, use_amp=False, batch_size_multiplier=1
  193. ):
  194. def _step(input, target, optimizer_step=True):
  195. input_var = Variable(input)
  196. target_var = Variable(target)
  197. loss, output = model_and_loss(input_var, target_var)
  198. if torch.distributed.is_initialized():
  199. reduced_loss = utils.reduce_tensor(loss.data)
  200. else:
  201. reduced_loss = loss.data
  202. if fp16:
  203. optimizer.backward(loss)
  204. elif use_amp:
  205. with amp.scale_loss(loss, optimizer) as scaled_loss:
  206. scaled_loss.backward()
  207. else:
  208. loss.backward()
  209. if optimizer_step:
  210. opt = (
  211. optimizer.optimizer
  212. if isinstance(optimizer, FP16_Optimizer)
  213. else optimizer
  214. )
  215. for param_group in opt.param_groups:
  216. for param in param_group["params"]:
  217. param.grad /= batch_size_multiplier
  218. optimizer.step()
  219. optimizer.zero_grad()
  220. torch.cuda.synchronize()
  221. return reduced_loss
  222. return _step
  223. def train(
  224. train_loader,
  225. model_and_loss,
  226. optimizer,
  227. lr_scheduler,
  228. fp16,
  229. logger,
  230. epoch,
  231. use_amp=False,
  232. prof=-1,
  233. batch_size_multiplier=1,
  234. register_metrics=True,
  235. ):
  236. if register_metrics and logger is not None:
  237. logger.register_metric(
  238. "train.loss",
  239. log.LOSS_METER(),
  240. verbosity=dllogger.Verbosity.DEFAULT,
  241. metadata=LOSS_METADATA,
  242. )
  243. logger.register_metric(
  244. "train.compute_ips",
  245. log.PERF_METER(),
  246. verbosity=dllogger.Verbosity.VERBOSE,
  247. metadata=IPS_METADATA,
  248. )
  249. logger.register_metric(
  250. "train.total_ips",
  251. log.PERF_METER(),
  252. verbosity=dllogger.Verbosity.DEFAULT,
  253. metadata=IPS_METADATA,
  254. )
  255. logger.register_metric(
  256. "train.data_time",
  257. log.PERF_METER(),
  258. verbosity=dllogger.Verbosity.VERBOSE,
  259. metadata=TIME_METADATA,
  260. )
  261. logger.register_metric(
  262. "train.compute_time",
  263. log.PERF_METER(),
  264. verbosity=dllogger.Verbosity.VERBOSE,
  265. metadata=TIME_METADATA,
  266. )
  267. step = get_train_step(
  268. model_and_loss,
  269. optimizer,
  270. fp16,
  271. use_amp=use_amp,
  272. batch_size_multiplier=batch_size_multiplier,
  273. )
  274. model_and_loss.train()
  275. end = time.time()
  276. optimizer.zero_grad()
  277. data_iter = enumerate(train_loader)
  278. if logger is not None:
  279. data_iter = logger.iteration_generator_wrapper(data_iter)
  280. if prof > 0:
  281. data_iter = utils.first_n(prof, data_iter)
  282. for i, (input, target) in data_iter:
  283. bs = input.size(0)
  284. lr_scheduler(optimizer, i, epoch)
  285. data_time = time.time() - end
  286. optimizer_step = ((i + 1) % batch_size_multiplier) == 0
  287. loss = step(input, target, optimizer_step=optimizer_step)
  288. it_time = time.time() - end
  289. if logger is not None:
  290. logger.log_metric("train.loss", to_python_float(loss), bs)
  291. logger.log_metric("train.compute_ips", calc_ips(bs, it_time - data_time))
  292. logger.log_metric("train.total_ips", calc_ips(bs, it_time))
  293. logger.log_metric("train.data_time", data_time)
  294. logger.log_metric("train.compute_time", it_time - data_time)
  295. end = time.time()
  296. def get_val_step(model_and_loss):
  297. def _step(input, target):
  298. input_var = Variable(input)
  299. target_var = Variable(target)
  300. with torch.no_grad():
  301. loss, output = model_and_loss(input_var, target_var)
  302. prec1, prec5 = utils.accuracy(output.data, target, topk=(1, 5))
  303. if torch.distributed.is_initialized():
  304. reduced_loss = utils.reduce_tensor(loss.data)
  305. prec1 = utils.reduce_tensor(prec1)
  306. prec5 = utils.reduce_tensor(prec5)
  307. else:
  308. reduced_loss = loss.data
  309. torch.cuda.synchronize()
  310. return reduced_loss, prec1, prec5
  311. return _step
  312. def validate(
  313. val_loader, model_and_loss, fp16, logger, epoch, prof=-1, register_metrics=True
  314. ):
  315. if register_metrics and logger is not None:
  316. logger.register_metric(
  317. "val.top1",
  318. log.ACC_METER(),
  319. verbosity=dllogger.Verbosity.DEFAULT,
  320. metadata=ACC_METADATA,
  321. )
  322. logger.register_metric(
  323. "val.top5",
  324. log.ACC_METER(),
  325. verbosity=dllogger.Verbosity.DEFAULT,
  326. metadata=ACC_METADATA,
  327. )
  328. logger.register_metric(
  329. "val.loss",
  330. log.LOSS_METER(),
  331. verbosity=dllogger.Verbosity.DEFAULT,
  332. metadata=LOSS_METADATA,
  333. )
  334. logger.register_metric(
  335. "val.compute_ips",
  336. log.PERF_METER(),
  337. verbosity=dllogger.Verbosity.VERBOSE,
  338. metadata=IPS_METADATA,
  339. )
  340. logger.register_metric(
  341. "val.total_ips",
  342. log.PERF_METER(),
  343. verbosity=dllogger.Verbosity.DEFAULT,
  344. metadata=IPS_METADATA,
  345. )
  346. logger.register_metric(
  347. "val.data_time",
  348. log.PERF_METER(),
  349. verbosity=dllogger.Verbosity.VERBOSE,
  350. metadata=TIME_METADATA,
  351. )
  352. logger.register_metric(
  353. "val.compute_latency",
  354. log.PERF_METER(),
  355. verbosity=dllogger.Verbosity.VERBOSE,
  356. metadata=TIME_METADATA,
  357. )
  358. logger.register_metric(
  359. "val.compute_latency_at100",
  360. log.LAT_100(),
  361. verbosity=dllogger.Verbosity.VERBOSE,
  362. metadata=TIME_METADATA,
  363. )
  364. logger.register_metric(
  365. "val.compute_latency_at99",
  366. log.LAT_99(),
  367. verbosity=dllogger.Verbosity.VERBOSE,
  368. metadata=TIME_METADATA,
  369. )
  370. logger.register_metric(
  371. "val.compute_latency_at95",
  372. log.LAT_95(),
  373. verbosity=dllogger.Verbosity.VERBOSE,
  374. metadata=TIME_METADATA,
  375. )
  376. step = get_val_step(model_and_loss)
  377. top1 = log.AverageMeter()
  378. # switch to evaluate mode
  379. model_and_loss.eval()
  380. end = time.time()
  381. data_iter = enumerate(val_loader)
  382. if not logger is None:
  383. data_iter = logger.iteration_generator_wrapper(data_iter, val=True)
  384. if prof > 0:
  385. data_iter = utils.first_n(prof, data_iter)
  386. for i, (input, target) in data_iter:
  387. bs = input.size(0)
  388. data_time = time.time() - end
  389. loss, prec1, prec5 = step(input, target)
  390. it_time = time.time() - end
  391. top1.record(to_python_float(prec1), bs)
  392. if logger is not None:
  393. logger.log_metric("val.top1", to_python_float(prec1), bs)
  394. logger.log_metric("val.top5", to_python_float(prec5), bs)
  395. logger.log_metric("val.loss", to_python_float(loss), bs)
  396. logger.log_metric("val.compute_ips", calc_ips(bs, it_time - data_time))
  397. logger.log_metric("val.total_ips", calc_ips(bs, it_time))
  398. logger.log_metric("val.data_time", data_time)
  399. logger.log_metric("val.compute_latency", it_time - data_time)
  400. logger.log_metric("val.compute_latency_at95", it_time - data_time)
  401. logger.log_metric("val.compute_latency_at99", it_time - data_time)
  402. logger.log_metric("val.compute_latency_at100", it_time - data_time)
  403. end = time.time()
  404. return top1.get_val()
  405. # Train loop {{{
  406. def calc_ips(batch_size, time):
  407. world_size = (
  408. torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
  409. )
  410. tbs = world_size * batch_size
  411. return tbs / time
  412. def train_loop(
  413. model_and_loss,
  414. optimizer,
  415. lr_scheduler,
  416. train_loader,
  417. val_loader,
  418. fp16,
  419. logger,
  420. should_backup_checkpoint,
  421. use_amp=False,
  422. batch_size_multiplier=1,
  423. best_prec1=0,
  424. start_epoch=0,
  425. end_epoch=0,
  426. prof=-1,
  427. skip_training=False,
  428. skip_validation=False,
  429. save_checkpoints=True,
  430. checkpoint_dir="./",
  431. checkpoint_filename="checkpoint.pth.tar",
  432. ):
  433. prec1 = -1
  434. print(f"RUNNING EPOCHS FROM {start_epoch} TO {end_epoch}")
  435. for epoch in range(start_epoch, end_epoch):
  436. if logger is not None:
  437. logger.start_epoch()
  438. if not skip_training:
  439. train(
  440. train_loader,
  441. model_and_loss,
  442. optimizer,
  443. lr_scheduler,
  444. fp16,
  445. logger,
  446. epoch,
  447. use_amp=use_amp,
  448. prof=prof,
  449. register_metrics=epoch == start_epoch,
  450. batch_size_multiplier=batch_size_multiplier,
  451. )
  452. if not skip_validation:
  453. prec1, nimg = validate(
  454. val_loader,
  455. model_and_loss,
  456. fp16,
  457. logger,
  458. epoch,
  459. prof=prof,
  460. register_metrics=epoch == start_epoch,
  461. )
  462. if logger is not None:
  463. logger.end_epoch()
  464. if save_checkpoints and (
  465. not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
  466. ):
  467. if not skip_validation:
  468. is_best = logger.metrics["val.top1"]["meter"].get_epoch() > best_prec1
  469. best_prec1 = max(
  470. logger.metrics["val.top1"]["meter"].get_epoch(), best_prec1
  471. )
  472. else:
  473. is_best = False
  474. best_prec1 = 0
  475. if should_backup_checkpoint(epoch):
  476. backup_filename = "checkpoint-{}.pth.tar".format(epoch + 1)
  477. else:
  478. backup_filename = None
  479. utils.save_checkpoint(
  480. {
  481. "epoch": epoch + 1,
  482. "arch": model_and_loss.arch,
  483. "state_dict": model_and_loss.model.state_dict(),
  484. "best_prec1": best_prec1,
  485. "optimizer": optimizer.state_dict(),
  486. },
  487. is_best,
  488. checkpoint_dir=checkpoint_dir,
  489. backup_filename=backup_filename,
  490. filename=checkpoint_filename,
  491. )
  492. # }}}