train.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. # coding: utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import functools
  17. import itertools
  18. import logging
  19. import math
  20. import os
  21. import time
  22. import sys
  23. import numpy as np
  24. import torch
  25. import torch.nn as nn
  26. import torch.optim as optim
  27. from apex.parallel import DistributedDataParallel
  28. import lamb
  29. import utils
  30. from apex import amp
  31. from data_utils import get_lm_corpus
  32. from mem_transformer import MemTransformerLM
  33. from utils.data_parallel import BalancedDataParallel
  34. from utils.exp_utils import create_exp_dir
  35. from utils.exp_utils import benchmark
  36. from utils.exp_utils import AverageMeter
  37. def parse_args():
  38. parser = argparse.ArgumentParser(
  39. description='PyTorch Transformer-XL Language Model',
  40. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  41. )
  42. general = parser.add_argument_group('general setup')
  43. general.add_argument('--work_dir', default='LM-TFM', type=str,
  44. help='Directory for the results')
  45. general.add_argument('--append_dataset', action='store_true',
  46. help='Automatically append dataset name to work_dir')
  47. general.add_argument('--append_time', action='store_true',
  48. help='Automatically append current time to work_dir')
  49. general.add_argument('--cuda', action='store_true',
  50. help='Use CUDA')
  51. general.add_argument('--fp16', action='store_true',
  52. help='Run training in fp16/mixed precision')
  53. general.add_argument('--restart', type=str, default='',
  54. help='Restart training from the saved checkpoint')
  55. general.add_argument('--debug', action='store_true',
  56. help='Run in debug mode (do not create exp dir)')
  57. general.add_argument('--log_all_ranks', action='store_true',
  58. help='Enable logging from all distributed ranks')
  59. general.add_argument('--save-all', action='store_true',
  60. help='Save all checkpoints')
  61. general.add_argument('--log_interval', type=int, default=10,
  62. help='Report interval')
  63. general.add_argument('--target_throughput', type=float, default=None,
  64. help='Target training throughput (for benchmarking)')
  65. general.add_argument('--target_perplexity', type=float, default=None,
  66. help='Target validation perplexity (for benchmarking)')
  67. dataset = parser.add_argument_group('dataset setup')
  68. dataset.add_argument('--data', type=str, default='../data/wikitext-103',
  69. help='Location of the data corpus')
  70. dataset.add_argument('--dataset', type=str, default='wt103',
  71. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  72. help='Dataset name')
  73. dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
  74. help='Type of vocabulary')
  75. model = parser.add_argument_group('model setup')
  76. model.add_argument('--n_layer', type=int, default=16,
  77. help='Number of total layers')
  78. model.add_argument('--n_head', type=int, default=8,
  79. help='Number of heads')
  80. model.add_argument('--d_head', type=int, default=64,
  81. help='Head dimension')
  82. model.add_argument('--d_embed', type=int, default=-1,
  83. help='Embedding dimension')
  84. model.add_argument('--d_model', type=int, default=512,
  85. help='Model dimension')
  86. model.add_argument('--d_inner', type=int, default=2048,
  87. help='Inner dimension in feedforward layer')
  88. model.add_argument('--dropout', type=float, default=0.1,
  89. help='Global dropout rate')
  90. model.add_argument('--dropatt', type=float, default=0.0,
  91. help='Attention probability dropout rate')
  92. model.add_argument('--pre_lnorm', action='store_true',
  93. help='Apply LayerNorm to the input instead of the output')
  94. model.add_argument('--attn_type', type=int, default=0,
  95. help='Attention type. 0 for ours, 1 for Shaw et al,'
  96. '2 for Vaswani et al, 3 for Al Rfou et al.')
  97. model.add_argument('--not_tied', action='store_true',
  98. help='Do not tie the word embedding and softmax weights')
  99. model.add_argument('--clamp_len', type=int, default=-1,
  100. help='Use the same pos embeddings after clamp_len')
  101. model.add_argument('--adaptive', action='store_true',
  102. help='Use adaptive softmax')
  103. model.add_argument('--div_val', type=int, default=1,
  104. help='Dividend value for adaptive input and softmax')
  105. model.add_argument('--sample_softmax', type=int, default=-1,
  106. help='Number of samples in sampled softmax')
  107. model.add_argument('--init', default='normal', type=str,
  108. help='Parameter initializer to use')
  109. model.add_argument('--emb_init', default='normal', type=str,
  110. help='Parameter initializer to use')
  111. model.add_argument('--init_range', type=float, default=0.1,
  112. help='Parameters initialized by U(-init_range, init_range)')
  113. model.add_argument('--emb_init_range', type=float, default=0.01,
  114. help='Parameters initialized by U(-init_range, init_range)')
  115. model.add_argument('--init_std', type=float, default=0.02,
  116. help='Parameters initialized by N(0, init_std)')
  117. model.add_argument('--proj_init_std', type=float, default=0.01,
  118. help='Parameters initialized by N(0, init_std)')
  119. opt = parser.add_argument_group('optimizer setup')
  120. opt.add_argument('--optim', default='jitlamb', type=str,
  121. choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
  122. help='Optimizer to use')
  123. opt.add_argument('--lr', type=float, default=0.01,
  124. help='Initial learning rate')
  125. opt.add_argument('--mom', type=float, default=0.0,
  126. help='Momentum for sgd')
  127. opt.add_argument('--scheduler', default='cosine', type=str,
  128. choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
  129. help='LR scheduler to use')
  130. opt.add_argument('--max_step_scheduler', type=int, default=None,
  131. help='Max number of training steps for LR scheduler')
  132. opt.add_argument('--warmup_step', type=int, default=1000,
  133. help='Number of iterations for LR warmup')
  134. opt.add_argument('--decay_rate', type=float, default=0.5,
  135. help='Decay factor when ReduceLROnPlateau is used')
  136. opt.add_argument('--lr_min', type=float, default=0.0,
  137. help='Minimum learning rate during annealing')
  138. opt.add_argument('--clip', type=float, default=0.25,
  139. help='Gradient clipping')
  140. opt.add_argument('--weight_decay', type=float, default=0.0,
  141. help='Weight decay for adam|lamb')
  142. opt.add_argument('--clip_nonemb', action='store_true',
  143. help='Only clip the gradient of non-embedding params')
  144. opt.add_argument('--patience', type=int, default=0,
  145. help='Patience')
  146. opt.add_argument('--eta_min', type=float, default=0.001,
  147. help='Min learning rate for cosine scheduler')
  148. training = parser.add_argument_group('training setup')
  149. training.add_argument('--max_step', type=int, default=40000,
  150. help='Max number of training steps')
  151. training.add_argument('--batch_size', type=int, default=256,
  152. help='Global batch size')
  153. training.add_argument('--batch_chunk', type=int, default=1,
  154. help='Split batch into chunks to save memory')
  155. training.add_argument('--roll', action='store_true',
  156. help='Enable random shifts within each data stream')
  157. training.add_argument('--tgt_len', type=int, default=192,
  158. help='Number of tokens to predict')
  159. training.add_argument('--ext_len', type=int, default=0,
  160. help='Length of the extended context')
  161. training.add_argument('--mem_len', type=int, default=192,
  162. help='Length of the retained previous heads')
  163. training.add_argument('--seed', type=int, default=1111,
  164. help='Random seed')
  165. training.add_argument('--multi_gpu', default=None, type=str,
  166. choices=['ddp', 'dp'],
  167. help='Use multiple GPU')
  168. training.add_argument('--gpu0_bsz', type=int, default=-1,
  169. help='Batch size on gpu 0 (for "dp" backend)')
  170. training.add_argument('--same_length', action='store_true',
  171. help='Use the same attn length for all tokens')
  172. training.add_argument('--varlen', action='store_true',
  173. help='Use variable length')
  174. val = parser.add_argument_group('validation setup')
  175. val.add_argument('--eval_tgt_len', type=int, default=192,
  176. help='Number of tokens to predict for evaluation')
  177. val.add_argument('--eval_batch_size', type=int, default=16,
  178. help='Eval batch size')
  179. val.add_argument('--eval_max_steps', type=int, default=-1,
  180. help='Max eval steps')
  181. val.add_argument('--eval_interval', type=int, default=5000,
  182. help='Evaluation interval')
  183. dist = parser.add_argument_group('distributed setup')
  184. dist.add_argument('--local_rank', default=0, type=int,
  185. help='Used for multi-process training. ' +
  186. 'Can either be manually set ' +
  187. 'or automatically set by using \'python -m multiproc\'')
  188. args = parser.parse_args()
  189. args.tied = not args.not_tied
  190. if args.d_embed < 0:
  191. args.d_embed = args.d_model
  192. assert args.ext_len >= 0, 'extended context length must be non-negative'
  193. assert args.batch_size % args.batch_chunk == 0
  194. return args
  195. def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
  196. train_step, best_val_loss, work_dir, name='checkpoint.pt'):
  197. if args.fp16:
  198. amp_state = amp.state_dict()
  199. else:
  200. amp_state = None
  201. state = {
  202. 'args': args,
  203. 'model_config': model_config,
  204. 'model_state': model.state_dict(),
  205. 'optimizer_state': optimizer.state_dict(),
  206. 'scheduler_state': scheduler.state_dict(),
  207. 'vocab': vocab,
  208. 'amp_state': amp_state,
  209. 'train_step': train_step,
  210. 'best_val_loss': best_val_loss,
  211. }
  212. with utils.distributed.sync_workers() as rank:
  213. path = os.path.join(work_dir, name)
  214. logging.info(f'Saving checkpoint to {path}')
  215. if rank == 0:
  216. torch.save(state, path)
  217. def load_checkpoint(path):
  218. if os.path.isdir(path):
  219. path = os.path.join(path, 'checkpoint_last.pt')
  220. dst = f'cuda:{torch.cuda.current_device()}'
  221. logging.info(f'Loading checkpoint from {path}')
  222. checkpoint = torch.load(path, map_location=dst)
  223. return checkpoint
  224. def init_weight(weight, args):
  225. if args.init == 'uniform':
  226. nn.init.uniform_(weight, -args.init_range, args.init_range)
  227. elif args.init == 'normal':
  228. nn.init.normal_(weight, 0.0, args.init_std)
  229. def init_bias(bias):
  230. nn.init.constant_(bias, 0.0)
  231. def weights_init(m, args):
  232. classname = m.__class__.__name__
  233. if classname.find('Linear') != -1:
  234. if hasattr(m, 'weight') and m.weight is not None:
  235. init_weight(m.weight, args)
  236. if hasattr(m, 'bias') and m.bias is not None:
  237. init_bias(m.bias)
  238. elif classname.find('AdaptiveEmbedding') != -1:
  239. if hasattr(m, 'emb_projs'):
  240. for i in range(len(m.emb_projs)):
  241. if m.emb_projs[i] is not None:
  242. nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
  243. elif classname.find('Embedding') != -1:
  244. if hasattr(m, 'weight'):
  245. init_weight(m.weight, args)
  246. elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
  247. if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
  248. init_weight(m.cluster_weight, args)
  249. if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
  250. init_bias(m.cluster_bias)
  251. if hasattr(m, 'out_projs'):
  252. for i in range(len(m.out_projs)):
  253. if m.out_projs[i] is not None:
  254. nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
  255. elif classname.find('LayerNorm') != -1:
  256. if hasattr(m, 'weight'):
  257. nn.init.normal_(m.weight, 1.0, args.init_std)
  258. if hasattr(m, 'bias') and m.bias is not None:
  259. init_bias(m.bias)
  260. elif classname.find('TransformerLM') != -1:
  261. if hasattr(m, 'r_emb'):
  262. init_weight(m.r_emb, args)
  263. if hasattr(m, 'r_w_bias'):
  264. init_weight(m.r_w_bias, args)
  265. if hasattr(m, 'r_r_bias'):
  266. init_weight(m.r_r_bias, args)
  267. if hasattr(m, 'r_bias'):
  268. init_bias(m.r_bias)
  269. def update_dropout(m, args):
  270. classname = m.__class__.__name__
  271. if classname.find('Dropout') != -1:
  272. if hasattr(m, 'p'):
  273. m.p = args.dropout
  274. def update_dropatt(m, args):
  275. if hasattr(m, 'dropatt'):
  276. m.dropatt.p = args.dropatt
  277. def evaluate(eval_iter, model, args):
  278. # Turn on evaluation mode which disables dropout.
  279. model.eval()
  280. # If the model does not use memory at all, make the ext_len longer.
  281. # Otherwise, make the mem_len longer and keep the ext_len the same.
  282. if args.mem_len == 0:
  283. model.reset_length(tgt_len=args.eval_tgt_len,
  284. ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
  285. mem_len=args.mem_len
  286. )
  287. else:
  288. model.reset_length(tgt_len=args.eval_tgt_len,
  289. ext_len=args.ext_len,
  290. mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
  291. )
  292. # Evaluation
  293. total_len, total_loss = 0, 0.
  294. with torch.no_grad():
  295. mems = None
  296. for i, (data, target, seq_len) in enumerate(eval_iter):
  297. if args.eval_max_steps > 0 and i >= args.eval_max_steps:
  298. break
  299. ret = model(data, target, mems)
  300. loss, mems = ret[0], ret[1:]
  301. loss = loss.mean()
  302. total_loss += seq_len * loss.float().item()
  303. total_len += seq_len
  304. # Switch back to the training mode
  305. model.reset_length(tgt_len=args.tgt_len,
  306. ext_len=args.ext_len,
  307. mem_len=args.mem_len
  308. )
  309. model.train()
  310. return total_loss / total_len
  311. def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
  312. optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch, train_step,
  313. best_val_loss, meters, args):
  314. # Turn on training mode which enables dropout.
  315. model.train()
  316. train_loss = 0
  317. target_tokens = 0
  318. log_step = 0
  319. log_start_time = time.time()
  320. mems = [None for _ in range(args.batch_chunk)]
  321. train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
  322. for batch, (data, target, seq_len) in enumerate(train_iter):
  323. log_step += 1
  324. target_tokens += target.numel()
  325. model.zero_grad()
  326. data_chunks = torch.chunk(data, args.batch_chunk, 1)
  327. target_chunks = torch.chunk(target, args.batch_chunk, 1)
  328. for i in range(args.batch_chunk):
  329. data_i = data_chunks[i].contiguous()
  330. target_i = target_chunks[i].contiguous()
  331. ret = para_model(data_i, target_i, mems[i])
  332. loss, mems[i] = ret[0], ret[1:]
  333. loss = loss.float().mean().type_as(loss) / args.batch_chunk
  334. if args.fp16:
  335. with amp.scale_loss(loss, optimizer) as scaled_loss:
  336. scaled_loss.backward()
  337. else:
  338. loss.backward()
  339. train_loss += loss.float().item()
  340. if args.fp16:
  341. torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
  342. else:
  343. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
  344. optimizer.step()
  345. if optimizer_sparse:
  346. optimizer_sparse.step()
  347. # step-wise learning rate annealing
  348. train_step += 1
  349. if args.scheduler in ['cosine', 'constant', 'dev_perf']:
  350. # linear warmup stage
  351. if train_step < args.warmup_step:
  352. curr_lr = args.lr * train_step / args.warmup_step
  353. optimizer.param_groups[0]['lr'] = curr_lr
  354. if optimizer_sparse:
  355. optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
  356. else:
  357. if args.scheduler == 'cosine':
  358. scheduler.step(train_step)
  359. if scheduler_sparse:
  360. scheduler_sparse.step(train_step)
  361. elif args.scheduler == 'inv_sqrt':
  362. scheduler.step(train_step)
  363. if train_step % args.log_interval == 0:
  364. cur_loss = train_loss / log_step
  365. cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
  366. train_loss = 0
  367. elapsed = time.time() - log_start_time
  368. avg_elapsed = elapsed / log_step
  369. avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
  370. log_start_time = time.time()
  371. log_step = 0
  372. lr = optimizer.param_groups[0]['lr']
  373. throughput = target_tokens / elapsed
  374. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  375. meters['train_throughput'].update(throughput)
  376. target_tokens = 0
  377. log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
  378. '| ms/batch {:5.1f} | tok/s {:>7d} | loss {:5.2f}'.format(
  379. epoch,
  380. train_step,
  381. batch+1,
  382. tr_iter.n_batch,
  383. lr,
  384. avg_elapsed * 1000,
  385. int(throughput),
  386. cur_loss,
  387. )
  388. if args.dataset in ['enwik8', 'text8']:
  389. log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
  390. else:
  391. log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
  392. logging.info(log_str)
  393. if train_step % args.eval_interval == 0:
  394. eval_start_time = time.time()
  395. val_loss = evaluate(va_iter, model, args)
  396. val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
  397. logging.info('-' * 100)
  398. log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
  399. '| valid loss {:5.2f}'.format(
  400. train_step // args.eval_interval,
  401. train_step,
  402. (time.time() - eval_start_time),
  403. val_loss,
  404. )
  405. if args.dataset in ['enwik8', 'text8']:
  406. log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
  407. else:
  408. log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
  409. logging.info(log_str)
  410. logging.info('-' * 100)
  411. # Save the model if the validation loss is the best we've seen so far.
  412. if not best_val_loss or val_loss < best_val_loss:
  413. best_val_loss = val_loss
  414. if not args.debug:
  415. name = 'checkpoint_best.pt'
  416. save_checkpoint(args, model, model_config, optimizer,
  417. scheduler, vocab, train_step,
  418. best_val_loss, args.work_dir, name)
  419. # Always save after eval if save_all is true and not debug
  420. if not args.debug and args.save_all:
  421. name = f'checkpoint_{train_step}.pt'
  422. save_checkpoint(args, model, model_config, optimizer,
  423. scheduler, vocab, train_step, best_val_loss,
  424. args.work_dir, name)
  425. # Save last checkpoint if not debug and not save_all
  426. if not args.debug and not args.save_all:
  427. name = 'checkpoint_last.pt'
  428. save_checkpoint(args, model, model_config, optimizer,
  429. scheduler, vocab, train_step, best_val_loss,
  430. args.work_dir, name)
  431. # dev-performance based learning rate annealing
  432. if args.scheduler == 'dev_perf':
  433. scheduler.step(val_loss)
  434. if scheduler_sparse:
  435. scheduler_sparse.step(val_loss)
  436. # subtract eval time from timers for training
  437. log_start_time += time.time() - eval_start_time
  438. if train_step == args.max_step:
  439. break
  440. return train_step, best_val_loss
  441. def main():
  442. args = parse_args()
  443. # Initialize device and distributed backend
  444. torch.cuda.set_device(args.local_rank)
  445. device = torch.device('cuda' if args.cuda else 'cpu')
  446. utils.distributed.init_distributed(args.cuda)
  447. args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
  448. args.dataset,
  449. args.append_dataset,
  450. args.append_time,
  451. )
  452. with utils.distributed.sync_workers() as rank:
  453. if rank == 0:
  454. create_exp_dir(args.work_dir,
  455. scripts_to_save=['train.py', 'mem_transformer.py'],
  456. debug=args.debug)
  457. # Setup logging
  458. if args.log_all_ranks:
  459. log_file = f'log_rank_{utils.distributed.get_rank()}.log'
  460. else:
  461. log_file = f'log.log'
  462. log_file = os.path.join(args.work_dir, log_file)
  463. if args.debug:
  464. log_file = os.devnull
  465. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  466. filename=log_file,
  467. )
  468. logging.info(args)
  469. # Set the random seed manually for reproducibility.
  470. np.random.seed(args.seed + utils.distributed.get_rank())
  471. torch.manual_seed(args.seed + utils.distributed.get_rank())
  472. ###########################################################################
  473. # Load data
  474. ###########################################################################
  475. corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
  476. ntokens = len(corpus.vocab)
  477. vocab = corpus.vocab
  478. args.n_token = ntokens
  479. tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
  480. device=device, ext_len=args.ext_len)
  481. va_iter = corpus.get_iterator('valid', args.eval_batch_size, args.eval_tgt_len,
  482. device=device, ext_len=args.ext_len)
  483. te_iter = corpus.get_iterator('test', args.eval_batch_size, args.eval_tgt_len,
  484. device=device, ext_len=args.ext_len)
  485. # adaptive softmax / embedding
  486. cutoffs, tie_projs = [], [False]
  487. if args.adaptive:
  488. assert args.dataset in ['wt103', 'lm1b']
  489. if args.dataset == 'wt103':
  490. cutoffs = [19997, 39997, 199997]
  491. tie_projs += [True] * len(cutoffs)
  492. elif args.dataset == 'lm1b':
  493. cutoffs = [59997, 99997, 639997]
  494. tie_projs += [False] * len(cutoffs)
  495. ###########################################################################
  496. # Build the model
  497. ###########################################################################
  498. model_config = {
  499. 'n_token': ntokens,
  500. 'n_layer': args.n_layer,
  501. 'n_head': args.n_head,
  502. 'd_model': args.d_model,
  503. 'd_head': args.d_head,
  504. 'd_inner': args.d_inner,
  505. 'dropout': args.dropout,
  506. 'dropatt': args.dropatt,
  507. 'dtype': None,
  508. 'tie_weight': args.tied,
  509. 'd_embed': args.d_embed,
  510. 'div_val': args.div_val,
  511. 'tie_projs': tie_projs,
  512. 'pre_lnorm': args.pre_lnorm,
  513. 'tgt_len': args.tgt_len,
  514. 'ext_len': args.ext_len,
  515. 'mem_len': args.mem_len,
  516. 'cutoffs': cutoffs,
  517. 'same_length': args.same_length,
  518. 'attn_type': args.attn_type,
  519. 'clamp_len': args.clamp_len,
  520. 'sample_softmax': args.sample_softmax,
  521. }
  522. model = MemTransformerLM(**model_config)
  523. model.apply(functools.partial(weights_init, args=args))
  524. # ensure embedding init is not overridden by out_layer in case of weight sharing
  525. model.word_emb.apply(functools.partial(weights_init, args=args))
  526. args.n_all_param = sum([p.nelement() for p in model.parameters()])
  527. args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
  528. # optimizer
  529. if args.optim.lower() == 'sgd':
  530. if args.sample_softmax > 0:
  531. dense_params, sparse_params = [], []
  532. for param in model.parameters():
  533. if param.size() == model.word_emb.weight.size():
  534. sparse_params.append(param)
  535. else:
  536. dense_params.append(param)
  537. optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
  538. optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
  539. else:
  540. optimizer = optim.SGD(model.parameters(), lr=args.lr,
  541. momentum=args.mom)
  542. optimizer_sparse = None
  543. elif args.optim.lower() == 'adam':
  544. if args.sample_softmax > 0:
  545. dense_params, sparse_params = [], []
  546. for param in model.parameters():
  547. if param.size() == model.word_emb.weight.size():
  548. sparse_params.append(param)
  549. else:
  550. dense_params.append(param)
  551. optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
  552. optimizer = optim.Adam(dense_params, lr=args.lr,
  553. weight_decay=args.weight_decay)
  554. else:
  555. optimizer = optim.Adam(model.parameters(), lr=args.lr,
  556. weight_decay=args.weight_decay)
  557. optimizer_sparse = None
  558. elif args.optim.lower() == 'adagrad':
  559. optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
  560. optimizer_sparse = None
  561. elif args.optim.lower() == 'lamb':
  562. optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
  563. weight_decay=args.weight_decay)
  564. optimizer_sparse = None
  565. elif args.optim.lower() == 'jitlamb':
  566. optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
  567. weight_decay=args.weight_decay)
  568. optimizer_sparse = None
  569. model = model.to(device)
  570. if args.fp16:
  571. model, optimizer = amp.initialize(
  572. model,
  573. optimizer,
  574. opt_level='O2',
  575. )
  576. if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
  577. para_model = DistributedDataParallel(model,
  578. delay_allreduce=True,
  579. )
  580. elif args.multi_gpu == 'dp':
  581. if args.gpu0_bsz >= 0:
  582. para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
  583. model, dim=1).to(device)
  584. else:
  585. para_model = nn.DataParallel(model, dim=1).to(device)
  586. else:
  587. para_model = model
  588. # scheduler
  589. if args.scheduler == 'cosine':
  590. if args.max_step_scheduler:
  591. max_step = args.max_step_scheduler
  592. else:
  593. max_step = args.max_step
  594. scheduler = optim.lr_scheduler.CosineAnnealingLR(
  595. optimizer, max_step, eta_min=args.eta_min
  596. )
  597. if args.sample_softmax > 0:
  598. scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
  599. optimizer_sparse, max_step, eta_min=args.eta_min
  600. )
  601. else:
  602. scheduler_sparse = None
  603. elif args.scheduler == 'inv_sqrt':
  604. # originally used for Transformer (in Attention is all you need)
  605. def lr_lambda(step):
  606. # return a multiplier instead of a learning rate
  607. if step == 0 and args.warmup_step == 0:
  608. return 1.
  609. else:
  610. return 1. / (step ** 0.5) if step > args.warmup_step \
  611. else step / (args.warmup_step ** 1.5)
  612. scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
  613. elif args.scheduler == 'dev_perf':
  614. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  615. optimizer, factor=args.decay_rate, patience=args.patience,
  616. min_lr=args.lr_min,
  617. )
  618. if args.sample_softmax > 0:
  619. scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
  620. optimizer_sparse, factor=args.decay_rate, patience=args.patience,
  621. min_lr=args.lr_min,
  622. )
  623. else:
  624. scheduler_sparse = None
  625. elif args.scheduler == 'constant':
  626. pass
  627. logging.info('=' * 100)
  628. for k, v in args.__dict__.items():
  629. logging.info(' - {} : {}'.format(k, v))
  630. logging.info('=' * 100)
  631. logging.info('#params = {}'.format(args.n_all_param))
  632. logging.info('#non emb params = {}'.format(args.n_nonemb_param))
  633. train_step = 0
  634. best_val_loss = None
  635. if args.restart:
  636. checkpoint = load_checkpoint(args.restart)
  637. model.load_state_dict(checkpoint['model_state'])
  638. optimizer.load_state_dict(checkpoint['optimizer_state'])
  639. scheduler.load_state_dict(checkpoint['scheduler_state'])
  640. if args.fp16:
  641. amp.load_state_dict(checkpoint['amp_state'])
  642. train_step = checkpoint['train_step']
  643. best_val_loss = checkpoint['best_val_loss']
  644. model.apply(functools.partial(update_dropout, args=args))
  645. model.apply(functools.partial(update_dropatt, args=args))
  646. meters = {}
  647. warmup = args.mem_len // args.tgt_len + 1
  648. meters['train_throughput'] = AverageMeter(warmup=warmup)
  649. ###########################################################################
  650. # Train
  651. ###########################################################################
  652. # Loop over epochs.
  653. # At any point you can hit Ctrl + C to break out of training early.
  654. start_time = time.time()
  655. try:
  656. for epoch in itertools.count(start=1):
  657. if args.roll:
  658. tr_iter.roll()
  659. train_step, best_val_loss = train(
  660. tr_iter, va_iter, model, para_model, model_config, optimizer,
  661. optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
  662. train_step, best_val_loss, meters, args
  663. )
  664. if train_step == args.max_step:
  665. logging.info('-' * 100)
  666. logging.info('End of training')
  667. break
  668. except KeyboardInterrupt:
  669. logging.info('-' * 100)
  670. logging.info('Exiting from training early')
  671. elapsed = time.time() - start_time
  672. ###########################################################################
  673. # Test
  674. ###########################################################################
  675. test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  676. if not args.debug and os.path.exists(test_path):
  677. # Load the best saved model.
  678. checkpoint = load_checkpoint(test_path)
  679. model.load_state_dict(checkpoint['model_state'])
  680. # Run on test data.
  681. test_start_time = time.time()
  682. test_loss = evaluate(te_iter, model, args)
  683. test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
  684. logging.info('=' * 100)
  685. if args.dataset in ['enwik8', 'text8']:
  686. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
  687. time.time() - test_start_time, test_loss, test_loss / math.log(2)))
  688. else:
  689. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
  690. time.time() - test_start_time, test_loss, math.exp(test_loss)))
  691. logging.info('=' * 100)
  692. logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
  693. logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
  694. if best_val_loss:
  695. val_perplexity = math.exp(best_val_loss)
  696. else:
  697. val_perplexity = None
  698. passed = benchmark(
  699. target_perplexity=args.target_perplexity,
  700. test_perplexity=val_perplexity,
  701. target_throughput=args.target_throughput,
  702. test_throughput=meters['train_throughput'].avg
  703. )
  704. if not passed:
  705. sys.exit(1)
  706. if __name__ == "__main__":
  707. main()