inference.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import io
  15. import math
  16. import os
  17. import random
  18. import time
  19. import warnings
  20. from argparse import ArgumentParser
  21. from heapq import nlargest
  22. from itertools import chain, repeat
  23. from pathlib import Path
  24. from tqdm import tqdm
  25. import dllogger
  26. import numpy as np
  27. import torch
  28. import torch.distributed as distrib
  29. from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
  30. import wav2vec2.arg_parser
  31. import wav2vec2.utils
  32. import common.fairseq.utils as utils
  33. from common.fairseq.data import Dictionary
  34. from common.helpers import (gather_predictions, gather_transcripts,
  35. load_wrapped_state, process_evaluation_epoch)
  36. from common.tb_dllogger import stdout_metric_format, unique_log_fpath
  37. from common.utils import print_once
  38. from torch.utils.data import DataLoader, DistributedSampler
  39. from wav2vec2.logging import init_infer_metadata
  40. def durs_to_percentiles(durations, ratios):
  41. durations = np.asarray(durations) * 1000 # in ms
  42. latency = durations
  43. latency = latency[5:]
  44. mean_latency = np.mean(latency)
  45. latency_worst = nlargest(math.ceil((1 - min(ratios)) * len(latency)),
  46. latency)
  47. latency_ranges = get_percentile(ratios, latency_worst, len(latency))
  48. latency_ranges[0.5] = mean_latency
  49. return latency_ranges
  50. def get_percentile(ratios, arr, nsamples):
  51. res = {}
  52. for a in ratios:
  53. idx = max(int(nsamples * (1 - a)), 0)
  54. res[a] = arr[idx]
  55. return res
  56. def fp_convert_batch(batch, precision):
  57. dt = {'fp32': torch.float32, 'fp16': torch.half,
  58. 'bf16': torch.bfloat16}[precision]
  59. def maybe_cast(t):
  60. if t.dtype is torch.float32:
  61. return t.to(dtype=dt)
  62. return t
  63. return utils.apply_to_sample(maybe_cast, batch)
  64. def main():
  65. parser = ArgumentParser(description='wav2vec2.0 inference')
  66. wav2vec2.arg_parser.populate_infer(parser)
  67. args = parser.parse_args()
  68. ckpt = torch.load(args.w2v_path, map_location=torch.device("cpu"))
  69. train_args = wav2vec2.utils.get_ckpt_args(ckpt)
  70. is_nv_ckpt = "mode" in train_args
  71. if is_nv_ckpt:
  72. print("Loaded a model trained with NVIDIA DLE")
  73. args.fp32_pos_conv = train_args.get("fp32_pos_conv",
  74. args.fp16 or args.bf16)
  75. args.fp32_conv_norms = train_args.get("fp32_conv_norms", args.fp16)
  76. else:
  77. args.fp32_pos_conv = args.fp16
  78. args.fp32_conv_norms = args.fp16
  79. args.fp32_pos_conv = True
  80. args.fp32_conv_norms = True
  81. log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json'))
  82. dllogger.init(backends=[
  83. JSONStreamBackend(Verbosity.DEFAULT, log_fpath, append=True),
  84. JSONStreamBackend(Verbosity.DEFAULT, unique_log_fpath(log_fpath)),
  85. StdOutBackend(Verbosity.VERBOSE, metric_format=stdout_metric_format)
  86. ])
  87. [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]
  88. init_infer_metadata()
  89. if ((train_args.get("fp16", False) or train_args.get("amp", False))
  90. and args.bf16):
  91. warnings.warn('Using FP16 ckpts in BF16 precision.')
  92. if train_args.get("bf16", False) and args.fp16:
  93. warnings.warn('Using BF16 ckpts in FP16 precision.')
  94. # load output labels - either from a file, or stored inside an nv ckpt
  95. assert args.labels_path is not None or is_nv_ckpt
  96. if args.labels_path is None:
  97. f = io.StringIO(ckpt["output_labels"])
  98. else:
  99. f = open(args.labels_path)
  100. target_dictionary = Dictionary.load(f)
  101. f.close()
  102. w2v_path_for_args = args.w2v_path_for_args or args.w2v_path
  103. wav2vec2.utils.update_args_for_finetuning(args, w2v_path_for_args)
  104. # "default" GroupNorm might leak padding
  105. args.masked_feature_extractor = True
  106. if args.torchscript:
  107. from common.fairseq.modules import layer_norm
  108. layer_norm.TORCHSCRIPT = True
  109. model, *_ = wav2vec2.utils.build_model(args, "infer", target_dictionary)
  110. load_wrapped_state(model, ckpt["model"])
  111. model.w2v_encoder.w2v_model.remove_conv_wn()
  112. model.w2v_encoder.w2v_model.feature_extractor.forward = \
  113. model.w2v_encoder.w2v_model.feature_extractor.masked_forward
  114. model.w2v_encoder.forward = model.w2v_encoder.infer
  115. model.w2v_encoder.w2v_model.forward = model.w2v_encoder.w2v_model.infer
  116. if args.cpu:
  117. device = torch.device('cpu')
  118. else:
  119. assert torch.cuda.is_available()
  120. device = torch.device('cuda')
  121. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  122. if args.seed is not None:
  123. torch.manual_seed(args.seed + args.local_rank)
  124. np.random.seed(args.seed + args.local_rank)
  125. random.seed(args.seed + args.local_rank)
  126. # set up distributed training
  127. multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1
  128. if multi_gpu:
  129. torch.cuda.set_device(args.local_rank)
  130. distrib.init_process_group(backend='nccl', init_method='env://')
  131. print_once(f'Inference with {distrib.get_world_size()} GPUs')
  132. measure_perf = args.steps > 0
  133. # Compliance with fairseq dataloader
  134. assert args.batch_size is not None
  135. args.min_sample_size = None
  136. args.max_sample_size = None
  137. if args.transcribe_wav or args.transcribe_filelist:
  138. assert args.max_duration is None and not measure_perf
  139. assert not (args.transcribe_wav and args.transcribe_filelist)
  140. assert args.labels is None, "Labels won't be used during trainscribing"
  141. assert not multi_gpu, (
  142. "multigpu is currently supported only for WER/perf measurements")
  143. if args.transcribe_wav:
  144. dataset = wav2vec2.utils.single_audio_dataset(args.transcribe_wav,
  145. args)
  146. else:
  147. dataset = wav2vec2.utils.load_dataset(args.transcribe_filelist,
  148. args, target_dictionary)
  149. data_loader = DataLoader(
  150. dataset=dataset,
  151. batch_size=args.batch_size,
  152. shuffle=False,
  153. collate_fn=dataset.collater,
  154. num_workers=args.num_workers,
  155. pin_memory=True,
  156. persistent_workers=args.num_workers > 0,
  157. drop_last=False,
  158. )
  159. else: # compute WER or measure perf
  160. assert args.labels is not None or measure_perf
  161. dataset = wav2vec2.utils.load_dataset(args.valid_subset, args,
  162. target_dictionary,
  163. with_labels=True)
  164. sampler = DistributedSampler(
  165. dataset,
  166. shuffle=False,
  167. drop_last=False
  168. ) if multi_gpu else None
  169. data_loader = DataLoader(
  170. dataset=dataset,
  171. batch_size=args.batch_size,
  172. sampler=sampler,
  173. shuffle=False,
  174. collate_fn=dataset.collater,
  175. num_workers=args.num_workers,
  176. pin_memory=True,
  177. persistent_workers=args.num_workers > 0,
  178. drop_last=(True if measure_perf else False),
  179. )
  180. model.to(device)
  181. model.eval()
  182. assert args.amp == args.fp16, 'During inference these are equivalent'
  183. if args.fp16:
  184. model = model.half()
  185. if args.bf16:
  186. model = model.to(dtype=torch.bfloat16)
  187. if (args.fp16 or args.bf16) and args.fp32_pos_conv:
  188. model.w2v_encoder.w2v_model.encoder.pos_conv.to(dtype=torch.float32)
  189. if args.torchscript:
  190. print("Attempting TorchScript export...")
  191. model = torch.jit.script(model)
  192. agg = {'txts': [], 'preds': [], 'logits': [], 'ids': []}
  193. dur = {'data': [], 'dnn': [], 'data+dnn': []}
  194. looped_loader = chain.from_iterable(repeat(data_loader))
  195. sync = lambda: torch.cuda.synchronize() if device.type == 'cuda' else None
  196. steps = args.steps + args.warmup_steps or len(data_loader)
  197. desc = 'warmup' if args.warmup_steps > 0 else 'inference'
  198. pbar = tqdm(looped_loader, initial=1, total=steps, desc=desc)
  199. for it, batch in enumerate(pbar):
  200. if it == args.warmup_steps:
  201. pbar.set_description('inference')
  202. batch = utils.move_to_cuda(batch)
  203. sync()
  204. t1 = time.time()
  205. if args.fp16:
  206. batch = fp_convert_batch(batch, 'fp16')
  207. if args.bf16:
  208. batch = fp_convert_batch(batch, 'bf16')
  209. with torch.no_grad():
  210. enc_out, padding_mask = model(batch["net_input"]["source"],
  211. batch["net_input"]["padding_mask"])
  212. logp = model.get_normalized_probs(enc_out,
  213. padding_mask,
  214. log_probs=True).contiguous()
  215. # greedy decoding
  216. preds = logp.argmax(dim=-1, keepdim=False).int()
  217. sync()
  218. t2 = time.time()
  219. # burn-in period; wait for a new loader due to num_workers
  220. if it >= 1 and (args.steps == 0 or it >= args.warmup_steps):
  221. dur['data'].append(t1 - t0)
  222. dur['dnn'].append(t2 - t1)
  223. dur['data+dnn'].append(t2 - t0)
  224. preds = preds.transpose(0, 1)
  225. agg['preds'] += gather_predictions([preds],
  226. target_dictionary,
  227. blank_id=0)
  228. agg['logits'].append(logp)
  229. if 'target' in batch:
  230. agg['txts'] += gather_transcripts([batch['target']],
  231. [batch['target_lengths']],
  232. target_dictionary)
  233. if multi_gpu:
  234. # ids are needed to remove duplicates in multi_gpu inference
  235. agg['ids'] += batch['id'].tolist()
  236. if it + 1 == steps:
  237. break
  238. sync()
  239. t0 = time.time()
  240. tdict = target_dictionary
  241. agg['preds'] = [pred.replace(tdict[tdict.nspecial], ' ')
  242. for pred in agg['preds']]
  243. agg['txts'] = [txt.replace(tdict[tdict.nspecial], ' ')
  244. for txt in agg['txts']]
  245. # communicate the results
  246. if args.transcribe_wav or args.transcribe_filelist:
  247. for idx, p in enumerate(agg['preds']):
  248. print_once(f'Prediction {idx + 1: >3}: {p}')
  249. elif args.valid_subset and not measure_perf:
  250. wer, _ = process_evaluation_epoch(agg)
  251. if not multi_gpu or distrib.get_rank() == 0:
  252. dllogger.log(step=(), data={'eval_wer': 100 * wer})
  253. if args.save_predictions and (not multi_gpu or distrib.get_rank() == 0):
  254. with open(args.save_predictions, 'w') as f:
  255. f.write('\n'.join(agg['preds']))
  256. if args.save_logits and (not multi_gpu or distrib.get_rank() == 0):
  257. logits = torch.cat(agg['logits'], dim=0).cpu()
  258. torch.save(logits, args.save_logits)
  259. # report timings
  260. if len(dur['data']) >= 20 and (not multi_gpu or distrib.get_rank() == 0):
  261. ratios = [0.9, 0.95, 0.99]
  262. for stage in dur:
  263. lat = durs_to_percentiles(dur[stage], ratios)
  264. for k in [0.99, 0.95, 0.9, 0.5]:
  265. k_ = str(k).replace('.', '_')
  266. dllogger.log(step=(), data={f'{stage}_latency_{k_}': lat[k]})
  267. else:
  268. print_once('Not enough samples to measure latencies.')
  269. if __name__ == "__main__":
  270. main()