translate.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. #!/usr/bin/env python
  2. # Copyright (c) 2017 Elad Hoffer
  3. # Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Permission is hereby granted, free of charge, to any person obtaining a copy
  6. # of this software and associated documentation files (the "Software"), to deal
  7. # in the Software without restriction, including without limitation the rights
  8. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. # copies of the Software, and to permit persons to whom the Software is
  10. # furnished to do so, subject to the following conditions:
  11. #
  12. # The above copyright notice and this permission notice shall be included in all
  13. # copies or substantial portions of the Software.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. # SOFTWARE.
  22. import os
  23. os.environ['KMP_AFFINITY'] = 'disabled'
  24. import argparse
  25. import itertools
  26. import logging
  27. import sys
  28. import warnings
  29. from itertools import product
  30. import dllogger
  31. import numpy as np
  32. import torch
  33. import seq2seq.gpu_affinity as gpu_affinity
  34. import seq2seq.utils as utils
  35. from seq2seq.data.dataset import RawTextDataset
  36. from seq2seq.data.dataset import SyntheticDataset
  37. from seq2seq.data.tokenizer import Tokenizer
  38. from seq2seq.inference import tables
  39. from seq2seq.inference.translator import Translator
  40. from seq2seq.models.gnmt import GNMT
  41. def parse_args():
  42. """
  43. Parse commandline arguments.
  44. """
  45. def exclusive_group(group, name, default, help):
  46. destname = name.replace('-', '_')
  47. subgroup = group.add_mutually_exclusive_group(required=False)
  48. subgroup.add_argument(f'--{name}', dest=f'{destname}',
  49. action='store_true',
  50. help=f'{help} (use \'--no-{name}\' to disable)')
  51. subgroup.add_argument(f'--no-{name}', dest=f'{destname}',
  52. action='store_false', help=argparse.SUPPRESS)
  53. subgroup.set_defaults(**{destname: default})
  54. parser = argparse.ArgumentParser(
  55. description='GNMT Translate',
  56. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  57. # dataset
  58. dataset = parser.add_argument_group('data setup')
  59. dataset.add_argument('-o', '--output', required=False,
  60. help='full path to the output file \
  61. if not specified, then the output will be printed')
  62. dataset.add_argument('-r', '--reference', default=None,
  63. help='full path to the file with reference \
  64. translations (for sacrebleu, raw text)')
  65. dataset.add_argument('-m', '--model', type=str, default=None,
  66. help='full path to the model checkpoint file')
  67. dataset.add_argument('--synthetic', action='store_true',
  68. help='use synthetic dataset')
  69. dataset.add_argument('--synthetic-batches', type=int, default=64,
  70. help='number of synthetic batches to generate')
  71. dataset.add_argument('--synthetic-vocab', type=int, default=32320,
  72. help='size of synthetic vocabulary')
  73. dataset.add_argument('--synthetic-len', type=int, default=50,
  74. help='sequence length of synthetic samples')
  75. source = dataset.add_mutually_exclusive_group(required=False)
  76. source.add_argument('-i', '--input', required=False,
  77. help='full path to the input file (raw text)')
  78. source.add_argument('-t', '--input-text', nargs='+', required=False,
  79. help='raw input text')
  80. exclusive_group(group=dataset, name='sort', default=False,
  81. help='sorts dataset by sequence length')
  82. # parameters
  83. params = parser.add_argument_group('inference setup')
  84. params.add_argument('--batch-size', nargs='+', default=[128], type=int,
  85. help='batch size per GPU')
  86. params.add_argument('--beam-size', nargs='+', default=[5], type=int,
  87. help='beam size')
  88. params.add_argument('--max-seq-len', default=80, type=int,
  89. help='maximum generated sequence length')
  90. params.add_argument('--len-norm-factor', default=0.6, type=float,
  91. help='length normalization factor')
  92. params.add_argument('--cov-penalty-factor', default=0.1, type=float,
  93. help='coverage penalty factor')
  94. params.add_argument('--len-norm-const', default=5.0, type=float,
  95. help='length normalization constant')
  96. # general setup
  97. general = parser.add_argument_group('general setup')
  98. general.add_argument('--math', nargs='+', default=['fp16'],
  99. choices=['fp16', 'fp32', 'tf32'], help='precision')
  100. exclusive_group(group=general, name='env', default=False,
  101. help='print info about execution env')
  102. exclusive_group(group=general, name='bleu', default=True,
  103. help='compares with reference translation and computes \
  104. BLEU')
  105. exclusive_group(group=general, name='cuda', default=True,
  106. help='enables cuda')
  107. exclusive_group(group=general, name='cudnn', default=True,
  108. help='enables cudnn')
  109. batch_first_parser = general.add_mutually_exclusive_group(required=False)
  110. batch_first_parser.add_argument('--batch-first', dest='batch_first',
  111. action='store_true',
  112. help='uses (batch, seq, feature) data \
  113. format for RNNs')
  114. batch_first_parser.add_argument('--seq-first', dest='batch_first',
  115. action='store_false',
  116. help='uses (seq, batch, feature) data \
  117. format for RNNs')
  118. batch_first_parser.set_defaults(batch_first=True)
  119. general.add_argument('--save-dir', default='gnmt',
  120. help='path to directory with results, it will be \
  121. automatically created if it does not exist')
  122. general.add_argument('--dllog-file', type=str, default='eval_log.json',
  123. help='Name of the DLLogger output file')
  124. general.add_argument('--print-freq', '-p', default=1, type=int,
  125. help='print log every PRINT_FREQ batches')
  126. general.add_argument('--affinity', type=str,
  127. default='single_unique',
  128. choices=['socket', 'single', 'single_unique',
  129. 'socket_unique_interleaved',
  130. 'socket_unique_continuous',
  131. 'disabled'],
  132. help='type of CPU affinity')
  133. # benchmarking
  134. benchmark = parser.add_argument_group('benchmark setup')
  135. benchmark.add_argument('--target-perf', default=None, type=float,
  136. help='target inference performance (in tokens \
  137. per second)')
  138. benchmark.add_argument('--target-bleu', default=None, type=float,
  139. help='target accuracy')
  140. benchmark.add_argument('--repeat', nargs='+', default=[1], type=float,
  141. help='loops over the dataset REPEAT times, flag \
  142. accepts multiple arguments, one for each specified \
  143. batch size')
  144. benchmark.add_argument('--warmup', default=0, type=int,
  145. help='warmup iterations for performance counters')
  146. benchmark.add_argument('--percentiles', nargs='+', type=int,
  147. default=(90, 95, 99),
  148. help='Percentiles for confidence intervals for \
  149. throughput/latency benchmarks')
  150. exclusive_group(group=benchmark, name='tables', default=False,
  151. help='print accuracy, throughput and latency results in \
  152. tables')
  153. # distributed
  154. distributed = parser.add_argument_group('distributed setup')
  155. distributed.add_argument('--local_rank', type=int,
  156. default=os.getenv('LOCAL_RANK', 0),
  157. help='Used for multi-process training.')
  158. args = parser.parse_args()
  159. if args.input_text:
  160. args.bleu = False
  161. if args.bleu and args.reference is None:
  162. parser.error('--bleu requires --reference')
  163. if ('fp16' in args.math or 'tf32' in args.math) and not args.cuda:
  164. parser.error(f'--math {args.math} requires --cuda')
  165. if len(list(product(args.math, args.batch_size, args.beam_size))) > 1:
  166. args.target_bleu = None
  167. args.target_perf = None
  168. args.repeat = dict(itertools.zip_longest(args.batch_size,
  169. args.repeat,
  170. fillvalue=1))
  171. return args
  172. def main():
  173. """
  174. Launches translation (inference).
  175. Inference is executed on a single GPU, implementation supports beam search
  176. with length normalization and coverage penalty.
  177. """
  178. args = parse_args()
  179. if args.affinity != 'disabled':
  180. nproc_per_node = torch.cuda.device_count()
  181. affinity = gpu_affinity.set_affinity(
  182. args.local_rank,
  183. nproc_per_node,
  184. args.affinity
  185. )
  186. print(f'{args.local_rank}: thread affinity: {affinity}')
  187. device = utils.set_device(args.cuda, args.local_rank)
  188. utils.init_distributed(args.cuda)
  189. args.rank = utils.get_rank()
  190. os.makedirs(args.save_dir, exist_ok=True)
  191. utils.setup_logging()
  192. dllog_file = os.path.join(args.save_dir, args.dllog_file)
  193. utils.setup_dllogger(enabled=True, filename=dllog_file)
  194. if args.env:
  195. utils.log_env_info()
  196. logging.info(f'Run arguments: {args}')
  197. dllogger.log(step='PARAMETER', data=vars(args))
  198. if not args.cuda and torch.cuda.is_available():
  199. warnings.warn('cuda is available but not enabled')
  200. if not args.cudnn:
  201. torch.backends.cudnn.enabled = False
  202. # load checkpoint and deserialize to CPU (to save GPU memory)
  203. if args.model:
  204. checkpoint = torch.load(args.model, map_location={'cuda:0': 'cpu'})
  205. # build GNMT model
  206. tokenizer = Tokenizer()
  207. tokenizer.set_state(checkpoint['tokenizer'])
  208. model_config = checkpoint['model_config']
  209. model_config['batch_first'] = args.batch_first
  210. model_config['vocab_size'] = tokenizer.vocab_size
  211. model = GNMT(**model_config)
  212. model.load_state_dict(checkpoint['state_dict'])
  213. elif args.synthetic:
  214. model = GNMT(args.synthetic_vocab, batch_first=args.batch_first)
  215. tokenizer = None
  216. else:
  217. raise RuntimeError('Specify model either with --synthetic or with --model flag')
  218. # construct the dataset
  219. if args.input:
  220. data = RawTextDataset(raw_datafile=args.input,
  221. tokenizer=tokenizer,
  222. sort=args.sort,
  223. )
  224. elif args.input_text:
  225. data = RawTextDataset(raw_data=args.input_text,
  226. tokenizer=tokenizer,
  227. sort=args.sort,
  228. )
  229. elif args.synthetic:
  230. data = SyntheticDataset(args.synthetic_vocab, args.synthetic_len, args.batch_size[0] * args.synthetic_batches)
  231. latency_table = tables.LatencyTable(args.percentiles)
  232. throughput_table = tables.ThroughputTable(args.percentiles)
  233. accuracy_table = tables.AccuracyTable('BLEU')
  234. dtype = {
  235. 'fp32': torch.FloatTensor,
  236. 'tf32': torch.FloatTensor,
  237. 'fp16': torch.HalfTensor
  238. }
  239. for (math, batch_size, beam_size) in product(args.math, args.batch_size,
  240. args.beam_size):
  241. logging.info(f'math: {math}, batch size: {batch_size}, '
  242. f'beam size: {beam_size}')
  243. model.type(dtype[math])
  244. model = model.to(device)
  245. model.eval()
  246. # build the data loader
  247. loader = data.get_loader(
  248. batch_size=batch_size,
  249. batch_first=args.batch_first,
  250. pad=True,
  251. repeat=args.repeat[batch_size],
  252. num_workers=0,
  253. )
  254. # build the translator object
  255. translator = Translator(
  256. model=model,
  257. tokenizer=tokenizer,
  258. loader=loader,
  259. beam_size=beam_size,
  260. max_seq_len=args.max_seq_len,
  261. len_norm_factor=args.len_norm_factor,
  262. len_norm_const=args.len_norm_const,
  263. cov_penalty_factor=args.cov_penalty_factor,
  264. print_freq=args.print_freq,
  265. )
  266. # execute the inference
  267. output, stats = translator.run(
  268. calc_bleu=args.bleu,
  269. eval_path=args.output,
  270. summary=True,
  271. warmup=args.warmup,
  272. reference_path=args.reference,
  273. )
  274. # print translated outputs
  275. if not args.synthetic and (not args.output and args.rank == 0):
  276. logging.info(f'Translated output:')
  277. for out in output:
  278. print(out)
  279. key = (batch_size, beam_size)
  280. latency_table.add(key, {math: stats['runtimes']})
  281. throughput_table.add(key, {math: stats['throughputs']})
  282. accuracy_table.add(key, {math: stats['bleu']})
  283. if args.tables:
  284. accuracy_table.write('Inference accuracy', args.math)
  285. if 'fp16' in args.math and 'fp32' in args.math:
  286. relative = 'fp32'
  287. elif 'fp16' in args.math and 'tf32' in args.math:
  288. relative = 'tf32'
  289. else:
  290. relative = None
  291. if 'fp32' in args.math:
  292. throughput_table.write('Inference throughput', 'fp32')
  293. if 'tf32' in args.math:
  294. throughput_table.write('Inference throughput', 'tf32')
  295. if 'fp16' in args.math:
  296. throughput_table.write('Inference throughput', 'fp16',
  297. relative=relative)
  298. if 'fp32' in args.math:
  299. latency_table.write('Inference latency', 'fp32')
  300. if 'tf32' in args.math:
  301. latency_table.write('Inference latency', 'tf32')
  302. if 'fp16' in args.math:
  303. latency_table.write('Inference latency', 'fp16',
  304. relative=relative, reverse_speedup=True)
  305. summary = {
  306. 'eval_throughput': stats['tokens_per_sec'],
  307. 'eval_bleu': stats['bleu'],
  308. 'eval_avg_latency': np.array(stats['runtimes']).mean(),
  309. }
  310. for p in args.percentiles:
  311. summary[f'eval_{p}%_latency'] = np.percentile(stats['runtimes'], p)
  312. dllogger.log(step=tuple(), data=summary)
  313. passed = utils.benchmark(stats['bleu'], args.target_bleu,
  314. stats['tokens_per_sec'], args.target_perf)
  315. return passed
  316. if __name__ == '__main__':
  317. passed = main()
  318. if not passed:
  319. sys.exit(1)