train.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102
  1. # coding: utf-8
  2. # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import functools
  17. import itertools
  18. import logging
  19. import math
  20. import os
  21. import shutil
  22. import sys
  23. import time
  24. import warnings
  25. import dllogger
  26. import numpy as np
  27. import torch
  28. import torch.nn as nn
  29. import torch.optim as optim
  30. import yaml
  31. try:
  32. from apex import amp
  33. except ModuleNotFoundError:
  34. warnings.warn('APEX AMP is unavailable')
  35. try:
  36. import pyprof
  37. except ModuleNotFoundError:
  38. warnings.warn('PyProf is unavailable')
  39. from torch.nn.parallel import DistributedDataParallel
  40. import lamb
  41. import utils
  42. from data_utils import get_lm_corpus
  43. from mem_transformer import MemTransformerLM
  44. from utils.data_parallel import BalancedDataParallel
  45. from utils.exp_utils import AverageMeter
  46. from utils.exp_utils import TimeoutHandler
  47. from utils.exp_utils import benchmark
  48. from utils.exp_utils import create_exp_dir
  49. from utils.exp_utils import l2_promote
  50. from utils.exp_utils import log_env_info
  51. from utils.exp_utils import register_ignoring_timeout_handler
  52. def parse_args():
  53. parent_parser = argparse.ArgumentParser(
  54. description='PyTorch Transformer-XL Language Model',
  55. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  56. add_help=False,
  57. )
  58. parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
  59. cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
  60. cfg_parser.add_argument('--config', default='default')
  61. cfg_parser.add_argument('--config_file', default=None)
  62. config_args, _ = cfg_parser.parse_known_args()
  63. if config_args.config is not None and config_args.config_file is not None:
  64. with open(config_args.config_file) as f:
  65. config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
  66. else:
  67. config = {}
  68. general = parser.add_argument_group('general setup')
  69. general.add_argument('--work_dir', default='LM-TFM', type=str,
  70. help='Directory for the results')
  71. general.add_argument('--append_dataset', action='store_true',
  72. help='Automatically append dataset name to work_dir')
  73. general.add_argument('--append_time', action='store_true',
  74. help='Automatically append current time to work_dir')
  75. general.add_argument('--cuda', action='store_true',
  76. help='Run training on a GPU using CUDA')
  77. general.add_argument('--fp16', action='store_true',
  78. help='Run training in fp16/mixed precision')
  79. general.add_argument('--restart', type=str, default='',
  80. help='Restart training from the saved checkpoint')
  81. general.add_argument('--debug', action='store_true',
  82. help='Run in debug mode (do not create exp dir)')
  83. general.add_argument('--log_all_ranks', action='store_true',
  84. help='Enable logging from all distributed ranks')
  85. general.add_argument('--dllog_file', type=str, default='train_log.json',
  86. help='Name of the DLLogger output file')
  87. general.add_argument('--txtlog_file', type=str, default='train_log.log',
  88. help='Name of the txt log file')
  89. general.add_argument('--save_all', action='store_true',
  90. help='Save all checkpoints')
  91. general.add_argument('--no_env', action='store_true',
  92. help='Do not print info on execution env')
  93. general.add_argument('--no_eval', action='store_true',
  94. help='Disable model evaluation')
  95. general.add_argument('--log_interval', type=int, default=10,
  96. help='Report interval')
  97. general.add_argument('--target_throughput', type=float, default=None,
  98. help='Target training throughput (for benchmarking)')
  99. general.add_argument('--target_perplexity', type=float, default=None,
  100. help='Target validation perplexity (for benchmarking)')
  101. general.add_argument('--apex_amp_opt_level', type=str, default='O2',
  102. choices=['O0', 'O1', 'O2', 'O3'],
  103. help='Optimization level for apex amp')
  104. general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
  105. help='Implementation of automatic mixed precision')
  106. general.add_argument('--affinity', type=str,
  107. default='socket_unique_interleaved',
  108. choices=['socket', 'single', 'single_unique',
  109. 'socket_unique_interleaved',
  110. 'socket_unique_continuous',
  111. 'disabled'],
  112. help='type of CPU affinity')
  113. general.add_argument('--profile', action='store_true',
  114. help='Enable profiling with DLProf')
  115. dataset = parser.add_argument_group('dataset setup')
  116. dataset.add_argument('--data', type=str, default='../data/wikitext-103',
  117. help='Location of the data corpus')
  118. dataset.add_argument('--dataset', type=str, default='wt103',
  119. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  120. help='Dataset name')
  121. dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
  122. help='Type of vocabulary')
  123. model = parser.add_argument_group('model setup')
  124. model.add_argument('--n_layer', type=int, default=16,
  125. help='Number of total layers')
  126. model.add_argument('--n_head', type=int, default=8,
  127. help='Number of heads')
  128. model.add_argument('--d_head', type=int, default=64,
  129. help='Head dimension')
  130. model.add_argument('--d_embed', type=int, default=-1,
  131. help='Embedding dimension')
  132. model.add_argument('--d_model', type=int, default=512,
  133. help='Model dimension')
  134. model.add_argument('--d_inner', type=int, default=2048,
  135. help='Inner dimension in feedforward layer')
  136. model.add_argument('--dropout', type=float, default=0.1,
  137. help='Global dropout rate')
  138. model.add_argument('--dropatt', type=float, default=0.0,
  139. help='Attention probability dropout rate')
  140. model.add_argument('--pre_lnorm', action='store_true',
  141. help='Apply LayerNorm to the input instead of the output')
  142. model.add_argument('--attn_type', type=int, default=0,
  143. help='Attention type. 0 for ours, 1 for Shaw et al,'
  144. '2 for Vaswani et al, 3 for Al Rfou et al.')
  145. model.add_argument('--not_tied', action='store_true',
  146. help='Do not tie the word embedding and softmax weights')
  147. model.add_argument('--clamp_len', type=int, default=-1,
  148. help='Use the same pos embeddings after clamp_len')
  149. model.add_argument('--adaptive', action='store_true',
  150. help='Use adaptive softmax')
  151. model.add_argument('--div_val', type=int, default=1,
  152. help='Dividend value for adaptive input and softmax')
  153. model.add_argument('--sample_softmax', type=int, default=-1,
  154. help='Number of samples in sampled softmax')
  155. model.add_argument('--init', default='normal', type=str,
  156. help='Parameter initializer to use')
  157. model.add_argument('--emb_init', default='normal', type=str,
  158. help='Parameter initializer to use')
  159. model.add_argument('--init_range', type=float, default=0.1,
  160. help='Parameters initialized by U(-init_range, init_range)')
  161. model.add_argument('--emb_init_range', type=float, default=0.01,
  162. help='Parameters initialized by U(-init_range, init_range)')
  163. model.add_argument('--init_std', type=float, default=0.02,
  164. help='Parameters initialized by N(0, init_std)')
  165. model.add_argument('--proj_init_std', type=float, default=0.01,
  166. help='Parameters initialized by N(0, init_std)')
  167. opt = parser.add_argument_group('optimizer setup')
  168. opt.add_argument('--optim', default='jitlamb', type=str,
  169. choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
  170. help='Optimizer to use')
  171. opt.add_argument('--lr', type=float, default=0.01,
  172. help='Initial learning rate')
  173. opt.add_argument('--mom', type=float, default=0.0,
  174. help='Momentum for sgd')
  175. opt.add_argument('--scheduler', default='cosine', type=str,
  176. choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
  177. help='LR scheduler to use')
  178. opt.add_argument('--max_step_scheduler', type=int, default=None,
  179. help='Max number of training steps for LR scheduler')
  180. opt.add_argument('--warmup_step', type=int, default=1000,
  181. help='Number of iterations for LR warmup')
  182. opt.add_argument('--decay_rate', type=float, default=0.5,
  183. help='Decay factor when ReduceLROnPlateau is used')
  184. opt.add_argument('--lr_min', type=float, default=0.0,
  185. help='Minimum learning rate during annealing')
  186. opt.add_argument('--clip', type=float, default=0.25,
  187. help='Gradient clipping')
  188. opt.add_argument('--weight_decay', type=float, default=0.0,
  189. help='Weight decay for adam|lamb')
  190. opt.add_argument('--clip_nonemb', action='store_true',
  191. help='Only clip the gradient of non-embedding params')
  192. opt.add_argument('--patience', type=int, default=0,
  193. help='Patience')
  194. opt.add_argument('--eta_min', type=float, default=0.001,
  195. help='Min learning rate for cosine scheduler')
  196. training = parser.add_argument_group('training setup')
  197. training.add_argument('--max_step', type=int, default=40000,
  198. help='Max number of training steps')
  199. training.add_argument('--batch_size', type=int, default=256,
  200. help='Global batch size')
  201. training.add_argument('--local_batch_size', type=int, default=None,
  202. help='Local (per-device) batch size, this setting \
  203. overrides global --batch_size and sets batch_size \
  204. to local_batch_size * world_size')
  205. training.add_argument('--batch_chunk', type=int, default=1,
  206. help='Split batch into chunks and train with '
  207. 'gradient accumulation')
  208. training.add_argument('--roll', action='store_true',
  209. help='Enable random shifts within each data stream')
  210. training.add_argument('--tgt_len', type=int, default=192,
  211. help='Number of tokens to predict')
  212. training.add_argument('--ext_len', type=int, default=0,
  213. help='Length of the extended context')
  214. training.add_argument('--mem_len', type=int, default=192,
  215. help='Length of the retained previous heads')
  216. training.add_argument('--seed', type=int, default=1111,
  217. help='Random seed')
  218. training.add_argument('--multi_gpu', default=None, type=str,
  219. choices=['ddp', 'dp'],
  220. help='Use multiple GPU')
  221. training.add_argument('--gpu0_bsz', type=int, default=-1,
  222. help='Batch size on gpu 0 (for "dp" backend)')
  223. training.add_argument('--same_length', action='store_true',
  224. help='Use the same attn length for all tokens')
  225. training.add_argument('--varlen', action='store_true',
  226. help='Use variable length')
  227. training.add_argument('--swap_mem', action='store_true',
  228. help='Swap memory tensors to cpu')
  229. val = parser.add_argument_group('validation setup')
  230. val.add_argument('--eval_tgt_len', type=int, default=192,
  231. help='Number of tokens to predict for evaluation')
  232. val.add_argument('--eval_batch_size', type=int, default=16,
  233. help='Eval batch size')
  234. val.add_argument('--eval_max_steps', type=int, default=-1,
  235. help='Max eval steps')
  236. val.add_argument('--eval_interval', type=int, default=5000,
  237. help='Evaluation interval')
  238. dist = parser.add_argument_group('distributed setup')
  239. dist.add_argument('--local_rank', type=int,
  240. default=os.getenv('LOCAL_RANK', 0),
  241. help='Used for multi-process training.')
  242. parser.set_defaults(**config)
  243. args, _ = parser.parse_known_args()
  244. args.tied = not args.not_tied
  245. if args.d_embed < 0:
  246. args.d_embed = args.d_model
  247. if args.ext_len < 0:
  248. raise RuntimeError('Extended context length must be non-negative')
  249. if args.mem_len == 0:
  250. if args.eval_tgt_len > args.ext_len + args.tgt_len:
  251. raise RuntimeError('eval_tgt_len should be <= tgt_len + ext_len; '
  252. f'eval_tgt_len: {args.eval_tgt_len}, '
  253. f'tgt_len: {args.tgt_len}, '
  254. f'ext_len: {args.ext_len}')
  255. else:
  256. if args.eval_tgt_len > args.mem_len + args.tgt_len:
  257. raise RuntimeError('eval_tgt_len should be <= tgt_len + mem_len; '
  258. f'eval_tgt_len: {args.eval_tgt_len}, '
  259. f'tgt_len: {args.tgt_len}, '
  260. f'mem_len: {args.mem_len}')
  261. if args.batch_size % args.batch_chunk != 0:
  262. raise RuntimeError('Batch size needs to be divisible by batch chunk')
  263. if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
  264. raise RuntimeError(
  265. 'APEX AMP unavailable, install APEX or switch to pytorch AMP'
  266. )
  267. return args
  268. def save_checkpoint(args, model, model_config, optimizer, scheduler, scaler,
  269. vocab, epoch, batch, last_iter, train_step, best_val_loss,
  270. is_best, work_dir):
  271. if args.fp16:
  272. if args.amp == 'pytorch':
  273. amp_state = scaler.state_dict()
  274. elif args.amp == 'apex':
  275. amp_state = amp.state_dict()
  276. else:
  277. amp_state = None
  278. state = {
  279. 'args': args,
  280. 'model_config': model_config,
  281. 'model_state': model.state_dict(),
  282. 'optimizer_state': optimizer.state_dict(),
  283. 'scheduler_state': scheduler.state_dict(),
  284. 'vocab': vocab,
  285. 'amp_state': amp_state,
  286. 'epoch': epoch,
  287. 'batch': batch,
  288. 'last_iter': last_iter,
  289. 'train_step': train_step,
  290. 'best_val_loss': best_val_loss,
  291. }
  292. last_chkpt_fname = 'checkpoint_last.pt'
  293. with utils.distributed.sync_workers() as rank:
  294. last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
  295. if rank == 0:
  296. # always save last checkpoint
  297. logging.info(f'Saving checkpoint to {last_chkpt_path}')
  298. torch.save(state, last_chkpt_path)
  299. # save best checkpoint if better than previous best
  300. if is_best:
  301. best_chkpt_fname = 'checkpoint_best.pt'
  302. best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
  303. logging.info(f'Saving checkpoint to {best_chkpt_path}')
  304. shutil.copy(last_chkpt_path, best_chkpt_path)
  305. # save every checkpoint if save_all is true
  306. if args.save_all:
  307. step_chkpt_fname = f'checkpoint_{train_step}.pt'
  308. step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
  309. logging.info(f'Saving checkpoint to {step_chkpt_path}')
  310. shutil.copy(last_chkpt_path, step_chkpt_path)
  311. def load_checkpoint(path):
  312. if os.path.isdir(path):
  313. path = os.path.join(path, 'checkpoint_last.pt')
  314. dst = f'cuda:{torch.cuda.current_device()}'
  315. logging.info(f'Loading checkpoint from {path}')
  316. checkpoint = torch.load(path, map_location=dst)
  317. return checkpoint
  318. def init_weight(weight, args):
  319. if args.init == 'uniform':
  320. nn.init.uniform_(weight, -args.init_range, args.init_range)
  321. elif args.init == 'normal':
  322. nn.init.normal_(weight, 0.0, args.init_std)
  323. def init_bias(bias):
  324. nn.init.constant_(bias, 0.0)
  325. def weights_init(m, args):
  326. classname = m.__class__.__name__
  327. if classname.find('Linear') != -1:
  328. if hasattr(m, 'weight') and m.weight is not None:
  329. init_weight(m.weight, args)
  330. if hasattr(m, 'bias') and m.bias is not None:
  331. init_bias(m.bias)
  332. elif classname.find('AdaptiveEmbedding') != -1:
  333. if hasattr(m, 'emb_projs'):
  334. for i in range(len(m.emb_projs)):
  335. if m.emb_projs[i] is not None:
  336. nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
  337. elif classname.find('Embedding') != -1:
  338. if hasattr(m, 'weight'):
  339. init_weight(m.weight, args)
  340. elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
  341. if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
  342. init_weight(m.cluster_weight, args)
  343. if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
  344. init_bias(m.cluster_bias)
  345. if hasattr(m, 'out_projs'):
  346. for i in range(len(m.out_projs)):
  347. if m.out_projs[i] is not None:
  348. nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
  349. if hasattr(m, 'out_layers_weights'):
  350. for i in range(len(m.out_layers_weights)):
  351. if m.out_layers_weights[i] is not None:
  352. init_weight(m.out_layers_weights[i], args)
  353. elif classname.find('LayerNorm') != -1:
  354. if hasattr(m, 'weight'):
  355. nn.init.normal_(m.weight, 1.0, args.init_std)
  356. if hasattr(m, 'bias') and m.bias is not None:
  357. init_bias(m.bias)
  358. elif classname.find('TransformerLM') != -1:
  359. if hasattr(m, 'r_emb'):
  360. init_weight(m.r_emb, args)
  361. if hasattr(m, 'r_w_bias'):
  362. init_weight(m.r_w_bias, args)
  363. if hasattr(m, 'r_r_bias'):
  364. init_weight(m.r_r_bias, args)
  365. if hasattr(m, 'r_bias'):
  366. init_bias(m.r_bias)
  367. def update_dropout(m, args):
  368. classname = m.__class__.__name__
  369. if classname.find('Dropout') != -1:
  370. if hasattr(m, 'p'):
  371. m.p = args.dropout
  372. def update_dropatt(m, args):
  373. if hasattr(m, 'dropatt'):
  374. m.dropatt.p = args.dropatt
  375. def evaluate(eval_iter, model, args):
  376. # Turn on evaluation mode which disables dropout.
  377. model.eval()
  378. # If the model does not use memory at all, make the ext_len longer.
  379. # Otherwise, make the mem_len longer and keep the ext_len the same.
  380. if args.mem_len == 0:
  381. model.reset_length(tgt_len=args.eval_tgt_len,
  382. ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
  383. mem_len=args.mem_len
  384. )
  385. else:
  386. model.reset_length(tgt_len=args.eval_tgt_len,
  387. ext_len=args.ext_len,
  388. mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
  389. )
  390. # Evaluation
  391. total_len, total_loss = 0, 0.
  392. with torch.no_grad():
  393. mems = None
  394. for i, (data, target, seq_len, warm) in enumerate(eval_iter):
  395. if args.eval_max_steps > 0 and i >= args.eval_max_steps:
  396. break
  397. loss, mems = model(data, target, mems)
  398. loss = loss.float().mean()
  399. if warm:
  400. # assert (mems is None) or mems.size(1) == model.mem_len
  401. total_loss += seq_len * loss.item()
  402. total_len += seq_len
  403. # Switch back to the training mode
  404. model.reset_length(tgt_len=args.tgt_len,
  405. ext_len=args.ext_len,
  406. mem_len=args.mem_len
  407. )
  408. model.train()
  409. return total_loss / total_len
  410. def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
  411. optimizer, device, delay_unscale, args):
  412. cpu = torch.device('cpu')
  413. data_i = data_chunks[i].contiguous()
  414. target_i = target_chunks[i].contiguous()
  415. if args.swap_mem and mems[i] is not None:
  416. mems[i] = mems[i].to(device, non_blocking=True)
  417. enable_autocast = args.fp16 and args.amp == 'pytorch'
  418. with torch.cuda.amp.autocast(enable_autocast):
  419. loss, mems[i] = model(data_i, target_i, mems[i])
  420. loss = loss.float().mean().type_as(loss) / args.batch_chunk
  421. if args.swap_mem and mems[i] is not None:
  422. mems[i] = mems[i].to(cpu, non_blocking=True)
  423. if args.fp16:
  424. if args.amp == 'pytorch':
  425. scaler.scale(loss).backward()
  426. elif args.amp == 'apex':
  427. with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss:
  428. scaled_loss.backward()
  429. else:
  430. loss.backward()
  431. train_loss = loss.float().item()
  432. return train_loss
  433. def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
  434. optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch,
  435. last_batch, last_iter, train_step, best_val_loss, meters,
  436. timeout_handler, device, args):
  437. # Turn on training mode which enables dropout.
  438. model.train()
  439. train_loss = 0
  440. target_tokens = 0
  441. log_step = 0
  442. log_start_time = time.time()
  443. mems = [None for _ in range(args.batch_chunk)]
  444. if args.varlen:
  445. train_iter = tr_iter.get_varlen_iter(start=last_iter)
  446. else:
  447. train_iter = tr_iter.get_fixlen_iter(start=last_iter)
  448. for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
  449. log_step += 1
  450. target_tokens += target.numel()
  451. for param in model.parameters():
  452. param.grad = None
  453. data_chunks = torch.chunk(data, args.batch_chunk, 1)
  454. target_chunks = torch.chunk(target, args.batch_chunk, 1)
  455. for i in range(args.batch_chunk):
  456. if i < args.batch_chunk - 1 and isinstance(para_model, DistributedDataParallel):
  457. with para_model.no_sync():
  458. train_loss_chunk = train_iteration(
  459. para_model, i, mems, data_chunks, target_chunks, scaler,
  460. optimizer, device, True, args
  461. )
  462. else:
  463. train_loss_chunk = train_iteration(
  464. para_model, i, mems, data_chunks, target_chunks, scaler,
  465. optimizer, device, False, args
  466. )
  467. train_loss += train_loss_chunk
  468. if args.fp16:
  469. if args.amp == 'pytorch':
  470. scaler.unscale_(optimizer)
  471. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
  472. elif args.amp == 'apex':
  473. torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
  474. else:
  475. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
  476. if args.fp16 and args.amp == 'pytorch':
  477. scaler.step(optimizer)
  478. scaler.update()
  479. else:
  480. optimizer.step()
  481. if optimizer_sparse:
  482. optimizer_sparse.step()
  483. # step-wise learning rate annealing
  484. train_step += 1
  485. if args.scheduler in ['cosine', 'constant', 'dev_perf']:
  486. # linear warmup stage
  487. if train_step < args.warmup_step:
  488. curr_lr = args.lr * train_step / args.warmup_step
  489. optimizer.param_groups[0]['lr'] = curr_lr
  490. if optimizer_sparse:
  491. optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
  492. else:
  493. if args.scheduler == 'cosine':
  494. scheduler.step(train_step - args.warmup_step)
  495. if scheduler_sparse:
  496. scheduler_sparse.step(train_step - args.warmup_step)
  497. elif args.scheduler == 'inv_sqrt':
  498. scheduler.step(train_step)
  499. if scheduler_sparse:
  500. scheduler_sparse.step(train_step)
  501. if train_step % args.log_interval == 0:
  502. cur_loss = train_loss / log_step
  503. cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
  504. train_loss = 0
  505. elapsed = time.time() - log_start_time
  506. avg_elapsed = elapsed / log_step
  507. avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
  508. log_start_time = time.time()
  509. log_step = 0
  510. lr = optimizer.param_groups[0]['lr']
  511. throughput = target_tokens / elapsed
  512. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  513. meters['train_throughput'].update(throughput)
  514. target_tokens = 0
  515. log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
  516. '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
  517. epoch,
  518. train_step,
  519. batch,
  520. tr_iter.n_batch,
  521. lr,
  522. avg_elapsed * 1000,
  523. throughput,
  524. cur_loss,
  525. )
  526. dllogger_data = {
  527. 'epoch': epoch,
  528. 'train_batch': batch+1,
  529. 'lr': lr,
  530. 'train_time/batch': avg_elapsed * 1000,
  531. 'train_throughput': throughput,
  532. 'train_loss': cur_loss,
  533. }
  534. if args.dataset in ['enwik8', 'text8']:
  535. log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
  536. dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
  537. else:
  538. log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
  539. dllogger_data['train_perplexity'] = math.exp(cur_loss)
  540. logging.info(log_str)
  541. dllogger.log(step=tuple([train_step]), data=dllogger_data)
  542. do_periodic_eval = train_step % args.eval_interval == 0
  543. is_final_step = train_step == args.max_step
  544. interrupted = timeout_handler.interrupted
  545. if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
  546. eval_start_time = time.time()
  547. val_loss = evaluate(va_iter, model, args)
  548. val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
  549. logging.info('-' * 100)
  550. log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
  551. '| valid loss {:5.2f}'.format(
  552. train_step // args.eval_interval,
  553. train_step,
  554. (time.time() - eval_start_time),
  555. val_loss,
  556. )
  557. dllogger_data = {
  558. 'valid_elapsed': (time.time() - eval_start_time),
  559. 'valid_loss': val_loss,
  560. }
  561. if args.dataset in ['enwik8', 'text8']:
  562. log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
  563. dllogger_data['valid_bits_per_character'] = val_loss / math.log(2)
  564. else:
  565. log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
  566. dllogger_data['valid_perplexity'] = math.exp(val_loss)
  567. logging.info(log_str)
  568. logging.info('-' * 100)
  569. dllogger.log(step=tuple([train_step]), data=dllogger_data)
  570. last_iter = tr_iter.last_iter
  571. # Check if the validation loss is the best we've seen so far.
  572. is_best = False
  573. if not best_val_loss or val_loss < best_val_loss:
  574. best_val_loss = val_loss
  575. is_best = True
  576. if not args.debug:
  577. save_checkpoint(args, model, model_config, optimizer, scheduler,
  578. scaler, vocab, epoch, batch, last_iter,
  579. train_step, best_val_loss, is_best,
  580. args.work_dir)
  581. # dev-performance based learning rate annealing
  582. if args.scheduler == 'dev_perf':
  583. scheduler.step(val_loss)
  584. if scheduler_sparse:
  585. scheduler_sparse.step(val_loss)
  586. # subtract eval time from timers for training
  587. log_start_time += time.time() - eval_start_time
  588. if interrupted:
  589. logging.info(f'Received SIGTERM, exiting')
  590. sys.exit(0)
  591. if is_final_step:
  592. break
  593. return train_step, best_val_loss
  594. def main():
  595. args = parse_args()
  596. if args.affinity != 'disabled':
  597. nproc_per_node = torch.cuda.device_count()
  598. affinity = utils.gpu_affinity.set_affinity(
  599. args.local_rank,
  600. nproc_per_node,
  601. args.affinity
  602. )
  603. print(f'{args.local_rank}: thread affinity: {affinity}')
  604. # Initialize device and distributed backend
  605. torch.cuda.set_device(args.local_rank)
  606. l2_promote()
  607. device = torch.device('cuda' if args.cuda else 'cpu')
  608. utils.distributed.init_distributed(args.cuda)
  609. args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
  610. args.dataset,
  611. args.append_dataset,
  612. args.append_time,
  613. )
  614. with utils.distributed.sync_workers() as rank:
  615. if rank == 0:
  616. create_exp_dir(args.work_dir,
  617. scripts_to_save=['train.py', 'mem_transformer.py'],
  618. debug=args.debug)
  619. # Setup logging
  620. if args.log_all_ranks:
  621. log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
  622. else:
  623. log_file = args.txtlog_file
  624. dllog_file = args.dllog_file
  625. log_file = os.path.join(args.work_dir, log_file)
  626. dllog_file = os.path.join(args.work_dir, dllog_file)
  627. if args.debug:
  628. log_file = os.devnull
  629. dllog_file = os.devnull
  630. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  631. filename=log_file,
  632. )
  633. utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
  634. if args.local_batch_size is not None:
  635. world_size = utils.distributed.get_world_size()
  636. args.batch_size = world_size * args.local_batch_size
  637. logging.info(f'--local_batch_size was set, adjusting global batch size'
  638. f' to {args.batch_size} (local_batch_size * world_size)')
  639. if args.profile:
  640. try:
  641. pyprof.init(enable_function_stack=True)
  642. except NameError:
  643. warnings.warn('Called pyprof.init() but pyprof is not available')
  644. logging.info(args)
  645. dllogger.log(step='PARAMETER', data=vars(args))
  646. logging.info(f'world size: {utils.distributed.get_world_size()}')
  647. if not args.no_env:
  648. log_env_info()
  649. register_ignoring_timeout_handler()
  650. # Set the random seed manually for reproducibility.
  651. np.random.seed(args.seed)
  652. torch.manual_seed(args.seed)
  653. ###########################################################################
  654. # Load data
  655. ###########################################################################
  656. corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
  657. ntokens = len(corpus.vocab)
  658. vocab = corpus.vocab
  659. args.n_token = ntokens
  660. if args.mem_len == 0:
  661. eval_mem_len = 0
  662. else:
  663. eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len
  664. tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
  665. device=device, ext_len=args.ext_len)
  666. va_iter = corpus.get_iterator('valid', args.eval_batch_size,
  667. args.eval_tgt_len, device=device,
  668. mem_len=eval_mem_len, ext_len=args.ext_len)
  669. te_iter = corpus.get_iterator('test', args.eval_batch_size,
  670. args.eval_tgt_len, device=device,
  671. mem_len=eval_mem_len, ext_len=args.ext_len)
  672. # adaptive softmax / embedding
  673. cutoffs, tie_projs = [], [False]
  674. if args.adaptive:
  675. assert args.dataset in ['wt103', 'lm1b']
  676. if args.dataset == 'wt103':
  677. cutoffs = [19997, 39997, 199997]
  678. tie_projs += [True] * len(cutoffs)
  679. elif args.dataset == 'lm1b':
  680. cutoffs = [59997, 99997, 639997]
  681. tie_projs += [False] * len(cutoffs)
  682. ###########################################################################
  683. # Build the model
  684. ###########################################################################
  685. model_config = {
  686. 'n_token': ntokens,
  687. 'n_layer': args.n_layer,
  688. 'n_head': args.n_head,
  689. 'd_model': args.d_model,
  690. 'd_head': args.d_head,
  691. 'd_inner': args.d_inner,
  692. 'dropout': args.dropout,
  693. 'dropatt': args.dropatt,
  694. 'dtype': None,
  695. 'tie_weight': args.tied,
  696. 'd_embed': args.d_embed,
  697. 'div_val': args.div_val,
  698. 'tie_projs': tie_projs,
  699. 'pre_lnorm': args.pre_lnorm,
  700. 'tgt_len': args.tgt_len,
  701. 'ext_len': args.ext_len,
  702. 'mem_len': args.mem_len,
  703. 'cutoffs': cutoffs,
  704. 'same_length': args.same_length,
  705. 'attn_type': args.attn_type,
  706. 'clamp_len': args.clamp_len,
  707. 'sample_softmax': args.sample_softmax,
  708. }
  709. model = MemTransformerLM(**model_config)
  710. model.apply(functools.partial(weights_init, args=args))
  711. # ensure embedding init is not overridden by out_layer in case of weight sharing
  712. model.word_emb.apply(functools.partial(weights_init, args=args))
  713. args.n_all_param = sum([p.nelement() for p in model.parameters()])
  714. args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
  715. # optimizer
  716. if args.optim.lower() == 'sgd':
  717. if args.sample_softmax > 0:
  718. dense_params, sparse_params = [], []
  719. for param in model.parameters():
  720. if param.size() == model.word_emb.weight.size():
  721. sparse_params.append(param)
  722. else:
  723. dense_params.append(param)
  724. optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
  725. optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
  726. else:
  727. optimizer = optim.SGD(model.parameters(), lr=args.lr,
  728. momentum=args.mom)
  729. optimizer_sparse = None
  730. elif args.optim.lower() == 'adam':
  731. if args.sample_softmax > 0:
  732. dense_params, sparse_params = [], []
  733. for param in model.parameters():
  734. if param.size() == model.word_emb.weight.size():
  735. sparse_params.append(param)
  736. else:
  737. dense_params.append(param)
  738. optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
  739. optimizer = optim.Adam(dense_params, lr=args.lr,
  740. weight_decay=args.weight_decay)
  741. else:
  742. optimizer = optim.Adam(model.parameters(), lr=args.lr,
  743. weight_decay=args.weight_decay)
  744. optimizer_sparse = None
  745. elif args.optim.lower() == 'adagrad':
  746. optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
  747. optimizer_sparse = None
  748. elif args.optim.lower() == 'lamb':
  749. optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
  750. weight_decay=args.weight_decay)
  751. optimizer_sparse = None
  752. elif args.optim.lower() == 'jitlamb':
  753. optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
  754. weight_decay=args.weight_decay)
  755. optimizer_sparse = None
  756. model = model.to(device)
  757. scaler = None
  758. if args.fp16:
  759. if args.amp == 'pytorch':
  760. scaler = torch.cuda.amp.GradScaler()
  761. elif args.amp == 'apex':
  762. model, optimizer = amp.initialize(
  763. model,
  764. optimizer,
  765. opt_level=args.apex_amp_opt_level,
  766. )
  767. if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
  768. para_model = DistributedDataParallel(model,
  769. device_ids=[args.local_rank],
  770. output_device=args.local_rank,
  771. broadcast_buffers=False,
  772. find_unused_parameters=True,
  773. )
  774. elif args.multi_gpu == 'dp':
  775. if args.gpu0_bsz >= 0:
  776. para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
  777. model, dim=1).to(device)
  778. else:
  779. para_model = nn.DataParallel(model, dim=1).to(device)
  780. else:
  781. para_model = model
  782. # scheduler
  783. if args.scheduler == 'cosine':
  784. if args.max_step_scheduler:
  785. max_step = args.max_step_scheduler
  786. else:
  787. max_step = args.max_step
  788. scheduler = optim.lr_scheduler.CosineAnnealingLR(
  789. optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
  790. if args.sample_softmax > 0 and optimizer_sparse is not None:
  791. scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
  792. optimizer_sparse, max_step - args.warmup_step,
  793. eta_min=args.eta_min)
  794. else:
  795. scheduler_sparse = None
  796. elif args.scheduler == 'inv_sqrt':
  797. # originally used for Transformer (in Attention is all you need)
  798. def lr_lambda(step):
  799. # return a multiplier instead of a learning rate
  800. if step == 0 and args.warmup_step == 0:
  801. return 1.
  802. else:
  803. return 1. / (step ** 0.5) if step > args.warmup_step \
  804. else step / (args.warmup_step ** 1.5)
  805. scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
  806. if args.sample_softmax > 0 and optimizer_sparse is not None:
  807. scheduler_sparse = optim.lr_scheduler.LambdaLR(
  808. optimizer_sparse,
  809. lr_lambda=lr_lambda
  810. )
  811. else:
  812. scheduler_sparse = None
  813. elif args.scheduler == 'dev_perf':
  814. scheduler = optim.lr_scheduler.ReduceLROnPlateau(
  815. optimizer, factor=args.decay_rate, patience=args.patience,
  816. min_lr=args.lr_min,
  817. )
  818. if args.sample_softmax > 0 and optimizer_sparse is not None:
  819. scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
  820. optimizer_sparse, factor=args.decay_rate, patience=args.patience,
  821. min_lr=args.lr_min,
  822. )
  823. else:
  824. scheduler_sparse = None
  825. elif args.scheduler == 'constant':
  826. pass
  827. logging.info('=' * 100)
  828. for k, v in args.__dict__.items():
  829. logging.info(' - {} : {}'.format(k, v))
  830. logging.info('=' * 100)
  831. logging.info('#params = {}'.format(args.n_all_param))
  832. logging.info('#non emb params = {}'.format(args.n_nonemb_param))
  833. train_step = 0
  834. start_epoch = 1
  835. last_batch = 0
  836. last_iter = 0
  837. best_val_loss = None
  838. if args.restart:
  839. try:
  840. checkpoint = load_checkpoint(args.restart)
  841. model.load_state_dict(checkpoint['model_state'])
  842. optimizer.load_state_dict(checkpoint['optimizer_state'])
  843. scheduler.load_state_dict(checkpoint['scheduler_state'])
  844. if args.fp16:
  845. if args.amp == 'pytorch':
  846. scaler.load_state_dict(checkpoint['amp_state'])
  847. elif args.amp == 'apex':
  848. amp.load_state_dict(checkpoint['amp_state'])
  849. train_step = checkpoint['train_step']
  850. start_epoch = checkpoint['epoch']
  851. last_batch = checkpoint['batch']
  852. last_iter = checkpoint['last_iter']
  853. best_val_loss = checkpoint['best_val_loss']
  854. if train_step >= args.max_step:
  855. logging.info(f'Loaded checkpoint after {train_step} steps, but '
  856. f'this run was scheduled for a total of '
  857. f'{args.max_step} steps, exiting')
  858. sys.exit(1)
  859. model.apply(functools.partial(update_dropout, args=args))
  860. model.apply(functools.partial(update_dropatt, args=args))
  861. except FileNotFoundError:
  862. logging.info(f'Could not load checkpoint from {args.restart}, '
  863. f'starting training from random init')
  864. meters = {}
  865. warmup = args.mem_len // args.tgt_len + 2
  866. meters['train_throughput'] = AverageMeter(warmup=warmup)
  867. ###########################################################################
  868. # Train
  869. ###########################################################################
  870. # Loop over epochs.
  871. # At any point you can hit Ctrl + C to break out of training early.
  872. start_time = time.time()
  873. with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
  874. with TimeoutHandler() as timeout_handler:
  875. try:
  876. for epoch in itertools.count(start=start_epoch):
  877. if args.roll:
  878. tr_iter.roll(seed=args.seed + epoch)
  879. train_step, best_val_loss = train(
  880. tr_iter, va_iter, model, para_model, model_config,
  881. optimizer, optimizer_sparse, scheduler,
  882. scheduler_sparse, scaler, vocab, epoch, last_batch,
  883. last_iter, train_step, best_val_loss, meters,
  884. timeout_handler, device, args
  885. )
  886. last_batch = 0
  887. last_iter = 0
  888. if train_step == args.max_step:
  889. logging.info('-' * 100)
  890. logging.info('End of training')
  891. break
  892. except KeyboardInterrupt:
  893. logging.info('-' * 100)
  894. logging.info('Exiting from training early')
  895. elapsed = time.time() - start_time
  896. ###########################################################################
  897. # Test
  898. ###########################################################################
  899. summary = {}
  900. test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  901. if not args.debug and not args.no_eval and os.path.exists(test_path):
  902. # Load the best saved model.
  903. checkpoint = load_checkpoint(test_path)
  904. model.load_state_dict(checkpoint['model_state'])
  905. # Run on test data.
  906. test_start_time = time.time()
  907. with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
  908. test_loss = evaluate(te_iter, model, args)
  909. test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
  910. test_elapsed = time.time() - test_start_time
  911. logging.info('=' * 100)
  912. if args.dataset in ['enwik8', 'text8']:
  913. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
  914. test_elapsed, test_loss, test_loss / math.log(2)))
  915. else:
  916. logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
  917. test_elapsed, test_loss, math.exp(test_loss)))
  918. logging.info('=' * 100)
  919. summary.update({
  920. 'test_elapsed': test_elapsed,
  921. 'test_loss': test_loss,
  922. })
  923. if args.dataset in ['enwik8', 'text8']:
  924. summary['test_bits_per_character'] = test_loss / math.log(2)
  925. else:
  926. summary['test_perplexity'] = math.exp(test_loss)
  927. logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
  928. logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
  929. if best_val_loss:
  930. val_perplexity = math.exp(best_val_loss)
  931. else:
  932. val_perplexity = None
  933. summary.update({
  934. 'train_throughput': meters['train_throughput'].avg,
  935. 'train_elapsed': elapsed / 60,
  936. 'valid_loss': best_val_loss,
  937. 'valid_perplexity': val_perplexity,
  938. })
  939. dllogger.log(step=tuple(), data=summary)
  940. passed = benchmark(
  941. target_perplexity=args.target_perplexity,
  942. test_perplexity=val_perplexity,
  943. target_throughput=args.target_throughput,
  944. test_throughput=meters['train_throughput'].avg
  945. )
  946. if not passed:
  947. sys.exit(1)
  948. if __name__ == "__main__":
  949. # Disable profiling executor
  950. try:
  951. torch._C._jit_set_profiling_executor(False)
  952. torch._C._jit_set_profiling_mode(False)
  953. except AttributeError:
  954. pass
  955. # Before we do anything with models, we want to ensure that we get fp16
  956. # execution of torch.einsum in APEX AMP.
  957. # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
  958. # Note that running `--apex_amp_opt_level O2` will remove the need for this
  959. # code, but it is still valid.
  960. if 'apex' in sys.modules:
  961. amp.register_half_function(torch, 'einsum')
  962. main()