inference.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import math
  16. import os
  17. import random
  18. import time
  19. from heapq import nlargest
  20. from itertools import chain, repeat
  21. from pathlib import Path
  22. from tqdm import tqdm
  23. import dllogger
  24. import torch
  25. import numpy as np
  26. import torch.distributed as distrib
  27. import torch.nn.functional as F
  28. from apex import amp
  29. from apex.parallel import DistributedDataParallel
  30. from dllogger import JSONStreamBackend, StdOutBackend, Verbosity
  31. from jasper import config
  32. from common import helpers
  33. from common.dali.data_loader import DaliDataLoader
  34. from common.dataset import (AudioDataset, FilelistDataset, get_data_loader,
  35. SingleAudioDataset)
  36. from common.features import BaseFeatures, FilterbankFeatures
  37. from common.helpers import print_once, process_evaluation_epoch
  38. from jasper.model import GreedyCTCDecoder, Jasper
  39. from common.tb_dllogger import stdout_metric_format, unique_log_fpath
  40. def get_parser():
  41. parser = argparse.ArgumentParser(description='Jasper')
  42. parser.add_argument('--batch_size', default=16, type=int,
  43. help='Data batch size')
  44. parser.add_argument('--steps', default=0, type=int,
  45. help='Eval this many steps for every worker')
  46. parser.add_argument('--warmup_steps', default=0, type=int,
  47. help='Burn-in period before measuring latencies')
  48. parser.add_argument('--model_config', type=str, required=True,
  49. help='Relative model config path given dataset folder')
  50. parser.add_argument('--dataset_dir', type=str,
  51. help='Absolute path to dataset folder')
  52. parser.add_argument('--val_manifests', type=str, nargs='+',
  53. help='Relative path to evaluation dataset manifest files')
  54. parser.add_argument('--ckpt', default=None, type=str,
  55. help='Path to model checkpoint')
  56. parser.add_argument('--pad_leading', type=int, default=16,
  57. help='Pads every batch with leading zeros '
  58. 'to counteract conv shifts of the field of view')
  59. parser.add_argument('--amp', '--fp16', action='store_true',
  60. help='Use FP16 precision')
  61. parser.add_argument('--cudnn_benchmark', action='store_true',
  62. help='Enable cudnn benchmark')
  63. parser.add_argument('--cpu', action='store_true',
  64. help='Run inference on CPU')
  65. parser.add_argument("--seed", default=None, type=int, help='Random seed')
  66. parser.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0),
  67. type=int, help='GPU id used for distributed training')
  68. io = parser.add_argument_group('feature and checkpointing setup')
  69. io.add_argument('--dali_device', type=str, choices=['none', 'cpu', 'gpu'],
  70. default='gpu', help='Use DALI pipeline for fast data processing')
  71. io.add_argument('--save_predictions', type=str, default=None,
  72. help='Save predictions in text form at this location')
  73. io.add_argument('--save_logits', default=None, type=str,
  74. help='Save output logits under specified path')
  75. io.add_argument('--transcribe_wav', type=str,
  76. help='Path to a single .wav file (16KHz)')
  77. io.add_argument('--transcribe_filelist', type=str,
  78. help='Path to a filelist with one .wav path per line')
  79. io.add_argument('-o', '--output_dir', default='results/',
  80. help='Output folder to save audio (file per phrase)')
  81. io.add_argument('--log_file', type=str, default=None,
  82. help='Path to a DLLogger log file')
  83. io.add_argument('--ema', action='store_true',
  84. help='Load averaged model weights')
  85. io.add_argument('--torchscript', action='store_true',
  86. help='Evaluate with a TorchScripted model')
  87. io.add_argument('--torchscript_export', action='store_true',
  88. help='Export the model with torch.jit to the output_dir')
  89. io.add_argument('--override_config', type=str, action='append',
  90. help='Overrides a value from a config .yaml.'
  91. ' Syntax: `--override_config nested.config.key=val`.')
  92. return parser
  93. def durs_to_percentiles(durations, ratios):
  94. durations = np.asarray(durations) * 1000 # in ms
  95. latency = durations
  96. latency = latency[5:]
  97. mean_latency = np.mean(latency)
  98. latency_worst = nlargest(math.ceil((1 - min(ratios)) * len(latency)), latency)
  99. latency_ranges = get_percentile(ratios, latency_worst, len(latency))
  100. latency_ranges[0.5] = mean_latency
  101. return latency_ranges
  102. def get_percentile(ratios, arr, nsamples):
  103. res = {}
  104. for a in ratios:
  105. idx = max(int(nsamples * (1 - a)), 0)
  106. res[a] = arr[idx]
  107. return res
  108. def torchscript_export(data_loader, audio_processor, model, greedy_decoder,
  109. output_dir, use_amp, use_conv_masks, model_config, device,
  110. save):
  111. audio_processor.to(device)
  112. for batch in data_loader:
  113. batch = [t.to(device, non_blocking=True) for t in batch]
  114. audio, audio_len, _, _ = batch
  115. feats, feat_lens = audio_processor(audio, audio_len)
  116. break
  117. print("\nExporting featurizer...")
  118. print("\nNOTE: Dithering causes warnings about non-determinism.\n")
  119. ts_feat = torch.jit.trace(audio_processor, (audio, audio_len))
  120. print("\nExporting acoustic model...")
  121. model(feats, feat_lens)
  122. ts_acoustic = torch.jit.trace(model, (feats, feat_lens))
  123. print("\nExporting decoder...")
  124. log_probs = model(feats, feat_lens)
  125. ts_decoder = torch.jit.script(greedy_decoder, log_probs)
  126. print("\nJIT export complete.")
  127. if save:
  128. precision = "fp16" if use_amp else "fp32"
  129. module_name = f'{os.path.basename(model_config)}_{precision}'
  130. ts_feat.save(os.path.join(output_dir, module_name + "_feat.pt"))
  131. ts_acoustic.save(os.path.join(output_dir, module_name + "_acoustic.pt"))
  132. ts_decoder.save(os.path.join(output_dir, module_name + "_decoder.pt"))
  133. return ts_feat, ts_acoustic, ts_decoder
  134. def main():
  135. parser = get_parser()
  136. args = parser.parse_args()
  137. log_fpath = args.log_file or str(Path(args.output_dir, 'nvlog_infer.json'))
  138. dllogger.init(backends=[
  139. JSONStreamBackend(Verbosity.DEFAULT, log_fpath, append=True),
  140. JSONStreamBackend(Verbosity.DEFAULT, unique_log_fpath(log_fpath)),
  141. StdOutBackend(Verbosity.VERBOSE, metric_format=stdout_metric_format)
  142. ])
  143. [dllogger.log("PARAMETER", {k: v}) for k, v in vars(args).items()]
  144. for step in ['DNN', 'data+DNN', 'data']:
  145. for c in [0.99, 0.95, 0.9, 0.5]:
  146. cs = 'avg' if c == 0.5 else f'{int(100*c)}%'
  147. dllogger.metadata(f'{step.lower()}_latency_{c}',
  148. {'name': f'{step} latency {cs}',
  149. 'format': ':>7.2f', 'unit': 'ms'})
  150. dllogger.metadata(
  151. 'eval_wer', {'name': 'WER', 'format': ':>3.2f', 'unit': '%'})
  152. if args.cpu:
  153. device = torch.device('cpu')
  154. else:
  155. assert torch.cuda.is_available()
  156. device = torch.device('cuda')
  157. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  158. if args.seed is not None:
  159. torch.manual_seed(args.seed + args.local_rank)
  160. np.random.seed(args.seed + args.local_rank)
  161. random.seed(args.seed + args.local_rank)
  162. # set up distributed training
  163. multi_gpu = not args.cpu and int(os.environ.get('WORLD_SIZE', 1)) > 1
  164. if multi_gpu:
  165. torch.cuda.set_device(args.local_rank)
  166. distrib.init_process_group(backend='nccl', init_method='env://')
  167. print_once(f'Inference with {distrib.get_world_size()} GPUs')
  168. cfg = config.load(args.model_config)
  169. config.apply_config_overrides(cfg, args)
  170. symbols = helpers.add_ctc_blank(cfg['labels'])
  171. use_dali = args.dali_device in ('cpu', 'gpu')
  172. dataset_kw, features_kw = config.input(cfg, 'val')
  173. measure_perf = args.steps > 0
  174. # dataset
  175. if args.transcribe_wav or args.transcribe_filelist:
  176. if use_dali:
  177. print("DALI supported only with input .json files; disabling")
  178. use_dali = False
  179. assert not (args.transcribe_wav and args.transcribe_filelist)
  180. if args.transcribe_wav:
  181. dataset = SingleAudioDataset(args.transcribe_wav)
  182. else:
  183. dataset = FilelistDataset(args.transcribe_filelist)
  184. data_loader = get_data_loader(dataset,
  185. batch_size=1,
  186. multi_gpu=multi_gpu,
  187. shuffle=False,
  188. num_workers=0,
  189. drop_last=(True if measure_perf else False))
  190. _, features_kw = config.input(cfg, 'val')
  191. assert not features_kw['pad_to_max_duration']
  192. feat_proc = FilterbankFeatures(**features_kw)
  193. elif use_dali:
  194. # pad_to_max_duration is not supported by DALI - have simple padders
  195. if features_kw['pad_to_max_duration']:
  196. feat_proc = BaseFeatures(
  197. pad_align=features_kw['pad_align'],
  198. pad_to_max_duration=True,
  199. max_duration=features_kw['max_duration'],
  200. sample_rate=features_kw['sample_rate'],
  201. window_size=features_kw['window_size'],
  202. window_stride=features_kw['window_stride'])
  203. features_kw['pad_to_max_duration'] = False
  204. else:
  205. feat_proc = None
  206. data_loader = DaliDataLoader(
  207. gpu_id=args.local_rank or 0,
  208. dataset_path=args.dataset_dir,
  209. config_data=dataset_kw,
  210. config_features=features_kw,
  211. json_names=args.val_manifests,
  212. batch_size=args.batch_size,
  213. pipeline_type=("train" if measure_perf else "val"), # no drop_last
  214. device_type=args.dali_device,
  215. symbols=symbols)
  216. else:
  217. dataset = AudioDataset(args.dataset_dir,
  218. args.val_manifests,
  219. symbols,
  220. **dataset_kw)
  221. data_loader = get_data_loader(dataset,
  222. args.batch_size,
  223. multi_gpu=multi_gpu,
  224. shuffle=False,
  225. num_workers=4,
  226. drop_last=False)
  227. feat_proc = FilterbankFeatures(**features_kw)
  228. model = Jasper(encoder_kw=config.encoder(cfg),
  229. decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
  230. if args.ckpt is not None:
  231. print(f'Loading the model from {args.ckpt} ...')
  232. checkpoint = torch.load(args.ckpt, map_location="cpu")
  233. key = 'ema_state_dict' if args.ema else 'state_dict'
  234. state_dict = helpers.convert_v1_state_dict(checkpoint[key])
  235. model.load_state_dict(state_dict, strict=True)
  236. model.to(device)
  237. model.eval()
  238. if feat_proc is not None:
  239. feat_proc.to(device)
  240. feat_proc.eval()
  241. if args.amp:
  242. model = model.half()
  243. if args.torchscript:
  244. greedy_decoder = GreedyCTCDecoder()
  245. feat_proc, model, greedy_decoder = torchscript_export(
  246. data_loader, feat_proc, model, greedy_decoder, args.output_dir,
  247. use_amp=args.amp, use_conv_masks=True, model_toml=args.model_toml,
  248. device=device, save=args.torchscript_export)
  249. if multi_gpu:
  250. model = DistributedDataParallel(model)
  251. agg = {'txts': [], 'preds': [], 'logits': []}
  252. dur = {'data': [], 'dnn': [], 'data+dnn': []}
  253. looped_loader = chain.from_iterable(repeat(data_loader))
  254. greedy_decoder = GreedyCTCDecoder()
  255. sync = lambda: torch.cuda.synchronize() if device.type == 'cuda' else None
  256. steps = args.steps + args.warmup_steps or len(data_loader)
  257. with torch.no_grad():
  258. for it, batch in enumerate(tqdm(looped_loader, initial=1, total=steps)):
  259. if use_dali:
  260. feats, feat_lens, txt, txt_lens = batch
  261. if feat_proc is not None:
  262. feats, feat_lens = feat_proc(feats, feat_lens)
  263. else:
  264. batch = [t.to(device, non_blocking=True) for t in batch]
  265. audio, audio_lens, txt, txt_lens = batch
  266. feats, feat_lens = feat_proc(audio, audio_lens)
  267. sync()
  268. t1 = time.time()
  269. if args.amp:
  270. feats = feats.half()
  271. feats = F.pad(feats, (args.pad_leading, 0))
  272. feat_lens += args.pad_leading
  273. if model.encoder.use_conv_masks:
  274. log_probs, log_prob_lens = model(feats, feat_lens)
  275. else:
  276. log_probs = model(feats, feat_lens)
  277. preds = greedy_decoder(log_probs)
  278. sync()
  279. t2 = time.time()
  280. # burn-in period; wait for a new loader due to num_workers
  281. if it >= 1 and (args.steps == 0 or it >= args.warmup_steps):
  282. dur['data'].append(t1 - t0)
  283. dur['dnn'].append(t2 - t1)
  284. dur['data+dnn'].append(t2 - t0)
  285. if txt is not None:
  286. agg['txts'] += helpers.gather_transcripts([txt], [txt_lens],
  287. symbols)
  288. agg['preds'] += helpers.gather_predictions([preds], symbols)
  289. agg['logits'].append(log_probs)
  290. if it + 1 == steps:
  291. break
  292. sync()
  293. t0 = time.time()
  294. # communicate the results
  295. if args.transcribe_wav:
  296. for idx, p in enumerate(agg['preds']):
  297. print_once(f'Prediction {idx+1: >3}: {p}')
  298. elif args.transcribe_filelist:
  299. pass
  300. elif not multi_gpu or distrib.get_rank() == 0:
  301. wer, _ = process_evaluation_epoch(agg)
  302. dllogger.log(step=(), data={'eval_wer': 100 * wer})
  303. if args.save_predictions:
  304. with open(args.save_predictions, 'w') as f:
  305. f.write('\n'.join(agg['preds']))
  306. if args.save_logits:
  307. logits = torch.cat(agg['logits'], dim=0).cpu()
  308. torch.save(logits, args.save_logits)
  309. # report timings
  310. if len(dur['data']) >= 20:
  311. ratios = [0.9, 0.95, 0.99]
  312. for stage in dur:
  313. lat = durs_to_percentiles(dur[stage], ratios)
  314. for k in [0.99, 0.95, 0.9, 0.5]:
  315. kk = str(k).replace('.', '_')
  316. dllogger.log(step=(), data={f'{stage.lower()}_latency_{kk}': lat[k]})
  317. else:
  318. print_once('Not enough samples to measure latencies.')
  319. if __name__ == "__main__":
  320. main()