data_functions.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import torch
  28. from tacotron2.data_function import TextMelCollate
  29. from tacotron2.data_function import TextMelLoader
  30. from waveglow.data_function import MelAudioLoader
  31. from tacotron2.data_function import batch_to_gpu as batch_to_gpu_tacotron2
  32. from waveglow.data_function import batch_to_gpu as batch_to_gpu_waveglow
  33. def get_collate_function(model_name, n_frames_per_step=1):
  34. if model_name == 'Tacotron2':
  35. collate_fn = TextMelCollate(n_frames_per_step)
  36. elif model_name == 'WaveGlow':
  37. collate_fn = torch.utils.data.dataloader.default_collate
  38. else:
  39. raise NotImplementedError(
  40. "unknown collate function requested: {}".format(model_name))
  41. return collate_fn
  42. def get_data_loader(model_name, dataset_path, audiopaths_and_text, args):
  43. if model_name == 'Tacotron2':
  44. data_loader = TextMelLoader(dataset_path, audiopaths_and_text, args)
  45. elif model_name == 'WaveGlow':
  46. data_loader = MelAudioLoader(dataset_path, audiopaths_and_text, args)
  47. else:
  48. raise NotImplementedError(
  49. "unknown data loader requested: {}".format(model_name))
  50. return data_loader
  51. def get_batch_to_gpu(model_name):
  52. if model_name == 'Tacotron2':
  53. batch_to_gpu = batch_to_gpu_tacotron2
  54. elif model_name == 'WaveGlow':
  55. batch_to_gpu = batch_to_gpu_waveglow
  56. else:
  57. raise NotImplementedError(
  58. "unknown batch_to_gpu requested: {}".format(model_name))
  59. return batch_to_gpu