inference.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. #!/usr/bin/env python3 -u
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the LICENSE file in
  6. # the root directory of this source tree. An additional grant of patent rights
  7. # can be found in the PATENTS file in the same directory.
  8. #
  9. #-------------------------------------------------------------------------
  10. #
  11. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. import sys
  24. import os
  25. import time
  26. from collections import namedtuple
  27. import numpy as np
  28. import torch
  29. from torch.serialization import default_restore_location
  30. from fairseq import data, options, tokenizer, utils, log_helper
  31. from fairseq.sequence_generator import SequenceGenerator
  32. from fairseq.meters import StopwatchMeter
  33. from fairseq.models.transformer import TransformerModel
  34. import dllogger
  35. from apply_bpe import BPE
  36. Batch = namedtuple('Batch', 'srcs tokens lengths')
  37. Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
  38. def load_ensemble_for_inference(filenames):
  39. """Load an ensemble of models for inference.
  40. model_arg_overrides allows you to pass a dictionary model_arg_overrides --
  41. {'arg_name': arg} -- to override model args that were used during model
  42. training
  43. """
  44. # load model architectures and weights
  45. states = []
  46. for filename in filenames:
  47. if not os.path.exists(filename):
  48. raise IOError('Model file not found: {}'.format(filename))
  49. state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
  50. states.append(state)
  51. ensemble = []
  52. for state in states:
  53. args = state['args']
  54. # build model for ensemble
  55. model = TransformerModel.build_model(args)
  56. model.load_state_dict(state['model'], strict=True)
  57. ensemble.append(model)
  58. src_dict = states[0]['extra_state']['src_dict']
  59. tgt_dict = states[0]['extra_state']['tgt_dict']
  60. return ensemble, args, src_dict, tgt_dict
  61. def buffered_read(buffer_size, data_descriptor):
  62. buffer = []
  63. for src_str in data_descriptor:
  64. buffer.append(src_str.strip())
  65. if len(buffer) >= buffer_size:
  66. yield buffer
  67. buffer = []
  68. if buffer:
  69. yield buffer
  70. def make_batches(lines, args, src_dict, max_positions, bpe=None):
  71. tokens = [
  72. tokenizer.Tokenizer.tokenize(
  73. src_str,
  74. src_dict,
  75. tokenize=tokenizer.tokenize_en,
  76. add_if_not_exist=False,
  77. bpe=bpe
  78. ).long()
  79. for src_str in lines
  80. ]
  81. lengths = np.array([t.numel() for t in tokens])
  82. itr = data.EpochBatchIterator(
  83. dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
  84. max_tokens=args.max_tokens,
  85. max_sentences=args.max_sentences,
  86. max_positions=max_positions,
  87. ).next_epoch_itr(shuffle=False)
  88. for batch in itr:
  89. yield Batch(
  90. srcs=[lines[i] for i in batch['id']],
  91. tokens=batch['net_input']['src_tokens'],
  92. lengths=batch['net_input']['src_lengths'],
  93. ), batch['id']
  94. def setup_logger(args):
  95. if not args.no_dllogger:
  96. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=1, filename=args.stat_file)])
  97. for k, v in vars(args).items():
  98. dllogger.log(step='PARAMETER', data={k:v}, verbosity=0)
  99. container_setup_info = log_helper.get_framework_env_vars()
  100. dllogger.log(step='PARAMETER', data=container_setup_info, verbosity=0)
  101. dllogger.metadata('throughput',
  102. {'unit':'tokens/s', 'format':':/3f', 'GOAL':'MAXIMIZE', 'STAGE':'INFER'})
  103. else:
  104. dllogger.init(backends=[])
  105. def main(args):
  106. setup_logger(args)
  107. args.interactive = sys.stdin.isatty() and not args.file # Just make the code more understendable
  108. if args.file:
  109. data_descriptor = open(args.file, 'r')
  110. else:
  111. data_descriptor = sys.stdin
  112. if args.interactive:
  113. args.buffer_size = 1
  114. if args.max_tokens is None and args.max_sentences is None:
  115. args.max_sentences = 1
  116. if args.buffer_size > 50000:
  117. print("WARNING: To prevent memory exhaustion buffer size is set to 50000", file=sys.stderr)
  118. args.buffer_size = 50000
  119. assert not args.sampling or args.nbest == args.beam, \
  120. '--sampling requires --nbest to be equal to --beam'
  121. assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
  122. '--max-sentences/--batch-size cannot be larger than --buffer-size'
  123. print(args, file=sys.stderr)
  124. use_cuda = torch.cuda.is_available() and not args.cpu
  125. torch.cuda.synchronize()
  126. processing_start = time.time()
  127. # Load ensemble
  128. print('| loading model(s) from {}'.format(args.path), file=sys.stderr)
  129. model_paths = args.path.split(':')
  130. models, model_args, src_dict, tgt_dict = load_ensemble_for_inference(model_paths)
  131. if args.fp16:
  132. for model in models:
  133. model.half()
  134. # Optimize ensemble for generation
  135. for model in models:
  136. model.make_generation_fast_(need_attn=args.print_alignment)
  137. # Initialize generator
  138. translator = SequenceGenerator(
  139. models,
  140. tgt_dict.get_metadata(),
  141. maxlen=args.max_target_positions,
  142. beam_size=args.beam,
  143. stop_early=(not args.no_early_stop),
  144. normalize_scores=(not args.unnormalized),
  145. len_penalty=args.lenpen,
  146. unk_penalty=args.unkpen,
  147. sampling=args.sampling,
  148. sampling_topk=args.sampling_topk,
  149. minlen=args.min_len,
  150. sampling_temperature=args.sampling_temperature
  151. )
  152. if use_cuda:
  153. translator.cuda()
  154. # Load BPE codes file
  155. if args.bpe_codes:
  156. codes = open(args.bpe_codes, 'r')
  157. bpe = BPE(codes)
  158. # Load alignment dictionary for unknown word replacement
  159. # (None if no unknown word replacement, empty if no path to align dictionary)
  160. align_dict = utils.load_align_dict(args.replace_unk)
  161. def make_result(src_str, hypos):
  162. result = Translation(
  163. src_str=src_str,
  164. hypos=[],
  165. pos_scores=[],
  166. alignments=[],
  167. )
  168. # Process top predictions
  169. for hypo in hypos[:min(len(hypos), args.nbest)]:
  170. hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
  171. hypo_tokens=hypo['tokens'].int().cpu(),
  172. src_str=src_str,
  173. alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
  174. align_dict=align_dict,
  175. tgt_dict=tgt_dict,
  176. remove_bpe=args.remove_bpe,
  177. )
  178. hypo_str = tokenizer.Tokenizer.detokenize(hypo_str, 'de').strip()
  179. result.hypos.append((hypo['score'], hypo_str))
  180. result.pos_scores.append('P\t' + ' '.join(f'{x:.4f}' for x in hypo['positional_scores'].tolist()))
  181. result.alignments.append('A\t' + ' '.join(str(utils.item(x)) for x in alignment)
  182. if args.print_alignment else None
  183. )
  184. return result
  185. gen_timer = StopwatchMeter()
  186. def process_batch(batch):
  187. tokens = batch.tokens
  188. lengths = batch.lengths
  189. if use_cuda:
  190. tokens = tokens.cuda()
  191. lengths = lengths.cuda()
  192. torch.cuda.synchronize()
  193. translation_start = time.time()
  194. gen_timer.start()
  195. translations = translator.generate(
  196. tokens,
  197. lengths,
  198. maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
  199. )
  200. gen_timer.stop(sum(len(h[0]['tokens']) for h in translations))
  201. torch.cuda.synchronize()
  202. dllogger.log(step='infer', data={'latency': time.time() - translation_start})
  203. return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
  204. if args.interactive:
  205. print('| Type the input sentence and press return:')
  206. for inputs in buffered_read(args.buffer_size, data_descriptor):
  207. indices = []
  208. results = []
  209. for batch, batch_indices in make_batches(inputs, args, src_dict, args.max_positions, bpe):
  210. indices.extend(batch_indices)
  211. results += process_batch(batch)
  212. for i in np.argsort(indices):
  213. result = results[i]
  214. print(result.src_str, file=sys.stderr)
  215. for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
  216. print(f'Score {hypo[0]}', file=sys.stderr)
  217. print(hypo[1])
  218. if align is not None:
  219. print(align, file=sys.stderr)
  220. if args.file:
  221. data_descriptor.close()
  222. torch.cuda.synchronize()
  223. log_dict = {
  224. 'throughput': 1./gen_timer.avg,
  225. 'latency_avg': sum(gen_timer.intervals)/len(gen_timer.intervals),
  226. 'latency_p90': gen_timer.p(90),
  227. 'latency_p95': gen_timer.p(95),
  228. 'latency_p99': gen_timer.p(99),
  229. 'total_infernece_time': gen_timer.sum,
  230. 'total_run_time': time.time() - processing_start,
  231. }
  232. print('Translation time: {} s'.format(log_dict['total_infernece_time']),
  233. file=sys.stderr)
  234. print('Model throughput (beam {}): {} tokens/s'.format(args.beam, log_dict['throughput']),
  235. file=sys.stderr)
  236. print('Latency:\n\tAverage {:.3f}s\n\tp90 {:.3f}s\n\tp95 {:.3f}s\n\tp99 {:.3f}s'.format(
  237. log_dict['latency_avg'], log_dict['latency_p90'], log_dict['latency_p95'], log_dict['latency_p99']),
  238. file=sys.stderr)
  239. print('End to end time: {} s'.format(log_dict['total_run_time']), file=sys.stderr)
  240. dllogger.log(step=(), data=log_dict)
  241. if __name__ == '__main__':
  242. parser = options.get_inference_parser()
  243. parser.add_argument('--no-dllogger', action='store_true')
  244. ARGS = options.parse_args_and_arch(parser)
  245. main(ARGS)