preprocessing_utils.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #!/usr/bin/env python
  15. import os
  16. import multiprocessing
  17. import functools
  18. import sox
  19. from tqdm import tqdm
  20. def preprocess(data, input_dir, dest_dir, target_sr=None, speed=None,
  21. overwrite=True):
  22. speed = speed or []
  23. speed.append(1)
  24. speed = list(set(speed)) # Make uniqe
  25. input_fname = os.path.join(input_dir,
  26. data['input_relpath'],
  27. data['input_fname'])
  28. input_sr = sox.file_info.sample_rate(input_fname)
  29. target_sr = target_sr or input_sr
  30. os.makedirs(os.path.join(dest_dir, data['input_relpath']), exist_ok=True)
  31. output_dict = {}
  32. output_dict['transcript'] = data['transcript'].lower().strip()
  33. output_dict['files'] = []
  34. fname = os.path.splitext(data['input_fname'])[0]
  35. for s in speed:
  36. output_fname = fname + '{}.wav'.format('' if s==1 else '-{}'.format(s))
  37. output_fpath = os.path.join(dest_dir,
  38. data['input_relpath'],
  39. output_fname)
  40. if not os.path.exists(output_fpath) or overwrite:
  41. cbn = sox.Transformer().speed(factor=s).convert(target_sr)
  42. cbn.build(input_fname, output_fpath)
  43. file_info = sox.file_info.info(output_fpath)
  44. file_info['fname'] = os.path.join(os.path.basename(dest_dir),
  45. data['input_relpath'],
  46. output_fname)
  47. file_info['speed'] = s
  48. output_dict['files'].append(file_info)
  49. if s == 1:
  50. file_info = sox.file_info.info(output_fpath)
  51. output_dict['original_duration'] = file_info['duration']
  52. output_dict['original_num_samples'] = file_info['num_samples']
  53. return output_dict
  54. def parallel_preprocess(dataset, input_dir, dest_dir, target_sr, speed, overwrite, parallel):
  55. with multiprocessing.Pool(parallel) as p:
  56. func = functools.partial(preprocess,
  57. input_dir=input_dir, dest_dir=dest_dir,
  58. target_sr=target_sr, speed=speed, overwrite=overwrite)
  59. dataset = list(tqdm(p.imap(func, dataset), total=len(dataset)))
  60. return dataset