train.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060
  1. # coding: utf-8
  2. # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import functools
  17. import itertools
  18. import logging
  19. import math
  20. import os
  21. import shutil
  22. import sys
  23. import time
  24. import warnings
  25. import dllogger
  26. import numpy as np
  27. import torch
  28. import torch.nn as nn
  29. import torch.optim as optim
  30. import yaml
  31. try:
  32. from apex import amp
  33. except ModuleNotFoundError:
  34. warnings.warn('APEX AMP is unavailable')
  35. from torch.nn.parallel import DistributedDataParallel
  36. import lamb
  37. import utils
  38. from data_utils import get_lm_corpus
  39. from mem_transformer import MemTransformerLM
  40. from utils.data_parallel import BalancedDataParallel
  41. from utils.exp_utils import AverageMeter
  42. from utils.exp_utils import TimeoutHandler
  43. from utils.exp_utils import benchmark
  44. from utils.exp_utils import create_exp_dir
  45. from utils.exp_utils import l2_promote
  46. from utils.exp_utils import log_env_info
  47. from utils.exp_utils import register_ignoring_timeout_handler
  48. def parse_args():
  49. parent_parser = argparse.ArgumentParser(
  50. description='PyTorch Transformer-XL Language Model',
  51. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  52. add_help=False,
  53. )
  54. parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
  55. cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
  56. cfg_parser.add_argument('--config', default='default')
  57. cfg_parser.add_argument('--config_file', default=None)
  58. config_args, _ = cfg_parser.parse_known_args()
  59. if config_args.config is not None and config_args.config_file is not None:
  60. with open(config_args.config_file) as f:
  61. config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
  62. else:
  63. config = {}
  64. general = parser.add_argument_group('general setup')
  65. general.add_argument('--work_dir', default='LM-TFM', type=str,
  66. help='Directory for the results')
  67. general.add_argument('--append_dataset', action='store_true',
  68. help='Automatically append dataset name to work_dir')
  69. general.add_argument('--append_time', action='store_true',
  70. help='Automatically append current time to work_dir')
  71. general.add_argument('--cuda', action='store_true',
  72. help='Run training on a GPU using CUDA')
  73. general.add_argument('--fp16', action='store_true',
  74. help='Run training in fp16/mixed precision')
  75. general.add_argument('--restart', type=str, default='',
  76. help='Restart training from the saved checkpoint')
  77. general.add_argument('--debug', action='store_true',
  78. help='Run in debug mode (do not create exp dir)')
  79. general.add_argument('--log_all_ranks', action='store_true',
  80. help='Enable logging from all distributed ranks')
  81. general.add_argument('--dllog_file', type=str, default='train_log.json',
  82. help='Name of the DLLogger output file')
  83. general.add_argument('--txtlog_file', type=str, default='train_log.log',
  84. help='Name of the txt log file')
  85. general.add_argument('--save_all', action='store_true',
  86. help='Save all checkpoints')
  87. general.add_argument('--no_env', action='store_true',
  88. help='Do not print info on execution env')
  89. general.add_argument('--no_eval', action='store_true',
  90. help='Disable model evaluation')
  91. general.add_argument('--log_interval', type=int, default=10,
  92. help='Report interval')
  93. general.add_argument('--target_throughput', type=float, default=None,
  94. help='Target training throughput (for benchmarking)')
  95. general.add_argument('--target_perplexity', type=float, default=None,
  96. help='Target validation perplexity (for benchmarking)')
  97. general.add_argument('--apex_amp_opt_level', type=str, default='O2',
  98. choices=['O0', 'O1', 'O2', 'O3'],
  99. help='Optimization level for apex amp')
  100. general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
  101. help='Implementation of automatic mixed precision')
  102. dataset = parser.add_argument_group('dataset setup')
  103. dataset.add_argument('--data', type=str, default='../data/wikitext-103',
  104. help='Location of the data corpus')
  105. dataset.add_argument('--dataset', type=str, default='wt103',
  106. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  107. help='Dataset name')
  108. dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
  109. help='Type of vocabulary')
  110. model = parser.add_argument_group('model setup')
  111. model.add_argument('--n_layer', type=int, default=16,
  112. help='Number of total layers')
  113. model.add_argument('--n_head', type=int, default=8,
  114. help='Number of heads')
  115. model.add_argument('--d_head', type=int, default=64,
  116. help='Head dimension')
  117. model.add_argument('--d_embed', type=int, default=-1,
  118. help='Embedding dimension')
  119. model.add_argument('--d_model', type=int, default=512,
  120. help='Model dimension')
  121. model.add_argument('--d_inner', type=int, default=2048,
  122. help='Inner dimension in feedforward layer')
  123. model.add_argument('--dropout', type=float, default=0.1,
  124. help='Global dropout rate')
  125. model.add_argument('--dropatt', type=float, default=0.0,
  126. help='Attention probability dropout rate')
  127. model.add_argument('--pre_lnorm', action='store_true',
  128. help='Apply LayerNorm to the input instead of the output')
  129. model.add_argument('--attn_type', type=int, default=0,
  130. help='Attention type. 0 for ours, 1 for Shaw et al,'
  131. '2 for Vaswani et al, 3 for Al Rfou et al.')
  132. model.add_argument('--not_tied', action='store_true',
  133. help='Do not tie the word embedding and softmax weights')
  134. model.add_argument('--clamp_len', type=int, default=-1,
  135. help='Use the same pos embeddings after clamp_len')
  136. model.add_argument('--adaptive', action='store_true',
  137. help='Use adaptive softmax')
  138. model.add_argument('--div_val', type=int, default=1,
  139. help='Dividend value for adaptive input and softmax')
  140. model.add_argument('--sample_softmax', type=int, default=-1,
  141. help='Number of samples in sampled softmax')
  142. model.add_argument('--init', default='normal', type=str,
  143. help='Parameter initializer to use')
  144. model.add_argument('--emb_init', default='normal', type=str,
  145. help='Parameter initializer to use')
  146. model.add_argument('--init_range', type=float, default=0.1,
  147. help='Parameters initialized by U(-init_range, init_range)')
  148. model.add_argument('--emb_init_range', type=float, default=0.01,
  149. help='Parameters initialized by U(-init_range, init_range)')
  150. model.add_argument('--init_std', type=float, default=0.02,
  151. help='Parameters initialized by N(0, init_std)')
  152. model.add_argument('--proj_init_std', type=float, default=0.01,
  153. help='Parameters initialized by N(0, init_std)')
  154. opt = parser.add_argument_group('optimizer setup')
  155. opt.add_argument('--optim', default='jitlamb', type=str,
  156. choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
  157. help='Optimizer to use')
  158. opt.add_argument('--lr', type=float, default=0.01,
  159. help='Initial learning rate')
  160. opt.add_argument('--mom', type=float, default=0.0,
  161. help='Momentum for sgd')
  162. opt.add_argument('--scheduler', default='cosine', type=str,
  163. choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
  164. help='LR scheduler to use')
  165. opt.add_argument('--max_step_scheduler', type=int, default=None,
  166. help='Max number of training steps for LR scheduler')
  167. opt.add_argument('--warmup_step', type=int, default=1000,
  168. help='Number of iterations for LR warmup')
  169. opt.add_argument('--decay_rate', type=float, default=0.5,
  170. help='Decay factor when ReduceLROnPlateau is used')
  171. opt.add_argument('--lr_min', type=float, default=0.0,
  172. help='Minimum learning rate during annealing')
  173. opt.add_argument('--clip', type=float, default=0.25,
  174. help='Gradient clipping')
  175. opt.add_argument('--weight_decay', type=float, default=0.0,
  176. help='Weight decay for adam|lamb')
  177. opt.add_argument('--clip_nonemb', action='store_true',
  178. help='Only clip the gradient of non-embedding params')
  179. opt.add_argument('--patience', type=int, default=0,
  180. help='Patience')
  181. opt.add_argument('--eta_min', type=float, default=0.001,
  182. help='Min learning rate for cosine scheduler')
  183. training = parser.add_argument_group('training setup')
  184. training.add_argument('--max_step', type=int, default=40000,
  185. help='Max number of training steps')
  186. training.add_argument('--batch_size', type=int, default=256,
  187. help='Global batch size')
  188. training.add_argument('--local_batch_size', type=int, default=None,
  189. help='Local (per-device) batch size, this setting \
  190. overrides global --batch_size and sets batch_size \
  191. to local_batch_size * world_size')
  192. training.add_argument('--batch_chunk', type=int, default=1,
  193. help='Split batch into chunks and train with '
  194. 'gradient accumulation')
  195. training.add_argument('--roll', action='store_true',
  196. help='Enable random shifts within each data stream')
  197. training.add_argument('--tgt_len', type=int, default=192,
  198. help='Number of tokens to predict')
  199. training.add_argument('--ext_len', type=int, default=0,
  200. help='Length of the extended context')
  201. training.add_argument('--mem_len', type=int, default=192,
  202. help='Length of the retained previous heads')
  203. training.add_argument('--seed', type=int, default=1111,
  204. help='Random seed')
  205. training.add_argument('--multi_gpu', default=None, type=str,
  206. choices=['ddp', 'dp'],
  207. help='Use multiple GPU')
  208. training.add_argument('--gpu0_bsz', type=int, default=-1,
  209. help='Batch size on gpu 0 (for "dp" backend)')
  210. training.add_argument('--same_length', action='store_true',
  211. help='Use the same attn length for all tokens')
  212. training.add_argument('--varlen', action='store_true',
  213. help='Use variable length')
  214. training.add_argument('--swap_mem', action='store_true',
  215. help='Swap memory tensors to cpu')
  216. val = parser.add_argument_group('validation setup')
  217. val.add_argument('--eval_tgt_len', type=int, default=192,
  218. help='Number of tokens to predict for evaluation')
  219. val.add_argument('--eval_batch_size', type=int, default=16,
  220. help='Eval batch size')
  221. val.add_argument('--eval_max_steps', type=int, default=-1,
  222. help='Max eval steps')
  223. val.add_argument('--eval_interval', type=int, default=5000,
  224. help='Evaluation interval')
  225. dist = parser.add_argument_group('distributed setup')
  226. dist.add_argument('--local_rank', type=int,
  227. default=os.getenv('LOCAL_RANK', 0),
  228. help='Used for multi-process training.')
  229. parser.set_defaults(**config)
  230. args, _ = parser.parse_known_args()
  231. args.tied = not args.not_tied
  232. if args.d_embed < 0:
  233. args.d_embed = args.d_model
  234. if args.ext_len < 0:
  235. raise RuntimeError('Extended context length must be non-negative')
  236. if args.batch_size % args.batch_chunk != 0:
  237. raise RuntimeError('Batch size needs to be divisible by batch chunk')
  238. if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
  239. raise RuntimeError(
  240. 'APEX AMP unavailable, install APEX or switch to pytorch AMP'
  241. )
  242. return args
  243. def save_checkpoint(args, model, model_config, optimizer, scheduler, scaler,
  244. vocab, epoch, batch, last_iter, train_step, best_val_loss,
  245. is_best, work_dir):
  246. if args.fp16:
  247. if args.amp == 'pytorch':
  248. amp_state = scaler.state_dict()
  249. elif args.amp == 'apex':
  250. amp_state = amp.state_dict()
  251. else:
  252. amp_state = None
  253. state = {
  254. 'args': args,
  255. 'model_config': model_config,
  256. 'model_state': model.state_dict(),
  257. 'optimizer_state': optimizer.state_dict(),
  258. 'scheduler_state': scheduler.state_dict(),
  259. 'vocab': vocab,
  260. 'amp_state': amp_state,
  261. 'epoch': epoch,
  262. 'batch': batch,
  263. 'last_iter': last_iter,
  264. 'train_step': train_step,
  265. 'best_val_loss': best_val_loss,
  266. }
  267. last_chkpt_fname = 'checkpoint_last.pt'
  268. with utils.distributed.sync_workers() as rank:
  269. last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
  270. if rank == 0:
  271. # always save last checkpoint
  272. logging.info(f'Saving checkpoint to {last_chkpt_path}')
  273. torch.save(state, last_chkpt_path)
  274. # save best checkpoint if better than previous best
  275. if is_best:
  276. best_chkpt_fname = 'checkpoint_best.pt'
  277. best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
  278. logging.info(f'Saving checkpoint to {best_chkpt_path}')
  279. shutil.copy(last_chkpt_path, best_chkpt_path)
  280. # save every checkpoint if save_all is true
  281. if args.save_all:
  282. step_chkpt_fname = f'checkpoint_{train_step}.pt'
  283. step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
  284. logging.info(f'Saving checkpoint to {step_chkpt_path}')
  285. shutil.copy(last_chkpt_path, step_chkpt_path)
  286. def load_checkpoint(path):
  287. if os.path.isdir(path):
  288. path = os.path.join(path, 'checkpoint_last.pt')
  289. dst = f'cuda:{torch.cuda.current_device()}'
  290. logging.info(f'Loading checkpoint from {path}')
  291. checkpoint = torch.load(path, map_location=dst)
  292. return checkpoint
  293. def init_weight(weight, args):
  294. if args.init == 'uniform':
  295. nn.init.uniform_(weight, -args.init_range, args.init_range)
  296. elif args.init == 'normal':
  297. nn.init.normal_(weight, 0.0, args.init_std)
  298. def init_bias(bias):
  299. nn.init.constant_(bias, 0.0)
  300. def weights_init(m, args):
  301. classname = m.__class__.__name__
  302. if classname.find('Linear') != -1:
  303. if hasattr(m, 'weight') and m.weight is not None:
  304. init_weight(m.weight, args)
  305. if hasattr(m, 'bias') and m.bias is not None:
  306. init_bias(m.bias)
  307. elif classname.find('AdaptiveEmbedding') != -1:
  308. if hasattr(m, 'emb_projs'):
  309. for i in range(len(m.emb_projs)):
  310. if m.emb_projs[i] is not None:
  311. nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
  312. elif classname.find('Embedding') != -1:
  313. if hasattr(m, 'weight'):
  314. init_weight(m.weight, args)
  315. elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
  316. if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
  317. init_weight(m.cluster_weight, args)
  318. if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
  319. init_bias(m.cluster_bias)
  320. if hasattr(m, 'out_projs'):
  321. for i in range(len(m.out_projs)):
  322. if m.out_projs[i] is not None:
  323. nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
  324. if hasattr(m, 'out_layers_weights'):
  325. for i in range(len(m.out_layers_weights)):
  326. if m.out_layers_weights[i] is not None:
  327. init_weight(m.out_layers_weights[i], args)
  328. elif classname.find('LayerNorm') != -1:
  329. if hasattr(m, 'weight'):
  330. nn.init.normal_(m.weight, 1.0, args.init_std)
  331. if hasattr(m, 'bias') and m.bias is not None:
  332. init_bias(m.bias)
  333. elif classname.find('TransformerLM') != -1:
  334. if hasattr(m, 'r_emb'):
  335. init_weight(m.r_emb, args)
  336. if hasattr(m, 'r_w_bias'):
  337. init_weight(m.r_w_bias, args)
  338. if hasattr(m, 'r_r_bias'):
  339. init_weight(m.r_r_bias, args)
  340. if hasattr(m, 'r_bias'):
  341. init_bias(m.r_bias)
  342. def update_dropout(m, args):
  343. classname = m.__class__.__name__
  344. if classname.find('Dropout') != -1:
  345. if hasattr(m, 'p'):
  346. m.p = args.dropout
  347. def update_dropatt(m, args):
  348. if hasattr(m, 'dropatt'):
  349. m.dropatt.p = args.dropatt
  350. def evaluate(eval_iter, model, args):
  351. # Turn on evaluation mode which disables dropout.
  352. model.eval()
  353. # If the model does not use memory at all, make the ext_len longer.
  354. # Otherwise, make the mem_len longer and keep the ext_len the same.
  355. if args.mem_len == 0:
  356. model.reset_length(tgt_len=args.eval_tgt_len,
  357. ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
  358. mem_len=args.mem_len
  359. )
  360. else:
  361. model.reset_length(tgt_len=args.eval_tgt_len,
  362. ext_len=args.ext_len,
  363. mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
  364. )
  365. # Evaluation
  366. total_len, total_loss = 0, 0.
  367. with torch.no_grad():
  368. mems = None
  369. for i, (data, target, seq_len, warm) in enumerate(eval_iter):
  370. if args.eval_max_steps > 0 and i >= args.eval_max_steps:
  371. break
  372. loss, mems = model(data, target, mems)
  373. loss = loss.float().mean()
  374. if warm:
  375. assert (mems is None) or mems.size(1) == model.mem_len
  376. total_loss += seq_len * loss.item()
  377. total_len += seq_len
  378. # Switch back to the training mode
  379. model.reset_length(tgt_len=args.tgt_len,
  380. ext_len=args.ext_len,
  381. mem_len=args.mem_len
  382. )
  383. model.train()
  384. return total_loss / total_len
  385. def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
  386. optimizer, device, args):
  387. cpu = torch.device('cpu')
  388. data_i = data_chunks[i].contiguous()
  389. target_i = target_chunks[i].contiguous()
  390. if args.swap_mem and mems[i] is not None:
  391. mems[i] = mems[i].to(device, non_blocking=True)
  392. enable_autocast = args.fp16 and args.amp == 'pytorch'
  393. with torch.cuda.amp.autocast(enable_autocast):
  394. loss, mems[i] = model(data_i, target_i, mems[i])
  395. loss = loss.float().mean().type_as(loss) / args.batch_chunk
  396. if args.swap_mem and mems[i] is not None:
  397. mems[i] = mems[i].to(cpu, non_blocking=True)
  398. if args.fp16:
  399. if args.amp == 'pytorch':
  400. scaler.scale(loss).backward()
  401. elif args.amp == 'apex':
  402. with amp.scale_loss(loss, optimizer) as scaled_loss:
  403. scaled_loss.backward()
  404. else:
  405. loss.backward()
  406. train_loss = loss.float().item()
  407. return train_loss
  408. def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
  409. optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch,
  410. last_batch, last_iter, train_step, best_val_loss, meters,
  411. timeout_handler, device, args):
  412. # Turn on training mode which enables dropout.
  413. model.train()
  414. train_loss = 0
  415. target_tokens = 0
  416. log_step = 0
  417. log_start_time = time.time()
  418. mems = [None for _ in range(args.batch_chunk)]
  419. if args.varlen:
  420. train_iter = tr_iter.get_varlen_iter(start=last_iter)
  421. else:
  422. train_iter = tr_iter.get_fixlen_iter(start=last_iter)
  423. for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
  424. log_step += 1
  425. target_tokens += target.numel()
  426. for param in model.parameters():
  427. param.grad = None
  428. data_chunks = torch.chunk(data, args.batch_chunk, 1)
  429. target_chunks = torch.chunk(target, args.batch_chunk, 1)
  430. for i in range(args.batch_chunk):
  431. if i < args.batch_chunk - 1 and isinstance(para_model, DistributedDataParallel):
  432. with para_model.no_sync():
  433. train_loss_chunk = train_iteration(
  434. para_model, i, mems, data_chunks, target_chunks, scaler,
  435. optimizer, device, args
  436. )
  437. else:
  438. train_loss_chunk = train_iteration(
  439. para_model, i, mems, data_chunks, target_chunks, scaler,
  440. optimizer, device, args
  441. )
  442. train_loss += train_loss_chunk
  443. if args.fp16:
  444. if args.amp == 'pytorch':
  445. scaler.unscale_(optimizer)
  446. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
  447. elif args.amp == 'apex':
  448. torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
  449. else:
  450. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
  451. if args.fp16 and args.amp == 'pytorch':
  452. scaler.step(optimizer)
  453. scaler.update()
  454. else:
  455. optimizer.step()
  456. if optimizer_sparse:
  457. optimizer_sparse.step()
  458. # step-wise learning rate annealing
  459. train_step += 1
  460. if args.scheduler in ['cosine', 'constant', 'dev_perf']:
  461. # linear warmup stage
  462. if train_step < args.warmup_step:
  463. curr_lr = args.lr * train_step / args.warmup_step
  464. optimizer.param_groups[0]['lr'] = curr_lr
  465. if optimizer_sparse:
  466. optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
  467. else:
  468. if args.scheduler == 'cosine':
  469. scheduler.step(train_step - args.warmup_step)
  470. if scheduler_sparse:
  471. scheduler_sparse.step(train_step - args.warmup_step)
  472. elif args.scheduler == 'inv_sqrt':
  473. scheduler.step(train_step)
  474. if scheduler_sparse:
  475. scheduler_sparse.step(train_step)
  476. if train_step % args.log_interval == 0:
  477. cur_loss = train_loss / log_step
  478. cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
  479. train_loss = 0
  480. elapsed = time.time() - log_start_time
  481. avg_elapsed = elapsed / log_step
  482. avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
  483. log_start_time = time.time()
  484. log_step = 0
  485. lr = optimizer.param_groups[0]['lr']
  486. throughput = target_tokens / elapsed
  487. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  488. meters['train_throughput'].update(throughput)
  489. target_tokens = 0
  490. log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
  491. '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
  492. epoch,
  493. train_step,
  494. batch,
  495. tr_iter.n_batch,
  496. lr,
  497. avg_elapsed * 1000,
  498. throughput,
  499. cur_loss,
  500. )
  501. dllogger_data = {
  502. 'epoch': epoch,
  503. 'train_batch': batch+1,
  504. 'lr': lr,
  505. 'train_time/batch': avg_elapsed * 1000,
  506. 'train_throughput': throughput,
  507. 'train_loss': cur_loss,
  508. }
  509. if args.dataset in ['enwik8', 'text8']:
  510. log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
  511. dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
  512. else:
  513. log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
  514. dllogger_data['train_perplexity'] = math.exp(cur_loss)
  515. logging.info(log_str)
  516. dllogger.log(step=tuple([train_step]), data=dllogger_data)
  517. do_periodic_eval = train_step % args.eval_interval == 0
  518. is_final_step = train_step == args.max_step
  519. interrupted = timeout_handler.interrupted
  520. if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
  521. eval_start_time = time.time()
  522. val_loss = evaluate(va_iter, model, args)
  523. val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
  524. logging.info('-' * 100)
  525. log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
  526. '| valid loss {:5.2f}'.format(
  527. train_step // args.eval_interval,
  528. train_step,
  529. (time.time() - eval_start_time),
  530. val_loss,
  531. )
  532. dllogger_data = {
  533. 'valid_elapsed': (time.time() - eval_start_time),
  534. 'valid_loss': val_loss,
  535. }
  536. if args.dataset in ['enwik8', 'text8']:
  537. log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
  538. dllogger_data['valid_bits_per_character'] = val_loss / math.log(2)
  539. else:
  540. log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
  541. dllogger_data['valid_perplexity'] = math.exp(val_loss)
  542. logging.info(log_str)
  543. logging.info('-' * 100)
  544. dllogger.log(step=tuple([train_step]), data=dllogger_data)
  545. last_iter = tr_iter.last_iter
  546. # Check if the validation loss is the best we've seen so far.
  547. is_best = False
  548. if not best_val_loss or val_loss < best_val_loss:
  549. best_val_loss = val_loss
  550. is_best = True
  551. if not args.debug:
  552. save_checkpoint(args, model, model_config, optimizer, scheduler,
  553. scaler, vocab, epoch, batch, last_iter,
  554. train_step, best_val_loss, is_best,
  555. args.work_dir)
  556. # dev-performance based learning rate annealing
  557. if args.scheduler == 'dev_perf':
  558. scheduler.step(val_loss)
  559. if scheduler_sparse:
  560. scheduler_sparse.step(val_loss)
  561. # subtract eval time from timers for training
  562. log_start_time += time.time() - eval_start_time
  563. if interrupted:
  564. logging.info(f'Received SIGTERM, exiting')
  565. sys.exit(0)
  566. if is_final_step:
  567. break
  568. return train_step, best_val_loss
  569. def main():
  570. args = parse_args()
  571. utils.gpu_affinity.set_affinity(args.local_rank)
  572. # Initialize device and distributed backend
  573. torch.cuda.set_device(args.local_rank)
  574. l2_promote()
  575. device = torch.device('cuda' if args.cuda else 'cpu')
  576. utils.distributed.init_distributed(args.cuda)
  577. args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
  578. args.dataset,
  579. args.append_dataset,
  580. args.append_time,
  581. )
  582. with utils.distributed.sync_workers() as rank:
  583. if rank == 0:
  584. create_exp_dir(args.work_dir,
  585. scripts_to_save=['train.py', 'mem_transformer.py'],
  586. debug=args.debug)
  587. # Setup logging
  588. if args.log_all_ranks:
  589. log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
  590. else:
  591. log_file = args.txtlog_file
  592. dllog_file = args.dllog_file
  593. log_file = os.path.join(args.work_dir, log_file)
  594. dllog_file = os.path.join(args.work_dir, dllog_file)
  595. if args.debug:
  596. log_file = os.devnull
  597. dllog_file = os.devnull
  598. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  599. filename=log_file,
  600. )
  601. utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
  602. if args.local_batch_size is not None:
  603. world_size = utils.distributed.get_world_size()
  604. args.batch_size = world_size * args.local_batch_size
  605. logging.info(f'--local_batch_size was set, adjusting global batch size'
  606. f' to {args.batch_size} (local_batch_size * world_size)')
  607. logging.info(args)
  608. dllogger.log(step='PARAMETER', data=vars(args))
  609. logging.info(f'world size: {utils.distributed.get_world_size()}')
  610. if not args.no_env:
  611. log_env_info()
  612. register_ignoring_timeout_handler()
  613. # Set the random seed manually for reproducibility.
  614. np.random.seed(args.seed)
  615. torch.manual_seed(args.seed)
  616. ###########################################################################
  617. # Load data
  618. ###########################################################################
  619. corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
  620. ntokens = len(corpus.vocab)
  621. vocab = corpus.vocab
  622. args.n_token = ntokens
  623. if args.mem_len == 0:
  624. eval_mem_len = 0
  625. else:
  626. eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len
  627. tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
  628. device=device, ext_len=args.ext_len)
  629. va_iter = corpus.get_iterator('valid', args.eval_batch_size,
  630. args.eval_tgt_len, device=device,
  631. mem_len=eval_mem_len, ext_len=args.ext_len)
  632. te_iter = corpus.get_iterator('test', args.eval_batch_size,
  633. args.eval_tgt_len, device=device,
  634. mem_len=eval_mem_len, ext_len=args.ext_len)
  635. # adaptive softmax / embedding
  636. cutoffs, tie_projs = [], [False]
  637. if args.adaptive:
  638. assert args.dataset in ['wt103', 'lm1b']
  639. if args.dataset == 'wt103':
  640. cutoffs = [19997, 39997, 199997]
  641. tie_projs += [True] * len(cutoffs)
  642. elif args.dataset == 'lm1b':
  643. cutoffs = [59997, 99997, 639997]
  644. tie_projs += [False] * len(cutoffs)
  645. ###########################################################################
  646. # Build the model
  647. ###########################################################################
  648. model_config = {
  649. 'n_token': ntokens,
  650. 'n_layer': args.n_layer,
  651. 'n_head': args.n_head,
  652. 'd_model': args.d_model,
  653. 'd_head': args.d_head,
  654. 'd_inner': args.d_inner,
  655. 'dropout': args.dropout,
  656. 'dropatt': args.dropatt,
  657. 'dtype': None,
  658. 'tie_weight': args.tied,
  659. 'd_embed': args.d_embed,
  660. 'div_val': args.div_val,
  661. 'tie_projs': tie_projs,
  662. 'pre_lnorm': args.pre_lnorm,
  663. 'tgt_len': args.tgt_len,
  664. 'ext_len': args.ext_len,
  665. 'mem_len': args.mem_len,
  666. 'cutoffs': cutoffs,
  667. 'same_length': args.same_length,
  668. 'attn_type': args.attn_type,
  669. 'clamp_len': args.clamp_len,
  670. 'sample_softmax': args.sample_softmax,
  671. }
  672. model = MemTransformerLM(**model_config)
  673. model.apply(functools.partial(weights_init, args=args))
  674. # ensure embedding init is not overridden by out_layer in case of weight sharing
  675. model.word_emb.apply(functools.partial(weights_init, args=args))
  676. args.n_all_param = sum([p.nelement() for p in model.parameters()])
  677. args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
  678. # optimizer
  679. if args.optim.lower() == 'sgd':
  680. if args.sample_softmax > 0:
  681. dense_params, sparse_params = [], []
  682. for param in model.parameters():
  683. if param.size() == model.word_emb.weight.size():
  684. sparse_params.append(param)
  685. else:
  686. dense_params.append(param)
  687. optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
  688. optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
  689. else:
  690. optimizer = optim.SGD(model.parameters(), lr=args.lr,
  691. momentum=args.mom)
  692. optimizer_sparse = None
  693. elif args.optim.lower() == 'adam':
  694. if args.sample_softmax > 0:
  695. dense_params, sparse_params = [], []
  696. for param in model.parameters():
  697. if param.size() == model.word_emb.weight.size():
  698. sparse_params.append(param)
  699. else:
  700. dense_params.append(param)
  701. optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
  702. optimizer = optim.Adam(dense_params, lr=args.lr,
  703. weight_decay=args.weight_decay)
  704. else:
  705. optimizer = optim.Adam(model.parameters(), lr=args.lr,
  706. weight_decay=args.weight_decay)
  707. optimizer_sparse = None
  708. elif args.optim.lower() == 'adagrad':
  709. optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
  710. optimizer_sparse = None
  711. elif args.optim.lower() == 'lamb':
  712. optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
  713. weight_decay=args.weight_decay)
  714. optimizer_sparse = None
  715. elif args.optim.lower() == 'jitlamb':
  716. optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
  717. weight_decay=args.weight_decay)
  718. optimizer_sparse = None
  719. model = model.to(device)
  720. scaler = None
  721. if args.fp16:
  722. if args.amp == 'pytorch':
  723. scaler = torch.cuda.amp.GradScaler()
  724. elif args.amp == 'apex':
  725. model, optimizer = amp.initialize(
  726. model,
  727. optimizer,
  728. opt_level=args.apex_amp_opt_level,
  729. )
  730. if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
  731. para_model = DistributedDataParallel(model,
  732. device_ids=[args.local_rank],
  733. output_device=args.local_rank,
  734. broadcast_buffers=False,
  735. find_unused_parameters=True,
  736. )
  737. elif args.multi_gpu == 'dp':
  738. if args.gpu0_bsz >= 0:
  739. para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
  740. model, dim=1).to(device)
  741. else:
  742. para_model = nn.DataParallel(model, dim=1).to(device)
  743. else:
  744. para_model = model
  745. # scheduler
  746. if args.scheduler == 'cosine':
  747. if args.max_step_scheduler:
  748. max_step = args.max_step_scheduler
  749. else:
  750. max_step = args.max_step
  751. scheduler = optim.lr_scheduler.CosineAnnealingLR(
  752. optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
  753. if args.sample_softmax > 0 and optimizer_sparse is not None:
  754. scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
  755. optimizer_sparse, max_step - args.warmup_step,
  756. eta_min=args.eta_min)
  757. else:
  758. scheduler_sparse = None
  759. elif args.scheduler == 'inv_sqrt':
  760. # originally used for Transformer (in Attention is all you need)
  761. def lr_lambda(step):
  762. # return a multiplier instead of a learning rate
  763. if step == 0 and args.warmup_step == 0:
  764. return 1.
  765. else:
  766. return 1. / (step ** 0.5) if step > args.warmup_step \
  767. else step / (args.warmup_step ** 1.5)
  768. scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
  769. if args.sample_softmax > 0 and optimizer_sparse is not None:
  770. scheduler_sparse = optim.lr_scheduler.LambdaLR(
  771. optimizer_sparse,
  772. lr_lambda=lr_lambda
  773. )
  774. else:
  775. scheduler_sparse = None
  776. elif args.scheduler == 'dev_perf':
  777. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  778. optimizer, factor=args.decay_rate, patience=args.patience,
  779. min_lr=args.lr_min,
  780. )
  781. if args.sample_softmax > 0 and optimizer_sparse is not None:
  782. scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
  783. optimizer_sparse, factor=args.decay_rate, patience=args.patience,
  784. min_lr=args.lr_min,
  785. )
  786. else:
  787. scheduler_sparse = None
  788. elif args.scheduler == 'constant':
  789. pass
  790. logging.info('=' * 100)
  791. for k, v in args.__dict__.items():
  792. logging.info(' - {} : {}'.format(k, v))
  793. logging.info('=' * 100)
  794. logging.info('#params = {}'.format(args.n_all_param))
  795. logging.info('#non emb params = {}'.format(args.n_nonemb_param))
  796. train_step = 0
  797. start_epoch = 1
  798. last_batch = 0
  799. last_iter = 0
  800. best_val_loss = None
  801. if args.restart:
  802. try:
  803. checkpoint = load_checkpoint(args.restart)
  804. model.load_state_dict(checkpoint['model_state'])
  805. optimizer.load_state_dict(checkpoint['optimizer_state'])
  806. scheduler.load_state_dict(checkpoint['scheduler_state'])
  807. if args.fp16:
  808. if args.amp == 'pytorch':
  809. scaler.load_state_dict(checkpoint['amp_state'])
  810. elif args.amp == 'apex':
  811. amp.load_state_dict(checkpoint['amp_state'])
  812. train_step = checkpoint['train_step']
  813. start_epoch = checkpoint['epoch']
  814. last_batch = checkpoint['batch']
  815. last_iter = checkpoint['last_iter']
  816. best_val_loss = checkpoint['best_val_loss']
  817. if train_step >= args.max_step:
  818. logging.info(f'Loaded checkpoint after {train_step} steps, but '
  819. f'this run was scheduled for a total of '
  820. f'{args.max_step} steps, exiting')
  821. sys.exit(1)
  822. model.apply(functools.partial(update_dropout, args=args))
  823. model.apply(functools.partial(update_dropatt, args=args))
  824. except FileNotFoundError:
  825. logging.info(f'Could not load checkpoint from {args.restart}, '
  826. f'starting training from random init')
  827. meters = {}
  828. warmup = args.mem_len // args.tgt_len + 2
  829. meters['train_throughput'] = AverageMeter(warmup=warmup)
  830. ###########################################################################
  831. # Train
  832. ###########################################################################
  833. # Loop over epochs.
  834. # At any point you can hit Ctrl + C to break out of training early.
  835. start_time = time.time()
  836. with TimeoutHandler() as timeout_handler:
  837. try:
  838. for epoch in itertools.count(start=start_epoch):
  839. if args.roll:
  840. tr_iter.roll(seed=args.seed + epoch)
  841. train_step, best_val_loss = train(
  842. tr_iter, va_iter, model, para_model, model_config,
  843. optimizer, optimizer_sparse, scheduler, scheduler_sparse,
  844. scaler, vocab, epoch, last_batch, last_iter, train_step,
  845. best_val_loss, meters, timeout_handler, device, args
  846. )
  847. last_batch = 0
  848. last_iter = 0
  849. if train_step == args.max_step:
  850. logging.info('-' * 100)
  851. logging.info('End of training')
  852. break
  853. except KeyboardInterrupt:
  854. logging.info('-' * 100)
  855. logging.info('Exiting from training early')
  856. elapsed = time.time() - start_time
  857. ###########################################################################
  858. # Test
  859. ###########################################################################
  860. summary = {}
  861. test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  862. if not args.debug and not args.no_eval and os.path.exists(test_path):
  863. # Load the best saved model.
  864. checkpoint = load_checkpoint(test_path)
  865. model.load_state_dict(checkpoint['model_state'])
  866. # Run on test data.
  867. test_start_time = time.time()
  868. test_loss = evaluate(te_iter, model, args)
  869. test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
  870. test_elapsed = time.time() - test_start_time
  871. logging.info('=' * 100)
  872. if args.dataset in ['enwik8', 'text8']:
  873. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
  874. test_elapsed, test_loss, test_loss / math.log(2)))
  875. else:
  876. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
  877. test_elapsed, test_loss, math.exp(test_loss)))
  878. logging.info('=' * 100)
  879. summary.update({
  880. 'test_elapsed': test_elapsed,
  881. 'test_loss': test_loss,
  882. })
  883. if args.dataset in ['enwik8', 'text8']:
  884. summary['test_bits_per_character'] = test_loss / math.log(2)
  885. else:
  886. summary['test_perplexity'] = math.exp(test_loss)
  887. logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
  888. logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
  889. if best_val_loss:
  890. val_perplexity = math.exp(best_val_loss)
  891. else:
  892. val_perplexity = None
  893. summary.update({
  894. 'train_throughput': meters['train_throughput'].avg,
  895. 'train_elapsed': elapsed / 60,
  896. 'valid_loss': best_val_loss,
  897. 'valid_perplexity': val_perplexity,
  898. })
  899. dllogger.log(step=tuple(), data=summary)
  900. passed = benchmark(
  901. target_perplexity=args.target_perplexity,
  902. test_perplexity=val_perplexity,
  903. target_throughput=args.target_throughput,
  904. test_throughput=meters['train_throughput'].avg
  905. )
  906. if not passed:
  907. sys.exit(1)
  908. if __name__ == "__main__":
  909. # Disable profiling executor
  910. try:
  911. torch._C._jit_set_profiling_executor(False)
  912. torch._C._jit_set_profiling_mode(False)
  913. except AttributeError:
  914. pass
  915. # Before we do anything with models, we want to ensure that we get fp16
  916. # execution of torch.einsum in APEX AMP.
  917. # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
  918. # Note that running `--apex_amp_opt_level O2` will remove the need for this
  919. # code, but it is still valid.
  920. if 'apex' in sys.modules:
  921. amp.register_half_function(torch, 'einsum')
  922. main()