eval.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import logging
  16. import math
  17. import os
  18. import pickle
  19. import sys
  20. import time
  21. import numpy as np
  22. import torch
  23. import data_utils
  24. import utils
  25. from data_utils import get_lm_corpus
  26. from data_utils import tokenize_raw
  27. from utils.exp_utils import AverageMeter
  28. from utils.exp_utils import benchmark
  29. from utils.exp_utils import create_exp_dir
  30. def parse_args():
  31. parser = argparse.ArgumentParser(
  32. description='PyTorch Transformer Language Model',
  33. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  34. parser.add_argument('--work_dir', default='LM-TFM', type=str,
  35. help='experiment directory')
  36. parser.add_argument('--debug', action='store_true',
  37. help='run in debug mode (do not create exp dir)')
  38. parser.add_argument('--data', type=str, default='../data/wikitext-103',
  39. help='location of the data corpus')
  40. parser.add_argument('--manual', type=str, default=None, nargs='+',
  41. help='run model on raw input data')
  42. parser.add_argument('--dataset', type=str, default='wt103',
  43. choices=['wt103', 'lm1b', 'enwik8', 'text8'],
  44. help='dataset name')
  45. parser.add_argument('--split', type=str, default='all',
  46. choices=['all', 'valid', 'test'],
  47. help='which split to evaluate')
  48. parser.add_argument('--type', type=str, default='pytorch',
  49. choices=['pytorch', 'torchscript', 'onnx'],
  50. help='type of runtime to use')
  51. parser.add_argument('--batch_size', type=int, default=16,
  52. help='batch size')
  53. parser.add_argument('--tgt_len', type=int, default=64,
  54. help='number of tokens to predict')
  55. parser.add_argument('--ext_len', type=int, default=0,
  56. help='length of the extended context')
  57. parser.add_argument('--mem_len', type=int, default=640,
  58. help='length of the retained previous heads')
  59. parser.add_argument('--clamp_len', type=int, default=-1,
  60. help='max positional embedding index')
  61. parser.add_argument('--cuda', action='store_true',
  62. help='use CUDA')
  63. parser.add_argument('--model', type=str, default='',
  64. help='path to the checkpoint')
  65. parser.add_argument('--fp16', action='store_true',
  66. help='Run training in fp16/mixed precision')
  67. parser.add_argument('--log_all_ranks', action='store_true',
  68. help='Enable logging for all distributed ranks')
  69. parser.add_argument('--same_length', action='store_true',
  70. help='set same length attention with masking')
  71. parser.add_argument('--target_perplexity', type=float, default=None,
  72. help='target perplexity')
  73. parser.add_argument('--target_throughput', type=float, default=None,
  74. help='target throughput')
  75. parser.add_argument('--save_data', action='store_true',
  76. help='save latency and throughput data to a file')
  77. parser.add_argument('--repeat', type=int, default=1,
  78. help='loop over the dataset REPEAT times')
  79. parser.add_argument('--max_size', type=int, default=None,
  80. help='run inference on up to MAX_SIZE batches')
  81. parser.add_argument('--percentiles', nargs='+', default=[90, 95, 99],
  82. help='percentiles for latency confidence intervals')
  83. parser.add_argument('--save_torchscript', default=None, type=str,
  84. help='save torchscript model to a file')
  85. parser.add_argument('--load_torchscript', default=None, type=str,
  86. help='load torchscript model from a file')
  87. parser.add_argument('--local_rank', default=0, type=int,
  88. help='Used for multi-process training. ' +
  89. 'Can either be manually set ' +
  90. 'or automatically set by using \'python -m multiproc\'.')
  91. args = parser.parse_args()
  92. assert args.ext_len >= 0, 'extended context length must be non-negative'
  93. return args
  94. def load_checkpoint(path):
  95. dst = f'cuda:{torch.cuda.current_device()}'
  96. logging.info(f'Loading checkpoint from {path}')
  97. checkpoint = torch.load(path, map_location=dst)
  98. return checkpoint
  99. def format_log(loss, split, args):
  100. if args.dataset in ['enwik8', 'text8']:
  101. log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
  102. split, loss, loss / math.log(2))
  103. else:
  104. log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
  105. split, loss, math.exp(loss))
  106. return log_str
  107. def evaluate(eval_iter, model, meters, max_size=None, repeat=1):
  108. total_len, total_loss = 0, 0.
  109. torch.cuda.synchronize()
  110. start_time = time.time()
  111. with torch.no_grad():
  112. mems = None
  113. for _ in range(repeat):
  114. for idx, (data, target, seq_len) in enumerate(eval_iter):
  115. if max_size and idx >= max_size:
  116. break
  117. torch.cuda.synchronize()
  118. start_iter = time.time()
  119. ret = model(data, target, mems)
  120. torch.cuda.synchronize()
  121. elapsed = time.time() - start_iter
  122. loss, mems = ret[0], ret[1:]
  123. loss = loss.mean()
  124. total_loss += seq_len * loss.item()
  125. total_len += seq_len
  126. meters['eval_latency'].update(elapsed)
  127. target_tokens = target.numel()
  128. throughput = target_tokens / elapsed
  129. throughput = utils.distributed.all_reduce_item(throughput, op='sum')
  130. meters['eval_throughput'].update(throughput)
  131. utils.distributed.barrier()
  132. torch.cuda.synchronize()
  133. total_time = time.time() - start_time
  134. logging.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
  135. total_time, 1000 * total_time / (idx+1)))
  136. avg_loss = total_loss / total_len
  137. avg_loss = utils.distributed.all_reduce_item(avg_loss, op='mean')
  138. return avg_loss
  139. def compile_model(model, device, args):
  140. inp = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  141. tgt = torch.randint(0, 1000, (args.tgt_len, args.batch_size)).to(device)
  142. start = time.time()
  143. with torch.no_grad():
  144. mems = None
  145. for _ in range(2):
  146. ret = model(inp, tgt, mems)
  147. _, mems = ret[0], ret[1:]
  148. torch.cuda.synchronize()
  149. stop = time.time()
  150. logging.info(f'Building the model took {stop - start:.2f} seconds')
  151. def main():
  152. args = parse_args()
  153. if args.type == 'pytorch':
  154. from mem_transformer import MemTransformerLM
  155. else:
  156. from inference.mem_transformer_base_jit import MemTransformerLM
  157. torch.cuda.set_device(args.local_rank)
  158. device = torch.device('cuda' if args.cuda else 'cpu')
  159. utils.distributed.init_distributed(args.cuda)
  160. with utils.distributed.sync_workers() as rank:
  161. if rank == 0:
  162. create_exp_dir(args.work_dir, debug=args.debug)
  163. # Setup logging
  164. if args.log_all_ranks:
  165. log_file = f'log_rank_{utils.distributed.get_rank()}.log'
  166. else:
  167. log_file = f'log.log'
  168. log_file = os.path.join(args.work_dir, log_file)
  169. if args.debug:
  170. log_file = os.devnull
  171. utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
  172. filename=log_file,
  173. filemode='a',
  174. )
  175. logging.info(args)
  176. if args.model:
  177. model_path = args.model
  178. elif args.work_dir:
  179. model_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
  180. else:
  181. raise RuntimeError('Specify path to checkpoint using --model or --work_dir')
  182. checkpoint = load_checkpoint(model_path)
  183. if args.manual:
  184. args.batch_size = 1
  185. vocab = checkpoint['vocab']
  186. if hasattr(vocab, 'sym2idx') and not hasattr(vocab, 'unk_idx'):
  187. vocab.unk_idx = vocab.sym2idx['<unk>']
  188. text = " ".join(args.manual)
  189. tokenized = tokenize_raw(text)
  190. symbols = vocab.tokenize(tokenized, add_eos=True)
  191. tensor = vocab.convert_to_tensor(symbols)
  192. iter = data_utils.LMOrderedIterator(tensor, bsz=args.batch_size,
  193. bptt=args.tgt_len, device=device,
  194. ext_len=args.ext_len)
  195. else:
  196. # Load dataset
  197. corpus = get_lm_corpus(args.data, args.dataset, checkpoint['args'].vocab)
  198. if args.split == 'valid':
  199. iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
  200. device=device, ext_len=args.ext_len)
  201. elif args.split == 'test':
  202. iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
  203. device=device, ext_len=args.ext_len)
  204. else:
  205. raise RuntimeError('Unknown split')
  206. if args.fp16:
  207. dtype = torch.float16
  208. math_str = 'fp16'
  209. else:
  210. dtype = torch.float32
  211. math_str = 'fp32'
  212. if args.load_torchscript:
  213. model = torch.jit.load(args.load_torchscript)
  214. else:
  215. checkpoint['model_config']['tgt_len'] = args.tgt_len
  216. checkpoint['model_config']['ext_len'] = args.ext_len
  217. checkpoint['model_config']['mem_len'] = args.mem_len
  218. checkpoint['model_config']['clamp_len'] = args.clamp_len
  219. checkpoint['model_config']['same_length'] = args.same_length
  220. checkpoint['model_config']['dtype'] = dtype
  221. model = MemTransformerLM(**checkpoint['model_config'])
  222. model.load_state_dict(checkpoint['model_state'])
  223. model = model.eval()
  224. model = model.to(device)
  225. model = model.float()
  226. if args.fp16:
  227. model = model.half()
  228. if args.type != 'pytorch':
  229. compile_model(model, device, args)
  230. if args.type == 'torchscript' and args.save_torchscript:
  231. torch.jit.save(model, args.save_torchscript)
  232. logging.info(f'Evaluating with: math {math_str} type {args.type} '
  233. f'bsz {args.batch_size} tgt_len {args.tgt_len} '
  234. f'ext_len {args.ext_len} mem_len {args.mem_len} '
  235. f'clamp_len {args.clamp_len}')
  236. meters = {}
  237. warmup = args.mem_len // args.tgt_len + 1
  238. meters['eval_throughput'] = AverageMeter(warmup=warmup, keep=args.save_data)
  239. meters['eval_latency'] = AverageMeter(warmup=warmup, keep=args.save_data)
  240. loss = evaluate(iter, model, meters, args.max_size, args.repeat)
  241. perplexity = math.exp(loss)
  242. log_str = format_log(loss, args.split, args)
  243. logging.info('=' * 100)
  244. logging.info(log_str)
  245. logging.info('=' * 100)
  246. if args.save_data:
  247. latency_data = np.array(meters['eval_latency'].vals)
  248. throughput_data = np.array(meters['eval_throughput'].vals)
  249. precision = 'fp16' if args.fp16 else 'fp32'
  250. data_fname = f'eval_data_{args.batch_size}_{precision}_{args.type}'
  251. data_path = os.path.join(args.work_dir, data_fname)
  252. data = {
  253. 'args': args,
  254. 'throughput': throughput_data,
  255. 'latency': latency_data,
  256. }
  257. with open(data_path, 'wb') as f:
  258. pickle.dump(data, f)
  259. logging.info(f'Throughput Avg: {throughput_data.mean():.2f} tok/s')
  260. logging.info(f'Latency Avg: {1000.0 * latency_data.mean():.2f} ms')
  261. for p in args.percentiles:
  262. logging.info(f'Latency {p}%: {1000.0 * np.percentile(latency_data, p):.2f} ms')
  263. logging.info('=' * 100)
  264. passed = benchmark(target_perplexity=args.target_perplexity,
  265. test_perplexity=perplexity,
  266. target_throughput=args.target_throughput,
  267. test_throughput=meters['eval_throughput'].avg,
  268. )
  269. if not passed:
  270. sys.exit(1)
  271. if __name__ == "__main__":
  272. main()