train.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. # coding: utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import functools
  17. import itertools
  18. import logging
  19. import math
  20. import os
  21. import shutil
  22. import sys
  23. import time
  24. import dllogger
  25. import numpy as np
  26. import torch
  27. import torch.nn as nn
  28. import torch.optim as optim
  29. import yaml
  30. from apex import amp
  31. from torch.nn.parallel import DistributedDataParallel
  32. import lamb
  33. import utils
  34. from data_utils import get_lm_corpus
  35. from mem_transformer import MemTransformerLM
  36. from utils.data_parallel import BalancedDataParallel
  37. from utils.exp_utils import AverageMeter
  38. from utils.exp_utils import TimeoutHandler
  39. from utils.exp_utils import benchmark
  40. from utils.exp_utils import create_exp_dir
  41. from utils.exp_utils import l2_promote
  42. from utils.exp_utils import log_env_info
  43. from utils.exp_utils import register_ignoring_timeout_handler
  44. def parse_args():
  45. parent_parser = argparse.ArgumentParser(
  46. description='PyTorch Transformer-XL Language Model',
  47. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  48. add_help=False,
  49. )
  50. parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
  51. cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
  52. cfg_parser.add_argument('--config', default='default')
  53. cfg_parser.add_argument('--config_file', default=None)
  54. config_args, _ = cfg_parser.parse_known_args()
  55. if config_args.config is not None and config_args.config_file is not None:
  56. with open(config_args.config_file) as f:
  57. config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
  58. else:
  59. config = {}
  60. general = parser.add_argument_group('general setup')
  61. general.add_argument('--work_dir', default='LM-TFM', type=str,
  62. help='Directory for the results')
  63. general.add_argument('--append_dataset', action='store_true',
  64. help='Automatically append dataset name to work_dir')
  65. general.add_argument('--append_time', action='store_true',
  66. help='Automatically append current time to work_dir')
  67. general.add_argument('--cuda', action='store_true',
  68. help='Run training on a GPU using CUDA')
  69. general.add_argument('--fp16', action='store_true',
  70. help='Run training in fp16/mixed precision')
  71. general.add_argument('--restart', type=str, default='',
  72. help='Restart training from the saved checkpoint')
  73. general.add_argument('--debug', action='store_true',
  74. help='Run in debug mode (do not create exp dir)')
  75. general.add_argument('--log_all_ranks', action='store_true',
  76. help='Enable logging from all distributed ranks')
  77. general.add_argument('--dllog_file', type=str, default='train_log.json',
  78. help='Name of the DLLogger output file')
  79. general.add_argument('--txtlog_file', type=str, default='train_log.log',
  80. help='Name of the txt log file')
  81. general.add_argument('--save_all', action='store_true',
  82. help='Save all checkpoints')
  83. general.add_argument('--no_env', action='store_true',
  84. help='Do not print info on execution env')
  85. general.add_argument('--no_eval', action='store_true',
  86. help='Disable model evaluation')
  87. general.add_argument('--log_interval', type=int, default=10,
  88. help='Report interval')
  89. general.add_argument('--target_throughput', type=float, default=None,
  90. help='Target training throughput (for benchmarking)')
  91. general.add_argument('--target_perplexity', type=float, default=None,
  92. help='Target validation perplexity (for benchmarking)')
  93. general.add_argument('--amp_mode', type=str, default='O2',
  94. choices=['O0', 'O1', 'O2', 'O3'],
  95. help='Optimization level for apex amp')
  96. dataset = parser.add_argument_group('dataset setup')
  97. dataset.add_argument('--data', type=str, default='../data/wikitext-103',
  98. help='Location of the data corpus')
  99. dataset.add_argument('--dataset', type=str, default='wt103',
  100. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  101. help='Dataset name')
  102. dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
  103. help='Type of vocabulary')
  104. model = parser.add_argument_group('model setup')
  105. model.add_argument('--n_layer', type=int, default=16,
  106. help='Number of total layers')
  107. model.add_argument('--n_head', type=int, default=8,
  108. help='Number of heads')
  109. model.add_argument('--d_head', type=int, default=64,
  110. help='Head dimension')
  111. model.add_argument('--d_embed', type=int, default=-1,
  112. help='Embedding dimension')
  113. model.add_argument('--d_model', type=int, default=512,
  114. help='Model dimension')
  115. model.add_argument('--d_inner', type=int, default=2048,
  116. help='Inner dimension in feedforward layer')
  117. model.add_argument('--dropout', type=float, default=0.1,
  118. help='Global dropout rate')
  119. model.add_argument('--dropatt', type=float, default=0.0,
  120. help='Attention probability dropout rate')
  121. model.add_argument('--pre_lnorm', action='store_true',
  122. help='Apply LayerNorm to the input instead of the output')
  123. model.add_argument('--attn_type', type=int, default=0,
  124. help='Attention type. 0 for ours, 1 for Shaw et al,'
  125. '2 for Vaswani et al, 3 for Al Rfou et al.')
  126. model.add_argument('--not_tied', action='store_true',
  127. help='Do not tie the word embedding and softmax weights')
  128. model.add_argument('--clamp_len', type=int, default=-1,
  129. help='Use the same pos embeddings after clamp_len')
  130. model.add_argument('--adaptive', action='store_true',
  131. help='Use adaptive softmax')
  132. model.add_argument('--div_val', type=int, default=1,
  133. help='Dividend value for adaptive input and softmax')
  134. model.add_argument('--sample_softmax', type=int, default=-1,
  135. help='Number of samples in sampled softmax')
  136. model.add_argument('--init', default='normal', type=str,
  137. help='Parameter initializer to use')
  138. model.add_argument('--emb_init', default='normal', type=str,
  139. help='Parameter initializer to use')
  140. model.add_argument('--init_range', type=float, default=0.1,
  141. help='Parameters initialized by U(-init_range, init_range)')
  142. model.add_argument('--emb_init_range', type=float, default=0.01,
  143. help='Parameters initialized by U(-init_range, init_range)')
  144. model.add_argument('--init_std', type=float, default=0.02,
  145. help='Parameters initialized by N(0, init_std)')
  146. model.add_argument('--proj_init_std', type=float, default=0.01,
  147. help='Parameters initialized by N(0, init_std)')
  148. opt = parser.add_argument_group('optimizer setup')
  149. opt.add_argument('--optim', default='jitlamb', type=str,
  150. choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
  151. help='Optimizer to use')
  152. opt.add_argument('--lr', type=float, default=0.01,
  153. help='Initial learning rate')
  154. opt.add_argument('--mom', type=float, default=0.0,
  155. help='Momentum for sgd')
  156. opt.add_argument('--scheduler', default='cosine', type=str,
  157. choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
  158. help='LR scheduler to use')
  159. opt.add_argument('--max_step_scheduler', type=int, default=None,
  160. help='Max number of training steps for LR scheduler')
  161. opt.add_argument('--warmup_step', type=int, default=1000,
  162. help='Number of iterations for LR warmup')
  163. opt.add_argument('--decay_rate', type=float, default=0.5,
  164. help='Decay factor when ReduceLROnPlateau is used')
  165. opt.add_argument('--lr_min', type=float, default=0.0,
  166. help='Minimum learning rate during annealing')
  167. opt.add_argument('--clip', type=float, default=0.25,
  168. help='Gradient clipping')
  169. opt.add_argument('--weight_decay', type=float, default=0.0,
  170. help='Weight decay for adam|lamb')
  171. opt.add_argument('--clip_nonemb', action='store_true',
  172. help='Only clip the gradient of non-embedding params')
  173. opt.add_argument('--patience', type=int, default=0,
  174. help='Patience')
  175. opt.add_argument('--eta_min', type=float, default=0.001,
  176. help='Min learning rate for cosine scheduler')
  177. training = parser.add_argument_group('training setup')
  178. training.add_argument('--max_step', type=int, default=40000,
  179. help='Max number of training steps')
  180. training.add_argument('--batch_size', type=int, default=256,
  181. help='Global batch size')
  182. training.add_argument('--local_batch_size', type=int, default=None,
  183. help='Local (per-device) batch size, this setting \
  184. overrides global --batch_size and sets batch_size \
  185. to local_batch_size * world_size')
  186. training.add_argument('--batch_chunk', type=int, default=1,
  187. help='Split batch into chunks and train with '
  188. 'gradient accumulation')
  189. training.add_argument('--roll', action='store_true',
  190. help='Enable random shifts within each data stream')
  191. training.add_argument('--tgt_len', type=int, default=192,
  192. help='Number of tokens to predict')
  193. training.add_argument('--ext_len', type=int, default=0,
  194. help='Length of the extended context')
  195. training.add_argument('--mem_len', type=int, default=192,
  196. help='Length of the retained previous heads')
  197. training.add_argument('--seed', type=int, default=1111,
  198. help='Random seed')
  199. training.add_argument('--multi_gpu', default=None, type=str,
  200. choices=['ddp', 'dp'],
  201. help='Use multiple GPU')
  202. training.add_argument('--gpu0_bsz', type=int, default=-1,
  203. help='Batch size on gpu 0 (for "dp" backend)')
  204. training.add_argument('--same_length', action='store_true',
  205. help='Use the same attn length for all tokens')
  206. training.add_argument('--varlen', action='store_true',
  207. help='Use variable length')
  208. val = parser.add_argument_group('validation setup')
  209. val.add_argument('--eval_tgt_len', type=int, default=192,
  210. help='Number of tokens to predict for evaluation')
  211. val.add_argument('--eval_batch_size', type=int, default=16,
  212. help='Eval batch size')
  213. val.add_argument('--eval_max_steps', type=int, default=-1,
  214. help='Max eval steps')
  215. val.add_argument('--eval_interval', type=int, default=5000,
  216. help='Evaluation interval')
  217. dist = parser.add_argument_group('distributed setup')
  218. dist.add_argument('--local_rank', type=int,
  219. default=os.getenv('LOCAL_RANK', 0),
  220. help='Used for multi-process training.')
  221. parser.set_defaults(**config)
  222. args, _ = parser.parse_known_args()
  223. args.tied = not args.not_tied
  224. if args.d_embed < 0:
  225. args.d_embed = args.d_model
  226. assert args.ext_len >= 0, 'extended context length must be non-negative'
  227. assert args.batch_size % args.batch_chunk == 0
  228. return args
  229. def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
  230. epoch, batch, last_iter, train_step, best_val_loss,
  231. is_best, work_dir):
  232. if args.fp16:
  233. amp_state = amp.state_dict()
  234. else:
  235. amp_state = None
  236. state = {
  237. 'args': args,
  238. 'model_config': model_config,
  239. 'model_state': model.state_dict(),
  240. 'optimizer_state': optimizer.state_dict(),
  241. 'scheduler_state': scheduler.state_dict(),
  242. 'vocab': vocab,
  243. 'amp_state': amp_state,
  244. 'epoch': epoch,
  245. 'batch': batch,
  246. 'last_iter': last_iter,
  247. 'train_step': train_step,
  248. 'best_val_loss': best_val_loss,
  249. }
  250. last_chkpt_fname = 'checkpoint_last.pt'
  251. with utils.distributed.sync_workers() as rank:
  252. last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
  253. if rank == 0:
  254. # always save last checkpoint
  255. logging.info(f'Saving checkpoint to {last_chkpt_path}')
  256. torch.save(state, last_chkpt_path)
  257. # save best checkpoint if better than previous best
  258. if is_best:
  259. best_chkpt_fname = 'checkpoint_best.pt'
  260. best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
  261. logging.info(f'Saving checkpoint to {best_chkpt_path}')
  262. shutil.copy(last_chkpt_path, best_chkpt_path)
  263. # save every checkpoint if save_all is true
  264. if args.save_all:
  265. step_chkpt_fname = f'checkpoint_{train_step}.pt'
  266. step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
  267. logging.info(f'Saving checkpoint to {step_chkpt_path}')
  268. shutil.copy(last_chkpt_path, step_chkpt_path)
  269. def load_checkpoint(path):
  270. if os.path.isdir(path):
  271. path = os.path.join(path, 'checkpoint_last.pt')
  272. dst = f'cuda:{torch.cuda.current_device()}'
  273. logging.info(f'Loading checkpoint from {path}')
  274. checkpoint = torch.load(path, map_location=dst)
  275. return checkpoint
  276. def init_weight(weight, args):
  277. if args.init == 'uniform':
  278. nn.init.uniform_(weight, -args.init_range, args.init_range)
  279. elif args.init == 'normal':
  280. nn.init.normal_(weight, 0.0, args.init_std)
  281. def init_bias(bias):
  282. nn.init.constant_(bias, 0.0)
  283. def weights_init(m, args):
  284. classname = m.__class__.__name__
  285. if classname.find('Linear') != -1:
  286. if hasattr(m, 'weight') and m.weight is not None:
  287. init_weight(m.weight, args)
  288. if hasattr(m, 'bias') and m.bias is not None:
  289. init_bias(m.bias)
  290. elif classname.find('AdaptiveEmbedding') != -1:
  291. if hasattr(m, 'emb_projs'):
  292. for i in range(len(m.emb_projs)):
  293. if m.emb_projs[i] is not None:
  294. nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
  295. elif classname.find('Embedding') != -1:
  296. if hasattr(m, 'weight'):
  297. init_weight(m.weight, args)
  298. elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
  299. if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
  300. init_weight(m.cluster_weight, args)
  301. if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
  302. init_bias(m.cluster_bias)
  303. if hasattr(m, 'out_projs'):
  304. for i in range(len(m.out_projs)):
  305. if m.out_projs[i] is not None:
  306. nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
  307. if hasattr(m, 'out_layers_weights'):
  308. for i in range(len(m.out_layers_weights)):
  309. if m.out_layers_weights[i] is not None:
  310. init_weight(m.out_layers_weights[i], args)
  311. elif classname.find('LayerNorm') != -1:
  312. if hasattr(m, 'weight'):
  313. nn.init.normal_(m.weight, 1.0, args.init_std)
  314. if hasattr(m, 'bias') and m.bias is not None:
  315. init_bias(m.bias)
  316. elif classname.find('TransformerLM') != -1:
  317. if hasattr(m, 'r_emb'):
  318. init_weight(m.r_emb, args)
  319. if hasattr(m, 'r_w_bias'):
  320. init_weight(m.r_w_bias, args)
  321. if hasattr(m, 'r_r_bias'):
  322. init_weight(m.r_r_bias, args)
  323. if hasattr(m, 'r_bias'):
  324. init_bias(m.r_bias)
  325. def update_dropout(m, args):
  326. classname = m.__class__.__name__
  327. if classname.find('Dropout') != -1:
  328. if hasattr(m, 'p'):
  329. m.p = args.dropout
  330. def update_dropatt(m, args):
  331. if hasattr(m, 'dropatt'):
  332. m.dropatt.p = args.dropatt
  333. def evaluate(eval_iter, model, args):
  334. # Turn on evaluation mode which disables dropout.
  335. model.eval()
  336. # If the model does not use memory at all, make the ext_len longer.
  337. # Otherwise, make the mem_len longer and keep the ext_len the same.
  338. if args.mem_len == 0:
  339. model.reset_length(tgt_len=args.eval_tgt_len,
  340. ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
  341. mem_len=args.mem_len
  342. )
  343. else:
  344. model.reset_length(tgt_len=args.eval_tgt_len,
  345. ext_len=args.ext_len,
  346. mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
  347. )
  348. # Evaluation
  349. total_len, total_loss = 0, 0.
  350. with torch.no_grad():
  351. mems = None
  352. for i, (data, target, seq_len, warm) in enumerate(eval_iter):
  353. if args.eval_max_steps > 0 and i >= args.eval_max_steps:
  354. break
  355. loss, mems = model(data, target, mems)
  356. loss = loss.float().mean()
  357. if warm:
  358. assert (mems is None) or mems.size(1) == model.mem_len
  359. total_loss += seq_len * loss.item()
  360. total_len += seq_len
  361. # Switch back to the training mode
  362. model.reset_length(tgt_len=args.tgt_len,
  363. ext_len=args.ext_len,
  364. mem_len=args.mem_len
  365. )
  366. model.train()
  367. return total_loss / total_len
  368. def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
  369. optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
  370. last_batch, last_iter, train_step, best_val_loss, meters,
  371. timeout_handler, args):
  372. # Turn on training mode which enables dropout.
  373. model.train()
  374. train_loss = 0
  375. target_tokens = 0
  376. log_step = 0
  377. log_start_time = time.time()
  378. mems = [None for _ in range(args.batch_chunk)]
  379. if args.varlen:
  380. train_iter = tr_iter.get_varlen_iter(start=last_iter)
  381. else:
  382. train_iter = tr_iter.get_fixlen_iter(start=last_iter)
  383. for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
  384. log_step += 1
  385. target_tokens += target.numel()
  386. for param in model.parameters():
  387. param.grad = None
  388. data_chunks = torch.chunk(data, args.batch_chunk, 1)
  389. target_chunks = torch.chunk(target, args.batch_chunk, 1)
  390. for i in range(args.batch_chunk):
  391. data_i = data_chunks[i].contiguous()
  392. target_i = target_chunks[i].contiguous()
  393. loss, mems[i] = para_model(data_i, target_i, mems[i])
  394. loss = loss.float().mean().type_as(loss) / args.batch_chunk
  395. if args.fp16:
  396. with amp.scale_loss(loss, optimizer) as scaled_loss:
  397. scaled_loss.backward()
  398. else:
  399. loss.backward()
  400. train_loss += loss.float().item()
  401. if args.fp16:
  402. torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
  403. else:
  404. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
  405. optimizer.step()
  406. if optimizer_sparse:
  407. optimizer_sparse.step()
  408. # step-wise learning rate annealing
  409. train_step += 1
  410. if args.scheduler in ['cosine', 'constant', 'dev_perf']:
  411. # linear warmup stage
  412. if train_step < args.warmup_step:
  413. curr_lr = args.lr * train_step / args.warmup_step
  414. optimizer.param_groups[0]['lr'] = curr_lr
  415. if optimizer_sparse:
  416. optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
  417. else:
  418. if args.scheduler == 'cosine':
  419. scheduler.step(train_step - args.warmup_step)
  420. if scheduler_sparse:
  421. scheduler_sparse.step(train_step - args.warmup_step)
  422. elif args.scheduler == 'inv_sqrt':
  423. scheduler.step(train_step)
  424. if scheduler_sparse:
  425. scheduler_sparse.step(train_step)
  426. if train_step % args.log_interval == 0:
  427. cur_loss = train_loss / log_step
  428. cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
  429. train_loss = 0
  430. elapsed = time.time() - log_start_time
  431. avg_elapsed = elapsed / log_step
  432. avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
  433. log_start_time = time.time()
  434. log_step = 0
  435. lr = optimizer.param_groups[0]['lr']
  436. throughput = target_tokens / elapsed
  437. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  438. meters['train_throughput'].update(throughput)
  439. target_tokens = 0
  440. log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
  441. '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
  442. epoch,
  443. train_step,
  444. batch,
  445. tr_iter.n_batch,
  446. lr,
  447. avg_elapsed * 1000,
  448. throughput,
  449. cur_loss,
  450. )
  451. dllogger_data = {
  452. 'epoch': epoch,
  453. 'train_batch': batch+1,
  454. 'lr': lr,
  455. 'train_time/batch': avg_elapsed * 1000,
  456. 'train_throughput': throughput,
  457. 'train_loss': cur_loss,
  458. }
  459. if args.dataset in ['enwik8', 'text8']:
  460. log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
  461. dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
  462. else:
  463. log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
  464. dllogger_data['train_perplexity'] = math.exp(cur_loss)
  465. logging.info(log_str)
  466. dllogger.log(step=tuple([train_step]), data=dllogger_data)
  467. do_periodic_eval = train_step % args.eval_interval == 0
  468. is_final_step = train_step == args.max_step
  469. interrupted = timeout_handler.interrupted
  470. if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
  471. eval_start_time = time.time()
  472. val_loss = evaluate(va_iter, model, args)
  473. val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
  474. logging.info('-' * 100)
  475. log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
  476. '| valid loss {:5.2f}'.format(
  477. train_step // args.eval_interval,
  478. train_step,
  479. (time.time() - eval_start_time),
  480. val_loss,
  481. )
  482. dllogger_data = {
  483. 'valid_elapsed': (time.time() - eval_start_time),
  484. 'valid_loss': val_loss,
  485. }
  486. if args.dataset in ['enwik8', 'text8']:
  487. log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
  488. dllogger_data['valid_bits_per_character'] = val_loss / math.log(2)
  489. else:
  490. log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
  491. dllogger_data['valid_perplexity'] = math.exp(val_loss)
  492. logging.info(log_str)
  493. logging.info('-' * 100)
  494. dllogger.log(step=tuple([train_step]), data=dllogger_data)
  495. last_iter = tr_iter.last_iter
  496. # Check if the validation loss is the best we've seen so far.
  497. is_best = False
  498. if not best_val_loss or val_loss < best_val_loss:
  499. best_val_loss = val_loss
  500. is_best = True
  501. if not args.debug:
  502. save_checkpoint(args, model, model_config, optimizer,
  503. scheduler, vocab, epoch, batch, last_iter,
  504. train_step, best_val_loss, is_best,
  505. args.work_dir)
  506. # dev-performance based learning rate annealing
  507. if args.scheduler == 'dev_perf':
  508. scheduler.step(val_loss)
  509. if scheduler_sparse:
  510. scheduler_sparse.step(val_loss)
  511. # subtract eval time from timers for training
  512. log_start_time += time.time() - eval_start_time
  513. if interrupted:
  514. logging.info(f'Received SIGTERM, exiting')
  515. sys.exit(0)
  516. if is_final_step:
  517. break
  518. return train_step, best_val_loss
  519. def main():
  520. args = parse_args()
  521. utils.gpu_affinity.set_affinity(args.local_rank)
  522. # Initialize device and distributed backend
  523. torch.cuda.set_device(args.local_rank)
  524. l2_promote()
  525. device = torch.device('cuda' if args.cuda else 'cpu')
  526. utils.distributed.init_distributed(args.cuda)
  527. args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
  528. args.dataset,
  529. args.append_dataset,
  530. args.append_time,
  531. )
  532. with utils.distributed.sync_workers() as rank:
  533. if rank == 0:
  534. create_exp_dir(args.work_dir,
  535. scripts_to_save=['train.py', 'mem_transformer.py'],
  536. debug=args.debug)
  537. # Setup logging
  538. if args.log_all_ranks:
  539. log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
  540. else:
  541. log_file = args.txtlog_file
  542. dllog_file = args.dllog_file
  543. log_file = os.path.join(args.work_dir, log_file)
  544. dllog_file = os.path.join(args.work_dir, dllog_file)
  545. if args.debug:
  546. log_file = os.devnull
  547. dllog_file = os.devnull
  548. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  549. filename=log_file,
  550. )
  551. utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
  552. if args.local_batch_size is not None:
  553. world_size = utils.distributed.get_world_size()
  554. args.batch_size = world_size * args.local_batch_size
  555. logging.info(f'--local_batch_size was set, adjusting global batch size'
  556. f' to {args.batch_size} (local_batch_size * world_size)')
  557. logging.info(args)
  558. dllogger.log(step='PARAMETER', data=vars(args))
  559. logging.info(f'world size: {utils.distributed.get_world_size()}')
  560. if not args.no_env:
  561. log_env_info()
  562. register_ignoring_timeout_handler()
  563. # Set the random seed manually for reproducibility.
  564. np.random.seed(args.seed)
  565. torch.manual_seed(args.seed)
  566. ###########################################################################
  567. # Load data
  568. ###########################################################################
  569. corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
  570. ntokens = len(corpus.vocab)
  571. vocab = corpus.vocab
  572. args.n_token = ntokens
  573. if args.mem_len == 0:
  574. eval_mem_len = 0
  575. else:
  576. eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len
  577. tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
  578. device=device, ext_len=args.ext_len)
  579. va_iter = corpus.get_iterator('valid', args.eval_batch_size,
  580. args.eval_tgt_len, device=device,
  581. mem_len=eval_mem_len, ext_len=args.ext_len)
  582. te_iter = corpus.get_iterator('test', args.eval_batch_size,
  583. args.eval_tgt_len, device=device,
  584. mem_len=eval_mem_len, ext_len=args.ext_len)
  585. # adaptive softmax / embedding
  586. cutoffs, tie_projs = [], [False]
  587. if args.adaptive:
  588. assert args.dataset in ['wt103', 'lm1b']
  589. if args.dataset == 'wt103':
  590. cutoffs = [19997, 39997, 199997]
  591. tie_projs += [True] * len(cutoffs)
  592. elif args.dataset == 'lm1b':
  593. cutoffs = [59997, 99997, 639997]
  594. tie_projs += [False] * len(cutoffs)
  595. ###########################################################################
  596. # Build the model
  597. ###########################################################################
  598. model_config = {
  599. 'n_token': ntokens,
  600. 'n_layer': args.n_layer,
  601. 'n_head': args.n_head,
  602. 'd_model': args.d_model,
  603. 'd_head': args.d_head,
  604. 'd_inner': args.d_inner,
  605. 'dropout': args.dropout,
  606. 'dropatt': args.dropatt,
  607. 'dtype': None,
  608. 'tie_weight': args.tied,
  609. 'd_embed': args.d_embed,
  610. 'div_val': args.div_val,
  611. 'tie_projs': tie_projs,
  612. 'pre_lnorm': args.pre_lnorm,
  613. 'tgt_len': args.tgt_len,
  614. 'ext_len': args.ext_len,
  615. 'mem_len': args.mem_len,
  616. 'cutoffs': cutoffs,
  617. 'same_length': args.same_length,
  618. 'attn_type': args.attn_type,
  619. 'clamp_len': args.clamp_len,
  620. 'sample_softmax': args.sample_softmax,
  621. }
  622. model = MemTransformerLM(**model_config)
  623. model.apply(functools.partial(weights_init, args=args))
  624. # ensure embedding init is not overridden by out_layer in case of weight sharing
  625. model.word_emb.apply(functools.partial(weights_init, args=args))
  626. args.n_all_param = sum([p.nelement() for p in model.parameters()])
  627. args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
  628. # optimizer
  629. if args.optim.lower() == 'sgd':
  630. if args.sample_softmax > 0:
  631. dense_params, sparse_params = [], []
  632. for param in model.parameters():
  633. if param.size() == model.word_emb.weight.size():
  634. sparse_params.append(param)
  635. else:
  636. dense_params.append(param)
  637. optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
  638. optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
  639. else:
  640. optimizer = optim.SGD(model.parameters(), lr=args.lr,
  641. momentum=args.mom)
  642. optimizer_sparse = None
  643. elif args.optim.lower() == 'adam':
  644. if args.sample_softmax > 0:
  645. dense_params, sparse_params = [], []
  646. for param in model.parameters():
  647. if param.size() == model.word_emb.weight.size():
  648. sparse_params.append(param)
  649. else:
  650. dense_params.append(param)
  651. optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
  652. optimizer = optim.Adam(dense_params, lr=args.lr,
  653. weight_decay=args.weight_decay)
  654. else:
  655. optimizer = optim.Adam(model.parameters(), lr=args.lr,
  656. weight_decay=args.weight_decay)
  657. optimizer_sparse = None
  658. elif args.optim.lower() == 'adagrad':
  659. optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
  660. optimizer_sparse = None
  661. elif args.optim.lower() == 'lamb':
  662. optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
  663. weight_decay=args.weight_decay)
  664. optimizer_sparse = None
  665. elif args.optim.lower() == 'jitlamb':
  666. optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
  667. weight_decay=args.weight_decay)
  668. optimizer_sparse = None
  669. model = model.to(device)
  670. if args.fp16:
  671. model, optimizer = amp.initialize(
  672. model,
  673. optimizer,
  674. opt_level=args.amp_mode,
  675. )
  676. if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
  677. para_model = DistributedDataParallel(model,
  678. device_ids=[args.local_rank],
  679. output_device=args.local_rank,
  680. broadcast_buffers=False,
  681. find_unused_parameters=True,
  682. )
  683. elif args.multi_gpu == 'dp':
  684. if args.gpu0_bsz >= 0:
  685. para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
  686. model, dim=1).to(device)
  687. else:
  688. para_model = nn.DataParallel(model, dim=1).to(device)
  689. else:
  690. para_model = model
  691. # scheduler
  692. if args.scheduler == 'cosine':
  693. if args.max_step_scheduler:
  694. max_step = args.max_step_scheduler
  695. else:
  696. max_step = args.max_step
  697. scheduler = optim.lr_scheduler.CosineAnnealingLR(
  698. optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
  699. if args.sample_softmax > 0 and optimizer_sparse is not None:
  700. scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
  701. optimizer_sparse, max_step - args.warmup_step,
  702. eta_min=args.eta_min)
  703. else:
  704. scheduler_sparse = None
  705. elif args.scheduler == 'inv_sqrt':
  706. # originally used for Transformer (in Attention is all you need)
  707. def lr_lambda(step):
  708. # return a multiplier instead of a learning rate
  709. if step == 0 and args.warmup_step == 0:
  710. return 1.
  711. else:
  712. return 1. / (step ** 0.5) if step > args.warmup_step \
  713. else step / (args.warmup_step ** 1.5)
  714. scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
  715. if args.sample_softmax > 0 and optimizer_sparse is not None:
  716. scheduler_sparse = optim.lr_scheduler.LambdaLR(
  717. optimizer_sparse,
  718. lr_lambda=lr_lambda
  719. )
  720. else:
  721. scheduler_sparse = None
  722. elif args.scheduler == 'dev_perf':
  723. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  724. optimizer, factor=args.decay_rate, patience=args.patience,
  725. min_lr=args.lr_min,
  726. )
  727. if args.sample_softmax > 0 and optimizer_sparse is not None:
  728. scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
  729. optimizer_sparse, factor=args.decay_rate, patience=args.patience,
  730. min_lr=args.lr_min,
  731. )
  732. else:
  733. scheduler_sparse = None
  734. elif args.scheduler == 'constant':
  735. pass
  736. logging.info('=' * 100)
  737. for k, v in args.__dict__.items():
  738. logging.info(' - {} : {}'.format(k, v))
  739. logging.info('=' * 100)
  740. logging.info('#params = {}'.format(args.n_all_param))
  741. logging.info('#non emb params = {}'.format(args.n_nonemb_param))
  742. train_step = 0
  743. start_epoch = 1
  744. last_batch = 0
  745. last_iter = 0
  746. best_val_loss = None
  747. if args.restart:
  748. try:
  749. checkpoint = load_checkpoint(args.restart)
  750. model.load_state_dict(checkpoint['model_state'])
  751. optimizer.load_state_dict(checkpoint['optimizer_state'])
  752. scheduler.load_state_dict(checkpoint['scheduler_state'])
  753. if args.fp16:
  754. amp.load_state_dict(checkpoint['amp_state'])
  755. train_step = checkpoint['train_step']
  756. start_epoch = checkpoint['epoch']
  757. last_batch = checkpoint['batch']
  758. last_iter = checkpoint['last_iter']
  759. best_val_loss = checkpoint['best_val_loss']
  760. if train_step >= args.max_step:
  761. logging.info(f'Loaded checkpoint after {train_step} steps, but '
  762. f'this run was scheduled for a total of '
  763. f'{args.max_step} steps, exiting')
  764. sys.exit(1)
  765. model.apply(functools.partial(update_dropout, args=args))
  766. model.apply(functools.partial(update_dropatt, args=args))
  767. except FileNotFoundError:
  768. logging.info(f'Could not load checkpoint from {args.restart}, '
  769. f'starting training from random init')
  770. meters = {}
  771. warmup = args.mem_len // args.tgt_len + 2
  772. meters['train_throughput'] = AverageMeter(warmup=warmup)
  773. ###########################################################################
  774. # Train
  775. ###########################################################################
  776. # Loop over epochs.
  777. # At any point you can hit Ctrl + C to break out of training early.
  778. start_time = time.time()
  779. with TimeoutHandler() as timeout_handler:
  780. try:
  781. for epoch in itertools.count(start=start_epoch):
  782. if args.roll:
  783. tr_iter.roll(seed=args.seed + epoch)
  784. train_step, best_val_loss = train(
  785. tr_iter, va_iter, model, para_model, model_config,
  786. optimizer, optimizer_sparse, scheduler, scheduler_sparse,
  787. vocab, epoch, last_batch, last_iter, train_step,
  788. best_val_loss, meters, timeout_handler, args
  789. )
  790. last_batch = 0
  791. last_iter = 0
  792. if train_step == args.max_step:
  793. logging.info('-' * 100)
  794. logging.info('End of training')
  795. break
  796. except KeyboardInterrupt:
  797. logging.info('-' * 100)
  798. logging.info('Exiting from training early')
  799. elapsed = time.time() - start_time
  800. ###########################################################################
  801. # Test
  802. ###########################################################################
  803. summary = {}
  804. test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  805. if not args.debug and not args.no_eval and os.path.exists(test_path):
  806. # Load the best saved model.
  807. checkpoint = load_checkpoint(test_path)
  808. model.load_state_dict(checkpoint['model_state'])
  809. # Run on test data.
  810. test_start_time = time.time()
  811. test_loss = evaluate(te_iter, model, args)
  812. test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
  813. test_elapsed = time.time() - test_start_time
  814. logging.info('=' * 100)
  815. if args.dataset in ['enwik8', 'text8']:
  816. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
  817. test_elapsed, test_loss, test_loss / math.log(2)))
  818. else:
  819. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
  820. test_elapsed, test_loss, math.exp(test_loss)))
  821. logging.info('=' * 100)
  822. summary.update({
  823. 'test_elapsed': test_elapsed,
  824. 'test_loss': test_loss,
  825. })
  826. if args.dataset in ['enwik8', 'text8']:
  827. summary['test_bits_per_character'] = test_loss / math.log(2)
  828. else:
  829. summary['test_perplexity'] = math.exp(test_loss)
  830. logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
  831. logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
  832. if best_val_loss:
  833. val_perplexity = math.exp(best_val_loss)
  834. else:
  835. val_perplexity = None
  836. summary.update({
  837. 'train_throughput': meters['train_throughput'].avg,
  838. 'train_elapsed': elapsed / 60,
  839. 'valid_loss': best_val_loss,
  840. 'valid_perplexity': val_perplexity,
  841. })
  842. dllogger.log(step=tuple(), data=summary)
  843. passed = benchmark(
  844. target_perplexity=args.target_perplexity,
  845. test_perplexity=val_perplexity,
  846. target_throughput=args.target_throughput,
  847. test_throughput=meters['train_throughput'].avg
  848. )
  849. if not passed:
  850. sys.exit(1)
  851. if __name__ == "__main__":
  852. # Disable profiling executor
  853. try:
  854. torch._C._jit_set_profiling_executor(False)
  855. torch._C._jit_set_profiling_mode(False)
  856. except AttributeError:
  857. pass
  858. # Before we do anything with models, we want to ensure that we get fp16
  859. # execution of torch.einsum.
  860. # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
  861. # Note that running `--amp_mode O2` will remove the need for this
  862. # code, but it is still valid.
  863. amp.register_half_function(torch, 'einsum')
  864. main()