eval.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import json
  16. import logging
  17. import math
  18. import os
  19. import pickle
  20. import sys
  21. import time
  22. import dllogger
  23. import numpy as np
  24. import torch
  25. import yaml
  26. import data_utils
  27. import utils
  28. from data_utils import get_lm_corpus
  29. from data_utils import tokenize_raw
  30. from utils.exp_utils import AverageMeter
  31. from utils.exp_utils import benchmark
  32. from utils.exp_utils import create_exp_dir
  33. from utils.exp_utils import l2_promote
  34. from utils.exp_utils import log_env_info
  35. def parse_args():
  36. parent_parser = argparse.ArgumentParser(
  37. description='PyTorch Transformer-XL Language Model',
  38. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  39. add_help=False,
  40. )
  41. parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
  42. cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
  43. cfg_parser.add_argument('--config', default='default')
  44. cfg_parser.add_argument('--config_file', default=None)
  45. config_args, _ = cfg_parser.parse_known_args()
  46. if config_args.config is not None and config_args.config_file is not None:
  47. with open(config_args.config_file) as f:
  48. config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['eval']
  49. else:
  50. config = {}
  51. parser.add_argument('--work_dir', default='LM-TFM', type=str,
  52. help='experiment directory')
  53. parser.add_argument('--debug', action='store_true',
  54. help='run in debug mode (do not create exp dir)')
  55. parser.add_argument('--data', type=str, default='../data/wikitext-103',
  56. help='location of the data corpus')
  57. parser.add_argument('--manual', type=str, default=None, nargs='+',
  58. help='run model on raw input data')
  59. parser.add_argument('--dataset', type=str, default='wt103',
  60. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  61. help='dataset name')
  62. parser.add_argument('--split', type=str, default='all',
  63. choices=['all', 'valid', 'test'],
  64. help='which split to evaluate')
  65. parser.add_argument('--type', type=str, default='pytorch',
  66. choices=['pytorch', 'torchscript'],
  67. help='type of runtime to use')
  68. parser.add_argument('--batch_size', type=int, default=16,
  69. help='batch size')
  70. parser.add_argument('--tgt_len', type=int, default=64,
  71. help='number of tokens to predict')
  72. parser.add_argument('--ext_len', type=int, default=0,
  73. help='length of the extended context')
  74. parser.add_argument('--mem_len', type=int, default=640,
  75. help='length of the retained previous heads')
  76. parser.add_argument('--seed', type=int, default=1111,
  77. help='Random seed')
  78. parser.add_argument('--clamp_len', type=int, default=-1,
  79. help='max positional embedding index')
  80. parser.add_argument('--cuda', action='store_true',
  81. help='Run evaluation on a GPU using CUDA')
  82. parser.add_argument('--model', type=str, default='',
  83. help='path to the checkpoint')
  84. parser.add_argument('--manual_config', type=json.loads, default=None,
  85. help='Manually specify config for the model')
  86. parser.add_argument('--manual_vocab', type=str, default='word',
  87. choices=['word', 'bpe'],
  88. help='Manually specify type of vocabulary')
  89. parser.add_argument('--fp16', action='store_true',
  90. help='Run training in fp16/mixed precision')
  91. parser.add_argument('--log_all_ranks', action='store_true',
  92. help='Enable logging for all distributed ranks')
  93. parser.add_argument('--dllog_file', type=str, default='eval_log.json',
  94. help='Name of the DLLogger output file')
  95. parser.add_argument('--same_length', action='store_true',
  96. help='set same length attention with masking')
  97. parser.add_argument('--no_env', action='store_true',
  98. help='Do not print info on execution env')
  99. parser.add_argument('--log_interval', type=int, default=10,
  100. help='Report interval')
  101. parser.add_argument('--target_perplexity', type=float, default=None,
  102. help='target perplexity')
  103. parser.add_argument('--target_throughput', type=float, default=None,
  104. help='target throughput')
  105. parser.add_argument('--save_data', action='store_true',
  106. help='save latency and throughput data to a file')
  107. parser.add_argument('--repeat', type=int, default=1,
  108. help='loop over the dataset REPEAT times')
  109. parser.add_argument('--max_size', type=int, default=None,
  110. help='run inference on up to MAX_SIZE batches')
  111. parser.add_argument('--percentiles', nargs='+', default=[90, 95, 99],
  112. help='percentiles for latency confidence intervals')
  113. parser.add_argument('--save_torchscript', default=None, type=str,
  114. help='save torchscript model to a file')
  115. parser.add_argument('--load_torchscript', default=None, type=str,
  116. help='load torchscript model from a file')
  117. parser.add_argument('--local_rank', type=int,
  118. default=os.getenv('LOCAL_RANK', 0),
  119. help='Used for multi-process training.')
  120. parser.set_defaults(**config)
  121. args, _ = parser.parse_known_args()
  122. if args.manual:
  123. args.batch_size = 1
  124. assert args.ext_len >= 0, 'extended context length must be non-negative'
  125. return args
  126. def load_checkpoint(path):
  127. dst = f'cuda:{torch.cuda.current_device()}'
  128. logging.info(f'Loading checkpoint from {path}')
  129. checkpoint = torch.load(path, map_location=dst)
  130. return checkpoint
  131. def format_log(loss, split, args):
  132. if args.dataset in ['enwik8', 'text8']:
  133. log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
  134. split, loss, loss / math.log(2))
  135. else:
  136. log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
  137. split, loss, math.exp(loss))
  138. return log_str
  139. def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
  140. total_len, total_loss = 0, 0.
  141. eval_step = 0
  142. log_throughput = 0
  143. log_latency = 0
  144. log_loss = 0
  145. torch.cuda.synchronize()
  146. start_time = time.time()
  147. with torch.no_grad():
  148. mems = None
  149. for _ in range(repeat):
  150. for idx, (data, target, seq_len, warm) in enumerate(eval_iter):
  151. if max_size and idx >= max_size:
  152. break
  153. eval_step += 1
  154. torch.cuda.synchronize()
  155. start_iter = time.time()
  156. loss, mems = model(data, target, mems)
  157. torch.cuda.synchronize()
  158. elapsed = time.time() - start_iter
  159. loss = loss.float().mean()
  160. log_loss += loss.item()
  161. if warm:
  162. # assert all([m.size(0) == model.mem_len for m in mems])
  163. total_loss += seq_len * loss.item()
  164. total_len += seq_len
  165. meters['eval_latency'].update(elapsed)
  166. log_latency += elapsed
  167. target_tokens = target.numel()
  168. throughput = target_tokens / elapsed
  169. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  170. meters['eval_throughput'].update(throughput)
  171. log_throughput += throughput
  172. if eval_step % log_interval == 0:
  173. log_throughput /= log_interval
  174. log_latency /= log_interval
  175. log_loss /= log_interval
  176. log_ppl = math.exp(log_loss)
  177. log_str = '| step {:>8d} | batches {:>6d} / {:d} ' \
  178. '| ms/batch {:5.2f} | tok/s {:7.0f} | loss {:5.2f} | ppl {:5.2f}'.format(
  179. eval_step,
  180. idx+1,
  181. eval_iter.n_batch,
  182. log_latency * 1000,
  183. log_throughput,
  184. log_loss,
  185. log_ppl,
  186. )
  187. logging.info(log_str)
  188. dllogger_data = {
  189. 'eval_latency': log_latency * 1000,
  190. 'eval_throughput': log_throughput,
  191. 'eval_loss': log_loss,
  192. 'eval_perplexity': log_ppl,
  193. }
  194. dllogger.log(step=tuple([eval_step]), data=dllogger_data)
  195. log_throughput = 0
  196. log_latency = 0
  197. log_loss = 0
  198. utils.distributed.barrier()
  199. torch.cuda.synchronize()
  200. total_time = time.time() - start_time
  201. logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
  202. total_time, 1000 * total_time / (idx+1)))
  203. avg_loss = total_loss / total_len
  204. avg_loss = utils.distributed.all_reduce_item(avg_loss, op='mean')
  205. return avg_loss
  206. def compile_model(model, device, args):
  207. inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  208. tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  209. start = time.time()
  210. with torch.no_grad():
  211. mems = None
  212. for _ in range(2):
  213. _, mems = model(inp, tgt, mems)
  214. torch.cuda.synchronize()
  215. stop = time.time()
  216. logging.info(f'Building the model took {stop - start:.2f} seconds')
  217. def main():
  218. args = parse_args()
  219. utils.gpu_affinity.set_affinity(args.local_rank)
  220. if args.type == 'pytorch':
  221. from mem_transformer import MemTransformerLM
  222. else:
  223. from inference.mem_transformer_jit import MemTransformerLM
  224. torch.cuda.set_device(args.local_rank)
  225. l2_promote()
  226. device = torch.device('cuda' if args.cuda else 'cpu')
  227. utils.distributed.init_distributed(args.cuda)
  228. with utils.distributed.sync_workers() as rank:
  229. if rank == 0:
  230. create_exp_dir(args.work_dir, debug=args.debug)
  231. # Setup logging
  232. if args.log_all_ranks:
  233. log_file = f'eval_log_rank_{utils.distributed.get_rank()}.log'
  234. else:
  235. log_file = f'eval_log.log'
  236. dllog_file = args.dllog_file
  237. log_file = os.path.join(args.work_dir, log_file)
  238. dllog_file = os.path.join(args.work_dir, dllog_file)
  239. if args.debug:
  240. log_file = os.devnull
  241. dllog_file = os.devnull
  242. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  243. filename=log_file,
  244. filemode='a',
  245. )
  246. utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
  247. logging.info(args)
  248. dllogger.log(step='PARAMETER', data=vars(args))
  249. if not args.no_env:
  250. log_env_info()
  251. # Set the random seed manually for reproducibility.
  252. np.random.seed(args.seed)
  253. torch.manual_seed(args.seed)
  254. if args.model:
  255. model_path = args.model
  256. elif args.work_dir:
  257. model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  258. else:
  259. raise RuntimeError('Specify path to checkpoint using --model or --work_dir')
  260. if not args.manual_config:
  261. checkpoint = load_checkpoint(model_path)
  262. vocab_type = checkpoint['args'].vocab
  263. else:
  264. checkpoint = None
  265. vocab_type = args.manual_vocab
  266. if args.manual:
  267. vocab = checkpoint['vocab']
  268. if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
  269. vocab.unk_idx = vocab.sym2idx['<unk>']
  270. text = " ".join(args.manual)
  271. tokenized = tokenize_raw(text)
  272. symbols = vocab.tokenize(tokenized, add_eos=True)
  273. tensor = vocab.convert_to_tensor(symbols)
  274. iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size,
  275. bptt=args.tgt_len, device=device,
  276. ext_len=args.ext_len, warmup=False)
  277. else:
  278. # Load dataset
  279. corpus = get_lm_corpus(args.data, args.dataset, vocab_type)
  280. if args.split == 'valid' or args.split == 'test':
  281. iter = corpus.get_iterator(args.split, args.batch_size, args.tgt_len,
  282. device=device, mem_len=args.mem_len,
  283. ext_len=args.ext_len)
  284. else:
  285. raise RuntimeError('Unknown split')
  286. if args.fp16:
  287. dtype = torch.float16
  288. math_str = 'fp16'
  289. else:
  290. dtype = torch.float32
  291. math_str = 'fp32'
  292. if args.load_torchscript:
  293. model = torch.jit.load(args.load_torchscript)
  294. elif not args.manual_config:
  295. checkpoint['model_config']['tgt_len'] = args.tgt_len
  296. checkpoint['model_config']['ext_len'] = args.ext_len
  297. checkpoint['model_config']['mem_len'] = args.mem_len
  298. checkpoint['model_config']['clamp_len'] = args.clamp_len
  299. checkpoint['model_config']['same_length'] = args.same_length
  300. checkpoint['model_config']['dtype'] = dtype
  301. model = MemTransformerLM(**checkpoint['model_config'])
  302. if args.type == 'pytorch':
  303. model.load_state_dict(checkpoint['model_state'])
  304. elif args.type == 'torchscript':
  305. model.load_state_dict(checkpoint['model_state'], strict=False)
  306. elif args.manual_config:
  307. args.manual_config['tgt_len'] = args.tgt_len
  308. args.manual_config['ext_len'] = args.ext_len
  309. args.manual_config['mem_len'] = args.mem_len
  310. args.manual_config['clamp_len'] = args.clamp_len
  311. args.manual_config['same_length'] = args.same_length
  312. args.manual_config['dtype'] = dtype
  313. model = MemTransformerLM(**args.manual_config)
  314. model = model.eval()
  315. model = model.to(device)
  316. model = model.to(dtype)
  317. if args.type == 'torchscript' and not args.manual_config:
  318. state = checkpoint['model_state']
  319. tie_projs = checkpoint['model_config']['tie_projs']
  320. tie_weight = checkpoint['model_config']['tie_weight']
  321. div_val = checkpoint['model_config']['div_val']
  322. d_model = checkpoint['model_config']['d_model']
  323. d_embed = checkpoint['model_config']['d_embed']
  324. if div_val != 1 or d_model != d_embed:
  325. for i in range(len(model.word_emb.emb_projs)):
  326. model.word_emb.emb_projs[i] = state[f'word_emb.emb_projs.{i}'].to(dtype)
  327. for i in range(len(model.crit.out_projs)):
  328. if div_val == 1:
  329. src = 0
  330. else:
  331. src = i
  332. if model.crit.out_projs[i] is not None:
  333. if tie_projs[i]:
  334. model.crit.out_projs[i] = state[f'word_emb.emb_projs.{src}'].to(dtype)
  335. else:
  336. model.crit.out_projs[i] = state[f'crit.out_projs.{i}'].to(dtype)
  337. for i in range(len(model.crit.out_layers_biases)):
  338. model.crit.out_layers_biases[i] = state[f'crit.out_layers_biases.{i}'].to(dtype)
  339. if tie_weight:
  340. for i in range(len(model.crit.out_layers_weights)):
  341. model.crit.out_layers_weights[i] = state[f'word_emb.emb_layers.{i}.weight'].to(dtype)
  342. else:
  343. for i in range(len(model.crit.out_layers_weights)):
  344. model.crit.out_layers_weights[i] = state[f'crit.out_layers_weights.{i}'].to(dtype)
  345. model = torch.jit.script(model)
  346. if args.type != 'pytorch':
  347. compile_model(model, device, args)
  348. if args.type == 'torchscript' and args.save_torchscript:
  349. torch.jit.save(model, args.save_torchscript)
  350. logging.info(f'Evaluating with: math {math_str} type {args.type} '
  351. f'bsz {args.batch_size} tgt_len {args.tgt_len} '
  352. f'ext_len {args.ext_len} mem_len {args.mem_len} '
  353. f'clamp_len {args.clamp_len}')
  354. meters = {}
  355. warmup = args.mem_len // args.tgt_len + 2
  356. meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
  357. meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
  358. loss = evaluate(iter, model, meters, args.log_interval, args.max_size, args.repeat)
  359. perplexity = math.exp(loss)
  360. log_str = format_log(loss, args.split, args)
  361. summary = {
  362. 'eval_loss': loss,
  363. 'eval_ppl': perplexity,
  364. }
  365. logging.info('=' * 100)
  366. logging.info(log_str)
  367. logging.info('=' * 100)
  368. if args.save_data:
  369. latency_data = np.array(meters['eval_latency'].vals)
  370. throughput_data = np.array(meters['eval_throughput'].vals)
  371. precision = 'fp16' if args.fp16 else 'fp32'
  372. data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
  373. data_path = os.path.join(args.work_dir, data_fname)
  374. data = {
  375. 'args': args,
  376. 'throughput': throughput_data,
  377. 'latency': latency_data,
  378. }
  379. with open(data_path, 'wb') as f:
  380. pickle.dump(data, f)
  381. logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
  382. logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
  383. for p in args.percentiles:
  384. logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
  385. logging.info('=' * 100)
  386. summary.update({
  387. 'eval_throughput': throughput_data.mean(),
  388. 'eval_avg_latency': 1000 * latency_data.mean(),
  389. })
  390. for p in args.percentiles:
  391. summary[f'eval_{p}%_latency'] = 1000 * np.percentile(latency_data, p)
  392. dllogger.log(step=tuple(), data=summary)
  393. passed = benchmark(target_perplexity=args.target_perplexity,
  394. test_perplexity=perplexity,
  395. target_throughput=args.target_throughput,
  396. test_throughput=meters['eval_throughput'].avg,
  397. )
  398. if not passed:
  399. sys.exit(1)
  400. if __name__ == "__main__":
  401. # Disable profiling executor
  402. try:
  403. torch._C._jit_set_profiling_executor(False)
  404. torch._C._jit_set_profiling_mode(False)
  405. except AttributeError:
  406. pass
  407. main()