eval.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import json
  16. import logging
  17. import math
  18. import os
  19. import pickle
  20. import sys
  21. import time
  22. import warnings
  23. import dllogger
  24. import numpy as np
  25. import torch
  26. import yaml
  27. import data_utils
  28. import utils
  29. from data_utils import get_lm_corpus
  30. from data_utils import tokenize_raw
  31. from utils.exp_utils import AverageMeter
  32. from utils.exp_utils import benchmark
  33. from utils.exp_utils import create_exp_dir
  34. from utils.exp_utils import l2_promote
  35. from utils.exp_utils import log_env_info
  36. def parse_args():
  37. parent_parser = argparse.ArgumentParser(
  38. description='PyTorch Transformer-XL Language Model',
  39. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  40. add_help=False,
  41. )
  42. parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
  43. cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
  44. cfg_parser.add_argument('--config', default='default')
  45. cfg_parser.add_argument('--config_file', default=None)
  46. config_args, _ = cfg_parser.parse_known_args()
  47. if config_args.config is not None and config_args.config_file is not None:
  48. with open(config_args.config_file) as f:
  49. config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['eval']
  50. else:
  51. config = {}
  52. parser.add_argument('--work_dir', default='LM-TFM', type=str,
  53. help='experiment directory')
  54. parser.add_argument('--debug', action='store_true',
  55. help='run in debug mode (do not create exp dir)')
  56. parser.add_argument('--data', type=str, default='../data/wikitext-103',
  57. help='location of the data corpus')
  58. parser.add_argument('--manual', type=str, default=None, nargs='+',
  59. help='run model on raw input data')
  60. parser.add_argument('--dataset', type=str, default='wt103',
  61. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  62. help='dataset name')
  63. parser.add_argument('--split', type=str, default='all',
  64. choices=['all', 'valid', 'test'],
  65. help='which split to evaluate')
  66. parser.add_argument('--affinity', type=str,
  67. default='single_unique',
  68. choices=['socket', 'single', 'single_unique',
  69. 'socket_unique_interleaved',
  70. 'socket_unique_continuous',
  71. 'disabled'],
  72. help='type of CPU affinity')
  73. parser.add_argument('--type', type=str, default='pytorch',
  74. choices=['pytorch', 'torchscript'],
  75. help='type of runtime to use')
  76. parser.add_argument('--batch_size', type=int, default=16,
  77. help='batch size')
  78. parser.add_argument('--tgt_len', type=int, default=64,
  79. help='number of tokens to predict')
  80. parser.add_argument('--ext_len', type=int, default=0,
  81. help='length of the extended context')
  82. parser.add_argument('--mem_len', type=int, default=640,
  83. help='length of the retained previous heads')
  84. parser.add_argument('--seed', type=int, default=1111,
  85. help='Random seed')
  86. parser.add_argument('--clamp_len', type=int, default=-1,
  87. help='max positional embedding index')
  88. parser.add_argument('--cuda', action='store_true',
  89. help='Run evaluation on a GPU using CUDA')
  90. parser.add_argument('--model', type=str, default='',
  91. help='path to the checkpoint')
  92. parser.add_argument('--manual_config', type=json.loads, default=None,
  93. help='Manually specify config for the model')
  94. parser.add_argument('--manual_vocab', type=str, default='word',
  95. choices=['word', 'bpe'],
  96. help='Manually specify type of vocabulary')
  97. parser.add_argument('--fp16', action='store_true',
  98. help='Run training in fp16/mixed precision')
  99. parser.add_argument('--log_all_ranks', action='store_true',
  100. help='Enable logging for all distributed ranks')
  101. parser.add_argument('--dllog_file', type=str, default='eval_log.json',
  102. help='Name of the DLLogger output file')
  103. parser.add_argument('--same_length', action='store_true',
  104. help='set same length attention with masking')
  105. parser.add_argument('--no_env', action='store_true',
  106. help='Do not print info on execution env')
  107. parser.add_argument('--log_interval', type=int, default=10,
  108. help='Report interval')
  109. parser.add_argument('--target_perplexity', type=float, default=None,
  110. help='target perplexity')
  111. parser.add_argument('--target_throughput', type=float, default=None,
  112. help='target throughput')
  113. parser.add_argument('--save_data', action='store_true',
  114. help='save latency and throughput data to a file')
  115. parser.add_argument('--repeat', type=int, default=1,
  116. help='loop over the dataset REPEAT times')
  117. parser.add_argument('--max_size', type=int, default=None,
  118. help='run inference on up to MAX_SIZE batches')
  119. parser.add_argument('--percentiles', nargs='+', default=[90, 95, 99],
  120. help='percentiles for latency confidence intervals')
  121. parser.add_argument('--save_torchscript', default=None, type=str,
  122. help='save torchscript model to a file')
  123. parser.add_argument('--load_torchscript', default=None, type=str,
  124. help='load torchscript model from a file')
  125. parser.add_argument('--local_rank', type=int,
  126. default=os.getenv('LOCAL_RANK', 0),
  127. help='Used for multi-process training.')
  128. parser.set_defaults(**config)
  129. args, _ = parser.parse_known_args()
  130. if args.manual:
  131. args.batch_size = 1
  132. if args.same_length and args.tgt_len > args.mem_len:
  133. warnings.warn('--same_length is intended to be used with large '
  134. 'mem_len relative to tgt_len')
  135. if args.ext_len < 0:
  136. raise RuntimeError('Extended context length must be non-negative')
  137. return args
  138. def load_checkpoint(path):
  139. dst = f'cuda:{torch.cuda.current_device()}'
  140. logging.info(f'Loading checkpoint from {path}')
  141. checkpoint = torch.load(path, map_location=dst)
  142. return checkpoint
  143. def format_log(loss, split, args):
  144. if args.dataset in ['enwik8', 'text8']:
  145. log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
  146. split, loss, loss / math.log(2))
  147. else:
  148. log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
  149. split, loss, math.exp(loss))
  150. return log_str
  151. def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
  152. total_len, total_loss = 0, 0.
  153. eval_step = 0
  154. log_throughput = 0
  155. log_latency = 0
  156. log_loss = 0
  157. torch.cuda.synchronize()
  158. start_time = time.time()
  159. with torch.no_grad():
  160. mems = None
  161. for _ in range(repeat):
  162. for idx, (data, target, seq_len, warm) in enumerate(eval_iter):
  163. if max_size and idx >= max_size:
  164. break
  165. eval_step += 1
  166. torch.cuda.synchronize()
  167. start_iter = time.time()
  168. loss, mems = model(data, target, mems)
  169. torch.cuda.synchronize()
  170. elapsed = time.time() - start_iter
  171. loss = loss.float().mean()
  172. log_loss += loss.item()
  173. if warm:
  174. total_loss += seq_len * loss.item()
  175. total_len += seq_len
  176. meters['eval_latency'].update(elapsed)
  177. log_latency += elapsed
  178. target_tokens = target.numel()
  179. throughput = target_tokens / elapsed
  180. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  181. meters['eval_throughput'].update(throughput)
  182. log_throughput += throughput
  183. if eval_step % log_interval == 0:
  184. log_throughput /= log_interval
  185. log_latency /= log_interval
  186. log_loss /= log_interval
  187. log_ppl = math.exp(log_loss)
  188. log_str = '| step {:>8d} | batches {:>6d} / {:d} ' \
  189. '| ms/batch {:5.2f} | tok/s {:7.0f} | loss {:5.2f} | ppl {:5.2f}'.format(
  190. eval_step,
  191. idx+1,
  192. eval_iter.n_batch,
  193. log_latency * 1000,
  194. log_throughput,
  195. log_loss,
  196. log_ppl,
  197. )
  198. logging.info(log_str)
  199. dllogger_data = {
  200. 'eval_latency': log_latency * 1000,
  201. 'eval_throughput': log_throughput,
  202. 'eval_loss': log_loss,
  203. 'eval_perplexity': log_ppl,
  204. }
  205. dllogger.log(step=tuple([eval_step]), data=dllogger_data)
  206. log_throughput = 0
  207. log_latency = 0
  208. log_loss = 0
  209. utils.distributed.barrier()
  210. torch.cuda.synchronize()
  211. total_time = time.time() - start_time
  212. logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
  213. total_time, 1000 * total_time / (idx+1)))
  214. avg_loss = total_loss / total_len
  215. avg_loss = utils.distributed.all_reduce_item(avg_loss, op='mean')
  216. return avg_loss
  217. def compile_model(model, device, args):
  218. inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  219. tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  220. start = time.time()
  221. with torch.no_grad():
  222. mems = None
  223. for _ in range(2):
  224. _, mems = model(inp, tgt, mems)
  225. torch.cuda.synchronize()
  226. stop = time.time()
  227. logging.info(f'Building the model took {stop - start:.2f} seconds')
  228. def main():
  229. args = parse_args()
  230. if args.affinity != 'disabled':
  231. nproc_per_node = torch.cuda.device_count()
  232. affinity = utils.gpu_affinity.set_affinity(
  233. args.local_rank,
  234. nproc_per_node,
  235. args.affinity
  236. )
  237. print(f'{args.local_rank}: thread affinity: {affinity}')
  238. if args.type == 'pytorch':
  239. from mem_transformer import MemTransformerLM
  240. else:
  241. from inference.mem_transformer_jit import MemTransformerLM
  242. torch.cuda.set_device(args.local_rank)
  243. l2_promote()
  244. device = torch.device('cuda' if args.cuda else 'cpu')
  245. utils.distributed.init_distributed(args.cuda)
  246. with utils.distributed.sync_workers() as rank:
  247. if rank == 0:
  248. create_exp_dir(args.work_dir, debug=args.debug)
  249. # Setup logging
  250. if args.log_all_ranks:
  251. log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log'
  252. else:
  253. log_file = f'eval_log.log'
  254. dllog_file = args.dllog_file
  255. log_file = os.path.join(args.work_dir, log_file)
  256. dllog_file = os.path.join(args.work_dir, dllog_file)
  257. if args.debug:
  258. log_file = os.devnull
  259. dllog_file = os.devnull
  260. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  261. filename=log_file,
  262. filemode='a',
  263. )
  264. utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
  265. logging.info(args)
  266. dllogger.log(step='PARAMETER', data=vars(args))
  267. dllogger.metadata('eval_throughput', {'unit': 'tokens/s'})
  268. dllogger.metadata('eval_loss', {'unit': None})
  269. dllogger.metadata('eval_perplexity', {'unit': None})
  270. dllogger.metadata('eval_latency', {'unit': 'ms'})
  271. dllogger.metadata('eval_avg_latency', {'unit': 'ms'})
  272. for p in args.percentiles:
  273. dllogger.metadata(f'eval_{p}%_latency', {'unit': 'ms'})
  274. if not args.no_env:
  275. log_env_info()
  276. # Set the random seed manually for reproducibility.
  277. np.random.seed(args.seed)
  278. torch.manual_seed(args.seed)
  279. if args.model:
  280. model_path = args.model
  281. elif args.work_dir:
  282. model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  283. else:
  284. raise RuntimeError('Specify path to checkpoint using --model or --work_dir')
  285. if not args.manual_config:
  286. checkpoint = load_checkpoint(model_path)
  287. vocab_type = checkpoint['args'].vocab
  288. else:
  289. checkpoint = None
  290. vocab_type = args.manual_vocab
  291. if args.manual:
  292. vocab = checkpoint['vocab']
  293. if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
  294. vocab.unk_idx = vocab.sym2idx['<unk>']
  295. text = " ".join(args.manual)
  296. tokenized = tokenize_raw(text)
  297. symbols = vocab.tokenize(tokenized, add_eos=True)
  298. tensor = vocab.convert_to_tensor(symbols)
  299. iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size,
  300. bptt=args.tgt_len, device=device,
  301. ext_len=args.ext_len, warmup=False)
  302. else:
  303. # Load dataset
  304. corpus = get_lm_corpus(args.data, args.dataset, vocab_type)
  305. if args.split == 'valid' or args.split == 'test':
  306. iter = corpus.get_iterator(args.split, args.batch_size, args.tgt_len,
  307. device=device, mem_len=args.mem_len,
  308. ext_len=args.ext_len)
  309. else:
  310. raise RuntimeError('Unknown split')
  311. if args.fp16:
  312. dtype = torch.float16
  313. math_str = 'fp16'
  314. else:
  315. dtype = torch.float32
  316. math_str = 'fp32'
  317. if args.load_torchscript:
  318. model = torch.jit.load(args.load_torchscript)
  319. elif not args.manual_config:
  320. checkpoint['model_config']['tgt_len'] = args.tgt_len
  321. checkpoint['model_config']['ext_len'] = args.ext_len
  322. checkpoint['model_config']['mem_len'] = args.mem_len
  323. checkpoint['model_config']['clamp_len'] = args.clamp_len
  324. checkpoint['model_config']['same_length'] = args.same_length
  325. checkpoint['model_config']['dtype'] = dtype
  326. model = MemTransformerLM(**checkpoint['model_config'])
  327. if args.type == 'pytorch':
  328. model.load_state_dict(checkpoint['model_state'])
  329. elif args.type == 'torchscript':
  330. model.load_state_dict(checkpoint['model_state'], strict=False)
  331. elif args.manual_config:
  332. args.manual_config['tgt_len'] = args.tgt_len
  333. args.manual_config['ext_len'] = args.ext_len
  334. args.manual_config['mem_len'] = args.mem_len
  335. args.manual_config['clamp_len'] = args.clamp_len
  336. args.manual_config['same_length'] = args.same_length
  337. args.manual_config['dtype'] = dtype
  338. model = MemTransformerLM(**args.manual_config)
  339. model = model.eval()
  340. model = model.to(device)
  341. model = model.to(dtype)
  342. if args.type == 'torchscript' and not args.manual_config:
  343. state = checkpoint['model_state']
  344. tie_projs = checkpoint['model_config']['tie_projs']
  345. tie_weight = checkpoint['model_config']['tie_weight']
  346. div_val = checkpoint['model_config']['div_val']
  347. d_model = checkpoint['model_config']['d_model']
  348. d_embed = checkpoint['model_config']['d_embed']
  349. if div_val != 1 or d_model != d_embed:
  350. for i in range(len(model.word_emb.emb_projs)):
  351. model.word_emb.emb_projs[i] = state[f'word_emb.emb_projs.{i}'].to(dtype)
  352. for i in range(len(model.crit.out_projs)):
  353. if div_val == 1:
  354. src = 0
  355. else:
  356. src = i
  357. if model.crit.out_projs[i] is not None:
  358. if tie_projs[i]:
  359. model.crit.out_projs[i] = state[f'word_emb.emb_projs.{src}'].to(dtype)
  360. else:
  361. model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to(dtype)
  362. for i in range(len(model.crit.out_layers_biases)):
  363. model.crit.out_layers_biases[i] = state[f'crit.out_layers_biases.{i}'].to(dtype)
  364. if tie_weight:
  365. for i in range(len(model.crit.out_layers_weights)):
  366. model.crit.out_layers_weights[i] = state[f'word_emb.emb_layers.{i}.weight'].to(dtype)
  367. else:
  368. for i in range(len(model.crit.out_layers_weights)):
  369. model.crit.out_layers_weights[i] = state[f'crit.out_layers_weights.{i}'].to(dtype)
  370. model = torch.jit.script(model)
  371. if args.type != 'pytorch':
  372. compile_model(model, device, args)
  373. if args.type == 'torchscript' and args.save_torchscript:
  374. torch.jit.save(model, args.save_torchscript)
  375. logging.info(f'Evaluating with: math {math_str} type {args.type} '
  376. f'bsz {args.batch_size} tgt_len {args.tgt_len} '
  377. f'ext_len {args.ext_len} mem_len {args.mem_len} '
  378. f'clamp_len {args.clamp_len}')
  379. meters = {}
  380. warmup = args.mem_len // args.tgt_len + 2
  381. meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
  382. meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
  383. loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat)
  384. perplexity = math.exp(loss)
  385. log_str = format_log(loss, args.split, args)
  386. summary = {
  387. 'eval_loss': loss,
  388. 'eval_ppl': perplexity,
  389. }
  390. logging.info('=' * 100)
  391. logging.info(log_str)
  392. logging.info('=' * 100)
  393. if args.save_data:
  394. latency_data = np.array(meters['eval_latency'].vals)
  395. throughput_data = np.array(meters['eval_throughput'].vals)
  396. precision = 'fp16' if args.fp16 else 'fp32'
  397. data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
  398. data_path = os.path.join(args.work_dir, data_fname)
  399. data = {
  400. 'args': args,
  401. 'throughput': throughput_data,
  402. 'latency': latency_data,
  403. }
  404. with open(data_path, 'wb') as f:
  405. pickle.dump(data, f)
  406. logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
  407. logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
  408. for p in args.percentiles:
  409. logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
  410. logging.info('=' * 100)
  411. summary.update({
  412. 'eval_throughput': throughput_data.mean(),
  413. 'eval_avg_latency': 1000 * latency_data.mean(),
  414. })
  415. for p in args.percentiles:
  416. summary[f'eval_{p}%_latency'] = 1000 * np.percentile(latency_data, p)
  417. dllogger.log(step=tuple(), data=summary)
  418. passed = benchmark(target_perplexity=args.target_perplexity,
  419. test_perplexity=perplexity,
  420. target_throughput=args.target_throughput,
  421. test_throughput=meters['eval_throughput'].avg,
  422. )
  423. if not passed:
  424. sys.exit(1)
  425. if __name__ == "__main__":
  426. # Disable profiling executor
  427. try:
  428. torch._C._jit_set_profiling_executor(False)
  429. torch._C._jit_set_profiling_mode(False)
  430. except AttributeError:
  431. pass
  432. main()