audio.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. import soundfile as sf
  16. import librosa
  17. import torch
  18. import numpy as np
  19. import sox
  20. def audio_from_file(file_path, offset=0, duration=0, trim=False, target_sr=16000):
  21. audio = AudioSegment(file_path, target_sr=target_sr, int_values=False,
  22. offset=offset, duration=duration, trim=trim)
  23. samples = torch.tensor(audio.samples, dtype=torch.float).cuda()
  24. num_samples = torch.tensor(samples.shape[0]).int().cuda()
  25. return (samples.unsqueeze(0), num_samples.unsqueeze(0))
  26. class AudioSegment(object):
  27. """Monaural audio segment abstraction.
  28. :param samples: Audio samples [num_samples x num_channels].
  29. :type samples: ndarray.float32
  30. :param sample_rate: Audio sample rate.
  31. :type sample_rate: int
  32. :raises TypeError: If the sample data type is not float or int.
  33. """
  34. def __init__(self, filename, target_sr=None, int_values=False, offset=0,
  35. duration=0, trim=False, trim_db=60):
  36. """Create audio segment from samples.
  37. Samples are converted to float32 internally, with int scaled to [-1, 1].
  38. Load a file supported by librosa and return as an AudioSegment.
  39. :param filename: path of file to load
  40. :param target_sr: the desired sample rate
  41. :param int_values: if true, load samples as 32-bit integers
  42. :param offset: offset in seconds when loading audio
  43. :param duration: duration in seconds when loading audio
  44. :return: numpy array of samples
  45. """
  46. with sf.SoundFile(filename, 'r') as f:
  47. dtype = 'int32' if int_values else 'float32'
  48. sample_rate = f.samplerate
  49. if offset > 0:
  50. f.seek(int(offset * sample_rate))
  51. if duration > 0:
  52. samples = f.read(int(duration * sample_rate), dtype=dtype)
  53. else:
  54. samples = f.read(dtype=dtype)
  55. samples = samples.transpose()
  56. samples = self._convert_samples_to_float32(samples)
  57. if target_sr is not None and target_sr != sample_rate:
  58. samples = librosa.resample(samples, orig_sr=sample_rate,
  59. target_sr=target_sr)
  60. sample_rate = target_sr
  61. if trim:
  62. samples, _ = librosa.effects.trim(samples, top_db=trim_db)
  63. self._samples = samples
  64. self._sample_rate = sample_rate
  65. if self._samples.ndim >= 2:
  66. self._samples = np.mean(self._samples, 0)
  67. def __eq__(self, other):
  68. """Return whether two objects are equal."""
  69. if type(other) is not type(self):
  70. return False
  71. if self._sample_rate != other._sample_rate:
  72. return False
  73. if self._samples.shape != other._samples.shape:
  74. return False
  75. if np.any(self.samples != other._samples):
  76. return False
  77. return True
  78. def __ne__(self, other):
  79. """Return whether two objects are unequal."""
  80. return not self.__eq__(other)
  81. def __str__(self):
  82. """Return human-readable representation of segment."""
  83. return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
  84. "rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate,
  85. self.duration, self.rms_db))
  86. @staticmethod
  87. def _convert_samples_to_float32(samples):
  88. """Convert sample type to float32.
  89. Audio sample type is usually integer or float-point.
  90. Integers will be scaled to [-1, 1] in float32.
  91. """
  92. float32_samples = samples.astype('float32')
  93. if samples.dtype in np.sctypes['int']:
  94. bits = np.iinfo(samples.dtype).bits
  95. float32_samples *= (1. / 2 ** (bits - 1))
  96. elif samples.dtype in np.sctypes['float']:
  97. pass
  98. else:
  99. raise TypeError("Unsupported sample type: %s." % samples.dtype)
  100. return float32_samples
  101. @property
  102. def samples(self):
  103. return self._samples.copy()
  104. @property
  105. def sample_rate(self):
  106. return self._sample_rate
  107. @property
  108. def num_samples(self):
  109. return self._samples.shape[0]
  110. @property
  111. def duration(self):
  112. return self._samples.shape[0] / float(self._sample_rate)
  113. @property
  114. def rms_db(self):
  115. mean_square = np.mean(self._samples ** 2)
  116. return 10 * np.log10(mean_square)
  117. def gain_db(self, gain):
  118. self._samples *= 10. ** (gain / 20.)
  119. def pad(self, pad_size, symmetric=False):
  120. """Add zero padding to the sample.
  121. The pad size is given in number of samples. If symmetric=True,
  122. `pad_size` will be added to both sides. If false, `pad_size` zeros
  123. will be added only to the end.
  124. """
  125. self._samples = np.pad(self._samples,
  126. (pad_size if symmetric else 0, pad_size),
  127. mode='constant')
  128. def subsegment(self, start_time=None, end_time=None):
  129. """Cut the AudioSegment between given boundaries.
  130. Note that this is an in-place transformation.
  131. :param start_time: Beginning of subsegment in seconds.
  132. :type start_time: float
  133. :param end_time: End of subsegment in seconds.
  134. :type end_time: float
  135. :raise ValueError: If start_time or end_time is incorrectly set, e.g. out
  136. of bounds in time.
  137. """
  138. start_time = 0.0 if start_time is None else start_time
  139. end_time = self.duration if end_time is None else end_time
  140. if start_time < 0.0:
  141. start_time = self.duration + start_time
  142. if end_time < 0.0:
  143. end_time = self.duration + end_time
  144. if start_time < 0.0:
  145. raise ValueError("The slice start position (%f s) is out of "
  146. "bounds." % start_time)
  147. if end_time < 0.0:
  148. raise ValueError("The slice end position (%f s) is out of bounds." %
  149. end_time)
  150. if start_time > end_time:
  151. raise ValueError("The slice start position (%f s) is later than "
  152. "the end position (%f s)." % (start_time, end_time))
  153. if end_time > self.duration:
  154. raise ValueError("The slice end position (%f s) is out of bounds "
  155. "(> %f s)" % (end_time, self.duration))
  156. start_sample = int(round(start_time * self._sample_rate))
  157. end_sample = int(round(end_time * self._sample_rate))
  158. self._samples = self._samples[start_sample:end_sample]
  159. class Perturbation:
  160. def __init__(self, p=0.1, rng=None):
  161. self.p = p
  162. self._rng = random.Random() if rng is None else rng
  163. def maybe_apply(self, segment, sample_rate=None):
  164. if self._rng.random() < self.p:
  165. self(segment, sample_rate)
  166. class SpeedPerturbation(Perturbation):
  167. def __init__(self, min_rate=0.85, max_rate=1.15, discrete=False, p=0.1, rng=None):
  168. super(SpeedPerturbation, self).__init__(p, rng)
  169. assert 0 < min_rate < max_rate
  170. self.min_rate = min_rate
  171. self.max_rate = max_rate
  172. self.discrete = discrete
  173. def __call__(self, data, sample_rate):
  174. if self.discrete:
  175. rate = np.random.choice([self.min_rate, None, self.max_rate])
  176. else:
  177. rate = self._rng.uniform(self.min_rate, self.max_rate)
  178. if rate is not None:
  179. data._samples = sox.Transformer().speed(factor=rate).build_array(
  180. input_array=data._samples, sample_rate_in=sample_rate)
  181. class GainPerturbation(Perturbation):
  182. def __init__(self, min_gain_dbfs=-10, max_gain_dbfs=10, p=0.1, rng=None):
  183. super(GainPerturbation, self).__init__(p, rng)
  184. self._rng = random.Random() if rng is None else rng
  185. self._min_gain_dbfs = min_gain_dbfs
  186. self._max_gain_dbfs = max_gain_dbfs
  187. def __call__(self, data, sample_rate=None):
  188. del sample_rate
  189. gain = self._rng.uniform(self._min_gain_dbfs, self._max_gain_dbfs)
  190. data._samples = data._samples * (10. ** (gain / 20.))
  191. class ShiftPerturbation(Perturbation):
  192. def __init__(self, min_shift_ms=-5.0, max_shift_ms=5.0, p=0.1, rng=None):
  193. super(ShiftPerturbation, self).__init__(p, rng)
  194. self._min_shift_ms = min_shift_ms
  195. self._max_shift_ms = max_shift_ms
  196. def __call__(self, data, sample_rate):
  197. shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
  198. if abs(shift_ms) / 1000 > data.duration:
  199. # TODO: do something smarter than just ignore this condition
  200. return
  201. shift_samples = int(shift_ms * data.sample_rate // 1000)
  202. # print("DEBUG: shift:", shift_samples)
  203. if shift_samples < 0:
  204. data._samples[-shift_samples:] = data._samples[:shift_samples]
  205. data._samples[:-shift_samples] = 0
  206. elif shift_samples > 0:
  207. data._samples[:-shift_samples] = data._samples[shift_samples:]
  208. data._samples[-shift_samples:] = 0