train.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. #!/usr/bin/env python
  2. # Copyright (c) 2017 Elad Hoffer
  3. # Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Permission is hereby granted, free of charge, to any person obtaining a copy
  6. # of this software and associated documentation files (the "Software"), to deal
  7. # in the Software without restriction, including without limitation the rights
  8. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. # copies of the Software, and to permit persons to whom the Software is
  10. # furnished to do so, subject to the following conditions:
  11. #
  12. # The above copyright notice and this permission notice shall be included in all
  13. # copies or substantial portions of the Software.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. # SOFTWARE.
  22. import os
  23. os.environ['KMP_AFFINITY'] = 'disabled'
  24. import argparse
  25. import logging
  26. import sys
  27. import time
  28. from ast import literal_eval
  29. import dllogger
  30. import torch.nn as nn
  31. import torch.nn.parallel
  32. import torch.optim
  33. import torch.utils.data.distributed
  34. import seq2seq.data.config as config
  35. import seq2seq.gpu_affinity as gpu_affinity
  36. import seq2seq.train.trainer as trainers
  37. import seq2seq.utils as utils
  38. from seq2seq.data.dataset import LazyParallelDataset
  39. from seq2seq.data.dataset import ParallelDataset
  40. from seq2seq.data.dataset import TextDataset
  41. from seq2seq.data.tokenizer import Tokenizer
  42. from seq2seq.inference.translator import Translator
  43. from seq2seq.models.gnmt import GNMT
  44. from seq2seq.train.smoothing import LabelSmoothing
  45. from seq2seq.train.table import TrainingTable
  46. def parse_args():
  47. """
  48. Parse commandline arguments.
  49. """
  50. def exclusive_group(group, name, default, help):
  51. destname = name.replace('-', '_')
  52. subgroup = group.add_mutually_exclusive_group(required=False)
  53. subgroup.add_argument(f'--{name}', dest=f'{destname}',
  54. action='store_true',
  55. help=f'{help} (use \'--no-{name}\' to disable)')
  56. subgroup.add_argument(f'--no-{name}', dest=f'{destname}',
  57. action='store_false', help=argparse.SUPPRESS)
  58. subgroup.set_defaults(**{destname: default})
  59. parser = argparse.ArgumentParser(
  60. description='GNMT training',
  61. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  62. # dataset
  63. dataset = parser.add_argument_group('dataset setup')
  64. dataset.add_argument('--dataset-dir', default='data/wmt16_de_en',
  65. help='path to the directory with training/test data')
  66. dataset.add_argument('--src-lang',
  67. default='en',
  68. help='source language')
  69. dataset.add_argument('--tgt-lang',
  70. default='de',
  71. help='target language')
  72. dataset.add_argument('--vocab',
  73. default='vocab.bpe.32000',
  74. help='path to the vocabulary file \
  75. (relative to DATASET_DIR directory)')
  76. dataset.add_argument('-bpe', '--bpe-codes', default='bpe.32000',
  77. help='path to the file with bpe codes \
  78. (relative to DATASET_DIR directory)')
  79. dataset.add_argument('--train-src',
  80. default='train.tok.clean.bpe.32000.en',
  81. help='path to the training source data file \
  82. (relative to DATASET_DIR directory)')
  83. dataset.add_argument('--train-tgt',
  84. default='train.tok.clean.bpe.32000.de',
  85. help='path to the training target data file \
  86. (relative to DATASET_DIR directory)')
  87. dataset.add_argument('--val-src',
  88. default='newstest_dev.tok.clean.bpe.32000.en',
  89. help='path to the validation source data file \
  90. (relative to DATASET_DIR directory)')
  91. dataset.add_argument('--val-tgt',
  92. default='newstest_dev.tok.clean.bpe.32000.de',
  93. help='path to the validation target data file \
  94. (relative to DATASET_DIR directory)')
  95. dataset.add_argument('--test-src',
  96. default='newstest2014.tok.bpe.32000.en',
  97. help='path to the test source data file \
  98. (relative to DATASET_DIR directory)')
  99. dataset.add_argument('--test-tgt',
  100. default='newstest2014.de',
  101. help='path to the test target data file \
  102. (relative to DATASET_DIR directory)')
  103. # results
  104. results = parser.add_argument_group('results setup')
  105. results.add_argument('--save-dir', default='gnmt',
  106. help='path to directory with results, it will be \
  107. automatically created if it does not exist')
  108. results.add_argument('--print-freq', default=10, type=int,
  109. help='print log every PRINT_FREQ batches')
  110. results.add_argument('--warmup', default=1, type=int,
  111. help='number of warmup iterations for performance \
  112. counters')
  113. # model
  114. model = parser.add_argument_group('model setup')
  115. model.add_argument('--hidden-size', default=1024, type=int,
  116. help='hidden size of the model')
  117. model.add_argument('--num-layers', default=4, type=int,
  118. help='number of RNN layers in encoder and in decoder')
  119. model.add_argument('--dropout', default=0.2, type=float,
  120. help='dropout applied to input of RNN cells')
  121. exclusive_group(group=model, name='share-embedding', default=True,
  122. help='use shared embeddings for encoder and decoder')
  123. model.add_argument('--smoothing', default=0.1, type=float,
  124. help='label smoothing, if equal to zero model will use \
  125. CrossEntropyLoss, if not zero model will be trained \
  126. with label smoothing loss')
  127. # setup
  128. general = parser.add_argument_group('general setup')
  129. general.add_argument('--math', default='fp16',
  130. choices=['fp16', 'fp32', 'tf32', 'manual_fp16'],
  131. help='precision')
  132. general.add_argument('--seed', default=None, type=int,
  133. help='master seed for random number generators, if \
  134. "seed" is undefined then the master seed will be \
  135. sampled from random.SystemRandom()')
  136. general.add_argument('--prealloc-mode', default='always', type=str,
  137. choices=['off', 'once', 'always'],
  138. help='controls preallocation')
  139. general.add_argument('--dllog-file', type=str, default='train_log.json',
  140. help='Name of the DLLogger output file')
  141. general.add_argument('--affinity', type=str,
  142. default='socket_unique_interleaved',
  143. choices=['socket', 'single', 'single_unique',
  144. 'socket_unique_interleaved',
  145. 'socket_unique_continuous',
  146. 'disabled'],
  147. help='type of CPU affinity')
  148. exclusive_group(group=general, name='eval', default=True,
  149. help='run validation and test after every epoch')
  150. exclusive_group(group=general, name='env', default=True,
  151. help='print info about execution env')
  152. exclusive_group(group=general, name='cuda', default=True,
  153. help='enables cuda')
  154. exclusive_group(group=general, name='cudnn', default=True,
  155. help='enables cudnn')
  156. exclusive_group(group=general, name='log-all-ranks', default=True,
  157. help='enables logging from all distributed ranks, if \
  158. disabled then only logs from rank 0 are reported')
  159. # training
  160. training = parser.add_argument_group('training setup')
  161. dataset.add_argument('--train-max-size', default=None, type=int,
  162. help='use at most TRAIN_MAX_SIZE elements from \
  163. training dataset (useful for benchmarking), by \
  164. default uses entire dataset')
  165. training.add_argument('--train-batch-size', default=128, type=int,
  166. help='training batch size per worker')
  167. training.add_argument('--train-global-batch-size', default=None, type=int,
  168. help='global training batch size, this argument \
  169. does not have to be defined, if it is defined it \
  170. will be used to automatically \
  171. compute train_iter_size \
  172. using the equation: train_iter_size = \
  173. train_global_batch_size // (train_batch_size * \
  174. world_size)')
  175. training.add_argument('--train-iter-size', metavar='N', default=1,
  176. type=int,
  177. help='training iter size, training loop will \
  178. accumulate gradients over N iterations and execute \
  179. optimizer every N steps')
  180. training.add_argument('--epochs', default=6, type=int,
  181. help='max number of training epochs')
  182. training.add_argument('--grad-clip', default=5.0, type=float,
  183. help='enables gradient clipping and sets maximum \
  184. norm of gradients')
  185. training.add_argument('--train-max-length', default=50, type=int,
  186. help='maximum sequence length for training \
  187. (including special BOS and EOS tokens)')
  188. training.add_argument('--train-min-length', default=0, type=int,
  189. help='minimum sequence length for training \
  190. (including special BOS and EOS tokens)')
  191. training.add_argument('--train-loader-workers', default=2, type=int,
  192. help='number of workers for training data loading')
  193. training.add_argument('--batching', default='bucketing', type=str,
  194. choices=['random', 'sharding', 'bucketing'],
  195. help='select batching algorithm')
  196. training.add_argument('--shard-size', default=80, type=int,
  197. help='shard size for "sharding" batching algorithm, \
  198. in multiples of global batch size')
  199. training.add_argument('--num-buckets', default=5, type=int,
  200. help='number of buckets for "bucketing" batching \
  201. algorithm')
  202. # optimizer
  203. optimizer = parser.add_argument_group('optimizer setup')
  204. optimizer.add_argument('--optimizer', type=str, default='Adam',
  205. help='training optimizer')
  206. optimizer.add_argument('--lr', type=float, default=2.00e-3,
  207. help='learning rate')
  208. optimizer.add_argument('--optimizer-extra', type=str,
  209. default="{}",
  210. help='extra options for the optimizer')
  211. # mixed precision loss scaling
  212. loss_scaling = parser.add_argument_group(
  213. 'mixed precision loss scaling setup'
  214. )
  215. loss_scaling.add_argument('--init-scale', type=float, default=8192,
  216. help='initial loss scale')
  217. loss_scaling.add_argument('--upscale-interval', type=float, default=128,
  218. help='loss upscaling interval')
  219. # scheduler
  220. scheduler = parser.add_argument_group('learning rate scheduler setup')
  221. scheduler.add_argument('--warmup-steps', type=str, default='200',
  222. help='number of learning rate warmup iterations')
  223. scheduler.add_argument('--remain-steps', type=str, default='0.666',
  224. help='starting iteration for learning rate decay')
  225. scheduler.add_argument('--decay-interval', type=str, default='None',
  226. help='interval between learning rate decay steps')
  227. scheduler.add_argument('--decay-steps', type=int, default=4,
  228. help='max number of learning rate decay steps')
  229. scheduler.add_argument('--decay-factor', type=float, default=0.5,
  230. help='learning rate decay factor')
  231. # validation
  232. val = parser.add_argument_group('validation setup')
  233. val.add_argument('--val-batch-size', default=64, type=int,
  234. help='batch size for validation')
  235. val.add_argument('--val-max-length', default=125, type=int,
  236. help='maximum sequence length for validation \
  237. (including special BOS and EOS tokens)')
  238. val.add_argument('--val-min-length', default=0, type=int,
  239. help='minimum sequence length for validation \
  240. (including special BOS and EOS tokens)')
  241. val.add_argument('--val-loader-workers', default=0, type=int,
  242. help='number of workers for validation data loading')
  243. # test
  244. test = parser.add_argument_group('test setup')
  245. test.add_argument('--test-batch-size', default=128, type=int,
  246. help='batch size for test')
  247. test.add_argument('--test-max-length', default=150, type=int,
  248. help='maximum sequence length for test \
  249. (including special BOS and EOS tokens)')
  250. test.add_argument('--test-min-length', default=0, type=int,
  251. help='minimum sequence length for test \
  252. (including special BOS and EOS tokens)')
  253. test.add_argument('--beam-size', default=5, type=int,
  254. help='beam size')
  255. test.add_argument('--len-norm-factor', default=0.6, type=float,
  256. help='length normalization factor')
  257. test.add_argument('--cov-penalty-factor', default=0.1, type=float,
  258. help='coverage penalty factor')
  259. test.add_argument('--len-norm-const', default=5.0, type=float,
  260. help='length normalization constant')
  261. test.add_argument('--intra-epoch-eval', metavar='N', default=0, type=int,
  262. help='evaluate within training epoch, this option will \
  263. enable extra N equally spaced evaluations executed \
  264. during each training epoch')
  265. test.add_argument('--test-loader-workers', default=0, type=int,
  266. help='number of workers for test data loading')
  267. # checkpointing
  268. chkpt = parser.add_argument_group('checkpointing setup')
  269. chkpt.add_argument('--start-epoch', default=0, type=int,
  270. help='manually set initial epoch counter')
  271. chkpt.add_argument('--resume', default=None, type=str, metavar='PATH',
  272. help='resumes training from checkpoint from PATH')
  273. chkpt.add_argument('--save-all', action='store_true', default=False,
  274. help='saves checkpoint after every epoch')
  275. chkpt.add_argument('--save-freq', default=5000, type=int,
  276. help='save checkpoint every SAVE_FREQ batches')
  277. chkpt.add_argument('--keep-checkpoints', default=0, type=int,
  278. help='keep only last KEEP_CHECKPOINTS checkpoints, \
  279. affects only checkpoints controlled by --save-freq \
  280. option')
  281. # benchmarking
  282. benchmark = parser.add_argument_group('benchmark setup')
  283. benchmark.add_argument('--target-perf', default=None, type=float,
  284. help='target training performance (in tokens \
  285. per second)')
  286. benchmark.add_argument('--target-bleu', default=None, type=float,
  287. help='target accuracy')
  288. # distributed
  289. distributed = parser.add_argument_group('distributed setup')
  290. distributed.add_argument('--local_rank', type=int,
  291. default=os.getenv('LOCAL_RANK', 0),
  292. help='Used for multi-process training.')
  293. args = parser.parse_args()
  294. args.lang = {'src': args.src_lang, 'tgt': args.tgt_lang}
  295. args.vocab = os.path.join(args.dataset_dir, args.vocab)
  296. args.bpe_codes = os.path.join(args.dataset_dir, args.bpe_codes)
  297. args.train_src = os.path.join(args.dataset_dir, args.train_src)
  298. args.train_tgt = os.path.join(args.dataset_dir, args.train_tgt)
  299. args.val_src = os.path.join(args.dataset_dir, args.val_src)
  300. args.val_tgt = os.path.join(args.dataset_dir, args.val_tgt)
  301. args.test_src = os.path.join(args.dataset_dir, args.test_src)
  302. args.test_tgt = os.path.join(args.dataset_dir, args.test_tgt)
  303. args.warmup_steps = literal_eval(args.warmup_steps)
  304. args.remain_steps = literal_eval(args.remain_steps)
  305. args.decay_interval = literal_eval(args.decay_interval)
  306. return args
  307. def set_iter_size(train_iter_size, train_global_batch_size, train_batch_size):
  308. """
  309. Automatically set train_iter_size based on train_global_batch_size,
  310. world_size and per-worker train_batch_size
  311. :param train_global_batch_size: global training batch size
  312. :param train_batch_size: local training batch size
  313. """
  314. if train_global_batch_size is not None:
  315. global_bs = train_global_batch_size
  316. bs = train_batch_size
  317. world_size = utils.get_world_size()
  318. assert global_bs % (bs * world_size) == 0
  319. train_iter_size = global_bs // (bs * world_size)
  320. logging.info(f'Global batch size was set, '
  321. f'Setting train_iter_size to {train_iter_size}')
  322. return train_iter_size
  323. def build_criterion(vocab_size, padding_idx, smoothing):
  324. if smoothing == 0.:
  325. logging.info(f'Building CrossEntropyLoss')
  326. criterion = nn.CrossEntropyLoss(ignore_index=padding_idx, size_average=False)
  327. else:
  328. logging.info(f'Building LabelSmoothingLoss (smoothing: {smoothing})')
  329. criterion = LabelSmoothing(padding_idx, smoothing)
  330. return criterion
  331. def main():
  332. """
  333. Launches data-parallel multi-gpu training.
  334. """
  335. training_start = time.time()
  336. args = parse_args()
  337. if args.affinity != 'disabled':
  338. nproc_per_node = torch.cuda.device_count()
  339. affinity = gpu_affinity.set_affinity(
  340. args.local_rank,
  341. nproc_per_node,
  342. args.affinity
  343. )
  344. print(f'{args.local_rank}: thread affinity: {affinity}')
  345. device = utils.set_device(args.cuda, args.local_rank)
  346. utils.init_distributed(args.cuda)
  347. args.rank = utils.get_rank()
  348. if not args.cudnn:
  349. torch.backends.cudnn.enabled = False
  350. # create directory for results
  351. os.makedirs(args.save_dir, exist_ok=True)
  352. # setup logging
  353. log_filename = f'log_rank_{utils.get_rank()}.log'
  354. utils.setup_logging(args.log_all_ranks,
  355. os.path.join(args.save_dir, log_filename))
  356. dllog_file = os.path.join(args.save_dir, args.dllog_file)
  357. utils.setup_dllogger(enabled=True, filename=dllog_file)
  358. if args.env:
  359. utils.log_env_info()
  360. logging.info(f'Saving results to: {args.save_dir}')
  361. logging.info(f'Run arguments: {args}')
  362. dllogger.log(step='PARAMETER', data=vars(args))
  363. args.train_iter_size = set_iter_size(args.train_iter_size,
  364. args.train_global_batch_size,
  365. args.train_batch_size)
  366. worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed,
  367. args.epochs,
  368. device)
  369. worker_seed = worker_seeds[args.rank]
  370. logging.info(f'Worker {args.rank} is using worker seed: {worker_seed}')
  371. torch.manual_seed(worker_seed)
  372. # build tokenizer
  373. pad_vocab = utils.pad_vocabulary(args.math)
  374. tokenizer = Tokenizer(args.vocab, args.bpe_codes, args.lang, pad_vocab)
  375. # build datasets
  376. train_data = LazyParallelDataset(
  377. src_fname=args.train_src,
  378. tgt_fname=args.train_tgt,
  379. tokenizer=tokenizer,
  380. min_len=args.train_min_length,
  381. max_len=args.train_max_length,
  382. sort=False,
  383. max_size=args.train_max_size,
  384. )
  385. val_data = ParallelDataset(
  386. src_fname=args.val_src,
  387. tgt_fname=args.val_tgt,
  388. tokenizer=tokenizer,
  389. min_len=args.val_min_length,
  390. max_len=args.val_max_length,
  391. sort=True,
  392. )
  393. test_data = TextDataset(
  394. src_fname=args.test_src,
  395. tokenizer=tokenizer,
  396. min_len=args.test_min_length,
  397. max_len=args.test_max_length,
  398. sort=True,
  399. )
  400. vocab_size = tokenizer.vocab_size
  401. # build GNMT model
  402. model_config = {'hidden_size': args.hidden_size,
  403. 'vocab_size': vocab_size,
  404. 'num_layers': args.num_layers,
  405. 'dropout': args.dropout,
  406. 'batch_first': False,
  407. 'share_embedding': args.share_embedding,
  408. }
  409. model = GNMT(**model_config).to(device)
  410. logging.info(model)
  411. batch_first = model.batch_first
  412. # define loss function (criterion) and optimizer
  413. criterion = build_criterion(vocab_size, config.PAD,
  414. args.smoothing).to(device)
  415. opt_config = {'optimizer': args.optimizer, 'lr': args.lr}
  416. opt_config.update(literal_eval(args.optimizer_extra))
  417. logging.info(f'Training optimizer config: {opt_config}')
  418. scheduler_config = {'warmup_steps': args.warmup_steps,
  419. 'remain_steps': args.remain_steps,
  420. 'decay_interval': args.decay_interval,
  421. 'decay_steps': args.decay_steps,
  422. 'decay_factor': args.decay_factor}
  423. logging.info(f'Training LR schedule config: {scheduler_config}')
  424. num_parameters = sum([l.nelement() for l in model.parameters()])
  425. logging.info(f'Number of parameters: {num_parameters}')
  426. batching_opt = {'shard_size': args.shard_size,
  427. 'num_buckets': args.num_buckets}
  428. # get data loaders
  429. train_loader = train_data.get_loader(batch_size=args.train_batch_size,
  430. seeds=shuffling_seeds,
  431. batch_first=batch_first,
  432. shuffle=True,
  433. batching=args.batching,
  434. batching_opt=batching_opt,
  435. num_workers=args.train_loader_workers)
  436. val_loader = val_data.get_loader(batch_size=args.val_batch_size,
  437. batch_first=batch_first,
  438. shuffle=False,
  439. num_workers=args.val_loader_workers)
  440. test_loader = test_data.get_loader(batch_size=args.test_batch_size,
  441. batch_first=batch_first,
  442. shuffle=False,
  443. pad=True,
  444. num_workers=args.test_loader_workers)
  445. translator = Translator(model=model,
  446. tokenizer=tokenizer,
  447. loader=test_loader,
  448. beam_size=args.beam_size,
  449. max_seq_len=args.test_max_length,
  450. len_norm_factor=args.len_norm_factor,
  451. len_norm_const=args.len_norm_const,
  452. cov_penalty_factor=args.cov_penalty_factor,
  453. print_freq=args.print_freq,
  454. reference=args.test_tgt,
  455. )
  456. # create trainer
  457. total_train_iters = len(train_loader) // args.train_iter_size * args.epochs
  458. save_info = {
  459. 'model_config': model_config,
  460. 'config': args,
  461. 'tokenizer': tokenizer.get_state()
  462. }
  463. loss_scaling = {
  464. 'init_scale': args.init_scale,
  465. 'upscale_interval': args.upscale_interval
  466. }
  467. trainer_options = dict(
  468. model=model,
  469. criterion=criterion,
  470. grad_clip=args.grad_clip,
  471. iter_size=args.train_iter_size,
  472. save_dir=args.save_dir,
  473. save_freq=args.save_freq,
  474. save_info=save_info,
  475. opt_config=opt_config,
  476. scheduler_config=scheduler_config,
  477. train_iterations=total_train_iters,
  478. keep_checkpoints=args.keep_checkpoints,
  479. math=args.math,
  480. loss_scaling=loss_scaling,
  481. print_freq=args.print_freq,
  482. intra_epoch_eval=args.intra_epoch_eval,
  483. translator=translator,
  484. prealloc_mode=args.prealloc_mode,
  485. warmup=args.warmup,
  486. )
  487. trainer = trainers.Seq2SeqTrainer(**trainer_options)
  488. # optionally resume from a checkpoint
  489. if args.resume:
  490. checkpoint_file = args.resume
  491. if os.path.isdir(checkpoint_file):
  492. checkpoint_file = os.path.join(checkpoint_file, 'model_best.pth')
  493. if os.path.isfile(checkpoint_file):
  494. trainer.load(checkpoint_file)
  495. else:
  496. logging.error(f'No checkpoint found at {args.resume}')
  497. # training loop
  498. train_loss = float('inf')
  499. val_loss = float('inf')
  500. best_loss = float('inf')
  501. training_perf = []
  502. break_training = False
  503. test_bleu = None
  504. for epoch in range(args.start_epoch, args.epochs):
  505. logging.info(f'Starting epoch {epoch}')
  506. train_loader.sampler.set_epoch(epoch)
  507. trainer.epoch = epoch
  508. train_loss, train_perf = trainer.optimize(train_loader)
  509. training_perf.append(train_perf)
  510. # evaluate on validation set
  511. if args.eval:
  512. logging.info(f'Running validation on dev set')
  513. val_loss, val_perf = trainer.evaluate(val_loader)
  514. # remember best prec@1 and save checkpoint
  515. if args.rank == 0:
  516. is_best = val_loss < best_loss
  517. best_loss = min(val_loss, best_loss)
  518. trainer.save(save_all=args.save_all, is_best=is_best)
  519. if args.eval:
  520. utils.barrier()
  521. eval_fname = f'eval_epoch_{epoch}'
  522. eval_path = os.path.join(args.save_dir, eval_fname)
  523. _, eval_stats = translator.run(
  524. calc_bleu=True,
  525. epoch=epoch,
  526. eval_path=eval_path,
  527. )
  528. test_bleu = eval_stats['bleu']
  529. if args.target_bleu and test_bleu >= args.target_bleu:
  530. logging.info(f'Target accuracy reached')
  531. break_training = True
  532. acc_log = []
  533. acc_log += [f'Summary: Epoch: {epoch}']
  534. acc_log += [f'Training Loss: {train_loss:.4f}']
  535. if args.eval:
  536. acc_log += [f'Validation Loss: {val_loss:.4f}']
  537. acc_log += [f'Test BLEU: {test_bleu:.2f}']
  538. perf_log = []
  539. perf_log += [f'Performance: Epoch: {epoch}']
  540. perf_log += [f'Training: {train_perf:.0f} Tok/s']
  541. if args.eval:
  542. perf_log += [f'Validation: {val_perf:.0f} Tok/s']
  543. if args.rank == 0:
  544. logging.info('\t'.join(acc_log))
  545. logging.info('\t'.join(perf_log))
  546. logging.info(f'Finished epoch {epoch}')
  547. if break_training:
  548. break
  549. utils.barrier()
  550. training_stop = time.time()
  551. training_time = training_stop - training_start
  552. logging.info(f'Total training time {training_time:.0f} s')
  553. table = TrainingTable()
  554. avg_training_perf = len(training_perf) / sum(1 / v for v in training_perf)
  555. table.add(utils.get_world_size(), args.train_batch_size, test_bleu,
  556. avg_training_perf, training_time)
  557. if utils.get_rank() == 0:
  558. table.write('Training Summary', args.math)
  559. summary = {
  560. 'val_loss': val_loss,
  561. 'train_loss': train_loss,
  562. 'train_throughput': avg_training_perf,
  563. 'train_elapsed': training_time,
  564. 'test_bleu': test_bleu,
  565. }
  566. dllogger.log(step=tuple(), data=summary)
  567. passed = utils.benchmark(test_bleu, args.target_bleu,
  568. train_perf, args.target_perf)
  569. if not passed:
  570. sys.exit(1)
  571. if __name__ == '__main__':
  572. main()