preprocessing_utils.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import multiprocessing
  17. import functools
  18. import sox
  19. from tqdm import tqdm
  20. def preprocess(data, input_dir, dest_dir, target_sr=None, speed=None,
  21. overwrite=True):
  22. speed = speed or []
  23. speed.append(1)
  24. speed = list(set(speed)) # Make uniqe
  25. input_fname = os.path.join(input_dir,
  26. data['input_relpath'],
  27. data['input_fname'])
  28. input_sr = sox.file_info.sample_rate(input_fname)
  29. target_sr = target_sr or input_sr
  30. os.makedirs(os.path.join(dest_dir, data['input_relpath']), exist_ok=True)
  31. output_dict = {}
  32. output_dict['transcript'] = data['transcript'].lower().strip()
  33. output_dict['files'] = []
  34. fname = os.path.splitext(data['input_fname'])[0]
  35. for s in speed:
  36. output_fname = fname + '{}.wav'.format('' if s == 1 else '-{}'.format(s))
  37. output_fpath = os.path.join(dest_dir,
  38. data['input_relpath'],
  39. output_fname)
  40. if not os.path.exists(output_fpath) or overwrite:
  41. cbn = sox.Transformer().speed(factor=s).convert(target_sr)
  42. cbn.build(input_fname, output_fpath)
  43. file_info = sox.file_info.info(output_fpath)
  44. file_info['fname'] = os.path.join(os.path.basename(dest_dir),
  45. data['input_relpath'],
  46. output_fname)
  47. file_info['speed'] = s
  48. output_dict['files'].append(file_info)
  49. if s == 1:
  50. file_info = sox.file_info.info(output_fpath)
  51. output_dict['original_duration'] = file_info['duration']
  52. output_dict['original_num_samples'] = file_info['num_samples']
  53. return output_dict
  54. def parallel_preprocess(dataset, input_dir, dest_dir, target_sr, speed,
  55. overwrite, parallel):
  56. with multiprocessing.Pool(parallel) as p:
  57. func = functools.partial(preprocess, input_dir=input_dir,
  58. dest_dir=dest_dir, target_sr=target_sr,
  59. speed=speed, overwrite=overwrite)
  60. dataset = list(tqdm(p.imap(func, dataset), total=len(dataset)))
  61. return dataset