eval.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import json
  16. import logging
  17. import math
  18. import os
  19. import pickle
  20. import sys
  21. import time
  22. import warnings
  23. import dllogger
  24. import numpy as np
  25. import torch
  26. import yaml
  27. try:
  28. import pyprof
  29. except ModuleNotFoundError:
  30. warnings.warn('PyProf is unavailable')
  31. import data_utils
  32. import utils
  33. from data_utils import get_lm_corpus
  34. from data_utils import tokenize_raw
  35. from utils.exp_utils import AverageMeter
  36. from utils.exp_utils import benchmark
  37. from utils.exp_utils import create_exp_dir
  38. from utils.exp_utils import l2_promote
  39. from utils.exp_utils import log_env_info
  40. def parse_args():
  41. parent_parser = argparse.ArgumentParser(
  42. description='PyTorch Transformer-XL Language Model',
  43. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  44. add_help=False,
  45. )
  46. parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
  47. cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
  48. cfg_parser.add_argument('--config', default='default')
  49. cfg_parser.add_argument('--config_file', default=None)
  50. config_args, _ = cfg_parser.parse_known_args()
  51. if config_args.config is not None and config_args.config_file is not None:
  52. with open(config_args.config_file) as f:
  53. config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['eval']
  54. else:
  55. config = {}
  56. parser.add_argument('--work_dir', default='LM-TFM', type=str,
  57. help='experiment directory')
  58. parser.add_argument('--debug', action='store_true',
  59. help='run in debug mode (do not create exp dir)')
  60. parser.add_argument('--data', type=str, default='../data/wikitext-103',
  61. help='location of the data corpus')
  62. parser.add_argument('--manual', type=str, default=None, nargs='+',
  63. help='run model on raw input data')
  64. parser.add_argument('--dataset', type=str, default='wt103',
  65. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  66. help='dataset name')
  67. parser.add_argument('--split', type=str, default='all',
  68. choices=['all', 'valid', 'test'],
  69. help='which split to evaluate')
  70. parser.add_argument('--affinity', type=str,
  71. default='single_unique',
  72. choices=['socket', 'single', 'single_unique',
  73. 'socket_unique_interleaved',
  74. 'socket_unique_continuous',
  75. 'disabled'],
  76. help='type of CPU affinity')
  77. parser.add_argument('--profile', action='store_true',
  78. help='Enable profiling with DLProf')
  79. parser.add_argument('--type', type=str, default='pytorch',
  80. choices=['pytorch', 'torchscript'],
  81. help='type of runtime to use')
  82. parser.add_argument('--batch_size', type=int, default=16,
  83. help='batch size')
  84. parser.add_argument('--tgt_len', type=int, default=64,
  85. help='number of tokens to predict')
  86. parser.add_argument('--ext_len', type=int, default=0,
  87. help='length of the extended context')
  88. parser.add_argument('--mem_len', type=int, default=640,
  89. help='length of the retained previous heads')
  90. parser.add_argument('--seed', type=int, default=1111,
  91. help='Random seed')
  92. parser.add_argument('--clamp_len', type=int, default=-1,
  93. help='max positional embedding index')
  94. parser.add_argument('--cuda', action='store_true',
  95. help='Run evaluation on a GPU using CUDA')
  96. parser.add_argument('--model', type=str, default='',
  97. help='path to the checkpoint')
  98. parser.add_argument('--manual_config', type=json.loads, default=None,
  99. help='Manually specify config for the model')
  100. parser.add_argument('--manual_vocab', type=str, default='word',
  101. choices=['word', 'bpe'],
  102. help='Manually specify type of vocabulary')
  103. parser.add_argument('--fp16', action='store_true',
  104. help='Run training in fp16/mixed precision')
  105. parser.add_argument('--log_all_ranks', action='store_true',
  106. help='Enable logging for all distributed ranks')
  107. parser.add_argument('--dllog_file', type=str, default='eval_log.json',
  108. help='Name of the DLLogger output file')
  109. parser.add_argument('--same_length', action='store_true',
  110. help='set same length attention with masking')
  111. parser.add_argument('--no_env', action='store_true',
  112. help='Do not print info on execution env')
  113. parser.add_argument('--log_interval', type=int, default=10,
  114. help='Report interval')
  115. parser.add_argument('--target_perplexity', type=float, default=None,
  116. help='target perplexity')
  117. parser.add_argument('--target_throughput', type=float, default=None,
  118. help='target throughput')
  119. parser.add_argument('--save_data', action='store_true',
  120. help='save latency and throughput data to a file')
  121. parser.add_argument('--repeat', type=int, default=1,
  122. help='loop over the dataset REPEAT times')
  123. parser.add_argument('--max_size', type=int, default=None,
  124. help='run inference on up to MAX_SIZE batches')
  125. parser.add_argument('--percentiles', nargs='+', default=[90, 95, 99],
  126. help='percentiles for latency confidence intervals')
  127. parser.add_argument('--save_torchscript', default=None, type=str,
  128. help='save torchscript model to a file')
  129. parser.add_argument('--load_torchscript', default=None, type=str,
  130. help='load torchscript model from a file')
  131. parser.add_argument('--local_rank', type=int,
  132. default=os.getenv('LOCAL_RANK', 0),
  133. help='Used for multi-process training.')
  134. parser.set_defaults(**config)
  135. args, _ = parser.parse_known_args()
  136. if args.manual:
  137. args.batch_size = 1
  138. if args.same_length and args.tgt_len > args.mem_len:
  139. warnings.warn('--same_length is intended to be used with large '
  140. 'mem_len relative to tgt_len')
  141. if args.ext_len < 0:
  142. raise RuntimeError('Extended context length must be non-negative')
  143. return args
  144. def load_checkpoint(path):
  145. dst = f'cuda:{torch.cuda.current_device()}'
  146. logging.info(f'Loading checkpoint from {path}')
  147. checkpoint = torch.load(path, map_location=dst)
  148. return checkpoint
  149. def format_log(loss, split, args):
  150. if args.dataset in ['enwik8', 'text8']:
  151. log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
  152. split, loss, loss / math.log(2))
  153. else:
  154. log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
  155. split, loss, math.exp(loss))
  156. return log_str
  157. def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
  158. total_len, total_loss = 0, 0.
  159. eval_step = 0
  160. log_throughput = 0
  161. log_latency = 0
  162. log_loss = 0
  163. torch.cuda.synchronize()
  164. start_time = time.time()
  165. with torch.no_grad():
  166. mems = None
  167. for _ in range(repeat):
  168. for idx, (data, target, seq_len, warm) in enumerate(eval_iter):
  169. if max_size and idx >= max_size:
  170. break
  171. eval_step += 1
  172. torch.cuda.synchronize()
  173. start_iter = time.time()
  174. loss, mems = model(data, target, mems)
  175. torch.cuda.synchronize()
  176. elapsed = time.time() - start_iter
  177. loss = loss.float().mean()
  178. log_loss += loss.item()
  179. if warm:
  180. total_loss += seq_len * loss.item()
  181. total_len += seq_len
  182. meters['eval_latency'].update(elapsed)
  183. log_latency += elapsed
  184. target_tokens = target.numel()
  185. throughput = target_tokens / elapsed
  186. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  187. meters['eval_throughput'].update(throughput)
  188. log_throughput += throughput
  189. if eval_step % log_interval == 0:
  190. log_throughput /= log_interval
  191. log_latency /= log_interval
  192. log_loss /= log_interval
  193. log_ppl = math.exp(log_loss)
  194. log_str = '| step {:>8d} | batches {:>6d} / {:d} ' \
  195. '| ms/batch {:5.2f} | tok/s {:7.0f} | loss {:5.2f} | ppl {:5.2f}'.format(
  196. eval_step,
  197. idx+1,
  198. eval_iter.n_batch,
  199. log_latency * 1000,
  200. log_throughput,
  201. log_loss,
  202. log_ppl,
  203. )
  204. logging.info(log_str)
  205. dllogger_data = {
  206. 'eval_latency': log_latency * 1000,
  207. 'eval_throughput': log_throughput,
  208. 'eval_loss': log_loss,
  209. 'eval_perplexity': log_ppl,
  210. }
  211. dllogger.log(step=tuple([eval_step]), data=dllogger_data)
  212. log_throughput = 0
  213. log_latency = 0
  214. log_loss = 0
  215. utils.distributed.barrier()
  216. torch.cuda.synchronize()
  217. total_time = time.time() - start_time
  218. logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
  219. total_time, 1000 * total_time / (idx+1)))
  220. avg_loss = total_loss / total_len
  221. avg_loss = utils.distributed.all_reduce_item(avg_loss, op='mean')
  222. return avg_loss
  223. def compile_model(model, device, args):
  224. inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  225. tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  226. start = time.time()
  227. with torch.no_grad():
  228. mems = None
  229. for _ in range(2):
  230. _, mems = model(inp, tgt, mems)
  231. torch.cuda.synchronize()
  232. stop = time.time()
  233. logging.info(f'Building the model took {stop - start:.2f} seconds')
  234. def main():
  235. args = parse_args()
  236. if args.affinity != 'disabled':
  237. nproc_per_node = torch.cuda.device_count()
  238. affinity = utils.gpu_affinity.set_affinity(
  239. args.local_rank,
  240. nproc_per_node,
  241. args.affinity
  242. )
  243. print(f'{args.local_rank}: thread affinity: {affinity}')
  244. if args.type == 'pytorch':
  245. from mem_transformer import MemTransformerLM
  246. else:
  247. from inference.mem_transformer_jit import MemTransformerLM
  248. torch.cuda.set_device(args.local_rank)
  249. l2_promote()
  250. device = torch.device('cuda' if args.cuda else 'cpu')
  251. utils.distributed.init_distributed(args.cuda)
  252. with utils.distributed.sync_workers() as rank:
  253. if rank == 0:
  254. create_exp_dir(args.work_dir, debug=args.debug)
  255. # Setup logging
  256. if args.log_all_ranks:
  257. log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log'
  258. else:
  259. log_file = f'eval_log.log'
  260. dllog_file = args.dllog_file
  261. log_file = os.path.join(args.work_dir, log_file)
  262. dllog_file = os.path.join(args.work_dir, dllog_file)
  263. if args.debug:
  264. log_file = os.devnull
  265. dllog_file = os.devnull
  266. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  267. filename=log_file,
  268. filemode='a',
  269. )
  270. utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
  271. if args.profile:
  272. try:
  273. pyprof.init(enable_function_stack=True)
  274. except NameError:
  275. warnings.warn('Called pyprof.init() but pyprof is not available')
  276. logging.info(args)
  277. dllogger.log(step='PARAMETER', data=vars(args))
  278. if not args.no_env:
  279. log_env_info()
  280. # Set the random seed manually for reproducibility.
  281. np.random.seed(args.seed)
  282. torch.manual_seed(args.seed)
  283. if args.model:
  284. model_path = args.model
  285. elif args.work_dir:
  286. model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  287. else:
  288. raise RuntimeError('Specify path to checkpoint using --model or --work_dir')
  289. if not args.manual_config:
  290. checkpoint = load_checkpoint(model_path)
  291. vocab_type = checkpoint['args'].vocab
  292. else:
  293. checkpoint = None
  294. vocab_type = args.manual_vocab
  295. if args.manual:
  296. vocab = checkpoint['vocab']
  297. if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
  298. vocab.unk_idx = vocab.sym2idx['<unk>']
  299. text = " ".join(args.manual)
  300. tokenized = tokenize_raw(text)
  301. symbols = vocab.tokenize(tokenized, add_eos=True)
  302. tensor = vocab.convert_to_tensor(symbols)
  303. iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size,
  304. bptt=args.tgt_len, device=device,
  305. ext_len=args.ext_len, warmup=False)
  306. else:
  307. # Load dataset
  308. corpus = get_lm_corpus(args.data, args.dataset, vocab_type)
  309. if args.split == 'valid' or args.split == 'test':
  310. iter = corpus.get_iterator(args.split, args.batch_size, args.tgt_len,
  311. device=device, mem_len=args.mem_len,
  312. ext_len=args.ext_len)
  313. else:
  314. raise RuntimeError('Unknown split')
  315. if args.fp16:
  316. dtype = torch.float16
  317. math_str = 'fp16'
  318. else:
  319. dtype = torch.float32
  320. math_str = 'fp32'
  321. if args.load_torchscript:
  322. model = torch.jit.load(args.load_torchscript)
  323. elif not args.manual_config:
  324. checkpoint['model_config']['tgt_len'] = args.tgt_len
  325. checkpoint['model_config']['ext_len'] = args.ext_len
  326. checkpoint['model_config']['mem_len'] = args.mem_len
  327. checkpoint['model_config']['clamp_len'] = args.clamp_len
  328. checkpoint['model_config']['same_length'] = args.same_length
  329. checkpoint['model_config']['dtype'] = dtype
  330. model = MemTransformerLM(**checkpoint['model_config'])
  331. if args.type == 'pytorch':
  332. model.load_state_dict(checkpoint['model_state'])
  333. elif args.type == 'torchscript':
  334. model.load_state_dict(checkpoint['model_state'], strict=False)
  335. elif args.manual_config:
  336. args.manual_config['tgt_len'] = args.tgt_len
  337. args.manual_config['ext_len'] = args.ext_len
  338. args.manual_config['mem_len'] = args.mem_len
  339. args.manual_config['clamp_len'] = args.clamp_len
  340. args.manual_config['same_length'] = args.same_length
  341. args.manual_config['dtype'] = dtype
  342. model = MemTransformerLM(**args.manual_config)
  343. model = model.eval()
  344. model = model.to(device)
  345. model = model.to(dtype)
  346. if args.type == 'torchscript' and not args.manual_config:
  347. state = checkpoint['model_state']
  348. tie_projs = checkpoint['model_config']['tie_projs']
  349. tie_weight = checkpoint['model_config']['tie_weight']
  350. div_val = checkpoint['model_config']['div_val']
  351. d_model = checkpoint['model_config']['d_model']
  352. d_embed = checkpoint['model_config']['d_embed']
  353. if div_val != 1 or d_model != d_embed:
  354. for i in range(len(model.word_emb.emb_projs)):
  355. model.word_emb.emb_projs[i] = state[f'word_emb.emb_projs.{i}'].to(dtype)
  356. for i in range(len(model.crit.out_projs)):
  357. if div_val == 1:
  358. src = 0
  359. else:
  360. src = i
  361. if model.crit.out_projs[i] is not None:
  362. if tie_projs[i]:
  363. model.crit.out_projs[i] = state[f'word_emb.emb_projs.{src}'].to(dtype)
  364. else:
  365. model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to(dtype)
  366. for i in range(len(model.crit.out_layers_biases)):
  367. model.crit.out_layers_biases[i] = state[f'crit.out_layers_biases.{i}'].to(dtype)
  368. if tie_weight:
  369. for i in range(len(model.crit.out_layers_weights)):
  370. model.crit.out_layers_weights[i] = state[f'word_emb.emb_layers.{i}.weight'].to(dtype)
  371. else:
  372. for i in range(len(model.crit.out_layers_weights)):
  373. model.crit.out_layers_weights[i] = state[f'crit.out_layers_weights.{i}'].to(dtype)
  374. model = torch.jit.script(model)
  375. if args.type != 'pytorch':
  376. compile_model(model, device, args)
  377. if args.type == 'torchscript' and args.save_torchscript:
  378. torch.jit.save(model, args.save_torchscript)
  379. logging.info(f'Evaluating with: math {math_str} type {args.type} '
  380. f'bsz {args.batch_size} tgt_len {args.tgt_len} '
  381. f'ext_len {args.ext_len} mem_len {args.mem_len} '
  382. f'clamp_len {args.clamp_len}')
  383. meters = {}
  384. warmup = args.mem_len // args.tgt_len + 2
  385. meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
  386. meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
  387. with torch.autograd.profiler.emit_nvtx(enabled=args.profile):
  388. loss = evaluate(iter, model, meters, args.log_interval, args.max_size,
  389. args.repeat)
  390. perplexity = math.exp(loss)
  391. log_str = format_log(loss, args.split, args)
  392. summary = {
  393. 'eval_loss': loss,
  394. 'eval_ppl': perplexity,
  395. }
  396. logging.info('=' * 100)
  397. logging.info(log_str)
  398. logging.info('=' * 100)
  399. if args.save_data:
  400. latency_data = np.array(meters['eval_latency'].vals)
  401. throughput_data = np.array(meters['eval_throughput'].vals)
  402. precision = 'fp16' if args.fp16 else 'fp32'
  403. data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
  404. data_path = os.path.join(args.work_dir, data_fname)
  405. data = {
  406. 'args': args,
  407. 'throughput': throughput_data,
  408. 'latency': latency_data,
  409. }
  410. with open(data_path, 'wb') as f:
  411. pickle.dump(data, f)
  412. logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
  413. logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
  414. for p in args.percentiles:
  415. logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
  416. logging.info('=' * 100)
  417. summary.update({
  418. 'eval_throughput': throughput_data.mean(),
  419. 'eval_avg_latency': 1000 * latency_data.mean(),
  420. })
  421. for p in args.percentiles:
  422. summary[f'eval_{p}%_latency'] = 1000 * np.percentile(latency_data, p)
  423. dllogger.log(step=tuple(), data=summary)
  424. passed = benchmark(target_perplexity=args.target_perplexity,
  425. test_perplexity=perplexity,
  426. target_throughput=args.target_throughput,
  427. test_throughput=meters['eval_throughput'].avg,
  428. )
  429. if not passed:
  430. sys.exit(1)
  431. if __name__ == "__main__":
  432. # Disable profiling executor
  433. try:
  434. torch._C._jit_set_profiling_executor(False)
  435. torch._C._jit_set_profiling_mode(False)
  436. except AttributeError:
  437. pass
  438. main()