dataset.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. from pathlib import Path
  16. import numpy as np
  17. import torch
  18. from torch.utils.data import Dataset, DataLoader
  19. from torch.utils.data.distributed import DistributedSampler
  20. from .audio import (audio_from_file, AudioSegment, GainPerturbation,
  21. ShiftPerturbation, SpeedPerturbation)
  22. from .text import _clean_text, punctuation_map
  23. def normalize_string(s, labels, punct_map):
  24. """Normalizes string.
  25. Example:
  26. 'call me at 8:00 pm!' -> 'call me at eight zero pm'
  27. """
  28. labels = set(labels)
  29. try:
  30. text = _clean_text(s, ["english_cleaners"], punct_map).strip()
  31. return ''.join([tok for tok in text if all(t in labels for t in tok)])
  32. except:
  33. print(f"WARNING: Normalizing failed: {s}")
  34. return None
  35. class FilelistDataset(Dataset):
  36. def __init__(self, filelist_fpath):
  37. self.samples = [line.strip() for line in open(filelist_fpath, 'r')]
  38. def __len__(self):
  39. return len(self.samples)
  40. def __getitem__(self, index):
  41. audio, audio_len = audio_from_file(self.samples[index])
  42. return (audio.squeeze(0), audio_len, torch.LongTensor([0]),
  43. torch.LongTensor([0]))
  44. class SingleAudioDataset(FilelistDataset):
  45. def __init__(self, audio_fpath):
  46. self.samples = [audio_fpath]
  47. class AudioDataset(Dataset):
  48. def __init__(self, data_dir, manifest_fpaths, labels,
  49. sample_rate=16000, min_duration=0.1, max_duration=float("inf"),
  50. pad_to_max_duration=False, max_utts=0, normalize_transcripts=True,
  51. sort_by_duration=False, trim_silence=False,
  52. speed_perturbation=None, gain_perturbation=None,
  53. shift_perturbation=None, ignore_offline_speed_perturbation=False):
  54. """Loads audio, transcript and durations listed in a .json file.
  55. Args:
  56. data_dir: absolute path to dataset folder
  57. manifest_filepath: relative path from dataset folder
  58. to manifest json as described above. Can be coma-separated paths.
  59. labels (str): all possible output symbols
  60. min_duration (int): skip audio shorter than threshold
  61. max_duration (int): skip audio longer than threshold
  62. pad_to_max_duration (bool): pad all sequences to max_duration
  63. max_utts (int): limit number of utterances
  64. normalize_transcripts (bool): normalize transcript text
  65. sort_by_duration (bool): sort sequences by increasing duration
  66. trim_silence (bool): trim leading and trailing silence from audio
  67. ignore_offline_speed_perturbation (bool): use precomputed speed perturbation
  68. Returns:
  69. tuple of Tensors
  70. """
  71. self.data_dir = data_dir
  72. self.labels = labels
  73. self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
  74. self.punctuation_map = punctuation_map(labels)
  75. self.blank_index = len(labels)
  76. self.pad_to_max_duration = pad_to_max_duration
  77. self.sort_by_duration = sort_by_duration
  78. self.max_utts = max_utts
  79. self.normalize_transcripts = normalize_transcripts
  80. self.ignore_offline_speed_perturbation = ignore_offline_speed_perturbation
  81. self.min_duration = min_duration
  82. self.max_duration = max_duration
  83. self.trim_silence = trim_silence
  84. self.sample_rate = sample_rate
  85. perturbations = []
  86. if speed_perturbation is not None:
  87. perturbations.append(SpeedPerturbation(**speed_perturbation))
  88. if gain_perturbation is not None:
  89. perturbations.append(GainPerturbation(**gain_perturbation))
  90. if shift_perturbation is not None:
  91. perturbations.append(ShiftPerturbation(**shift_perturbation))
  92. self.perturbations = perturbations
  93. self.max_duration = max_duration
  94. self.samples = []
  95. self.duration = 0.0
  96. self.duration_filtered = 0.0
  97. for fpath in manifest_fpaths:
  98. self._load_json_manifest(fpath)
  99. if sort_by_duration:
  100. self.samples = sorted(self.samples, key=lambda s: s['duration'])
  101. def __getitem__(self, index):
  102. s = self.samples[index]
  103. rn_indx = np.random.randint(len(s['audio_filepath']))
  104. duration = s['audio_duration'][rn_indx] if 'audio_duration' in s else 0
  105. offset = s.get('offset', 0)
  106. segment = AudioSegment(
  107. s['audio_filepath'][rn_indx], target_sr=self.sample_rate,
  108. offset=offset, duration=duration, trim=self.trim_silence)
  109. for p in self.perturbations:
  110. p.maybe_apply(segment, self.sample_rate)
  111. segment = torch.FloatTensor(segment.samples)
  112. return (segment,
  113. torch.tensor(segment.shape[0]).int(),
  114. torch.tensor(s["transcript"]),
  115. torch.tensor(len(s["transcript"])).int())
  116. def __len__(self):
  117. return len(self.samples)
  118. def _load_json_manifest(self, fpath):
  119. for s in json.load(open(fpath, "r", encoding="utf-8")):
  120. if self.pad_to_max_duration and not self.ignore_offline_speed_perturbation:
  121. # require all perturbed samples to be < self.max_duration
  122. s_max_duration = max(f['duration'] for f in s['files'])
  123. else:
  124. # otherwise we allow perturbances to be > self.max_duration
  125. s_max_duration = s['original_duration']
  126. s['duration'] = s.pop('original_duration')
  127. if not (self.min_duration <= s_max_duration <= self.max_duration):
  128. self.duration_filtered += s['duration']
  129. continue
  130. # Prune and normalize according to transcript
  131. tr = (s.get('transcript', None) or
  132. self.load_transcript(s['text_filepath']))
  133. if not isinstance(tr, str):
  134. print(f'WARNING: Skipped sample (transcript not a str): {tr}.')
  135. self.duration_filtered += s['duration']
  136. continue
  137. if self.normalize_transcripts:
  138. tr = normalize_string(tr, self.labels, self.punctuation_map)
  139. s["transcript"] = self.to_vocab_inds(tr)
  140. files = s.pop('files')
  141. if self.ignore_offline_speed_perturbation:
  142. files = [f for f in files if f['speed'] == 1.0]
  143. s['audio_duration'] = [f['duration'] for f in files]
  144. s['audio_filepath'] = [str(Path(self.data_dir, f['fname']))
  145. for f in files]
  146. self.samples.append(s)
  147. self.duration += s['duration']
  148. if self.max_utts > 0 and len(self.samples) >= self.max_utts:
  149. print(f'Reached max_utts={self.max_utts}. Finished parsing {fpath}.')
  150. break
  151. def load_transcript(self, transcript_path):
  152. with open(transcript_path, 'r', encoding="utf-8") as transcript_file:
  153. transcript = transcript_file.read().replace('\n', '')
  154. return transcript
  155. def to_vocab_inds(self, transcript):
  156. chars = [self.labels_map.get(x, self.blank_index) for x in list(transcript)]
  157. transcript = list(filter(lambda x: x != self.blank_index, chars))
  158. return transcript
  159. def collate_fn(batch):
  160. bs = len(batch)
  161. max_len = lambda l, idx: max(el[idx].size(0) for el in l)
  162. audio = torch.zeros(bs, max_len(batch, 0))
  163. audio_lens = torch.zeros(bs, dtype=torch.int32)
  164. transcript = torch.zeros(bs, max_len(batch, 2))
  165. transcript_lens = torch.zeros(bs, dtype=torch.int32)
  166. for i, sample in enumerate(batch):
  167. audio[i].narrow(0, 0, sample[0].size(0)).copy_(sample[0])
  168. audio_lens[i] = sample[1]
  169. transcript[i].narrow(0, 0, sample[2].size(0)).copy_(sample[2])
  170. transcript_lens[i] = sample[3]
  171. return audio, audio_lens, transcript, transcript_lens
  172. def get_data_loader(dataset, batch_size, multi_gpu=True, shuffle=True,
  173. drop_last=True, num_workers=4):
  174. kw = {'dataset': dataset, 'collate_fn': collate_fn,
  175. 'num_workers': num_workers, 'pin_memory': True}
  176. if multi_gpu:
  177. loader_shuffle = False
  178. sampler = DistributedSampler(dataset, shuffle=shuffle)
  179. else:
  180. loader_shuffle = shuffle
  181. sampler = None
  182. return DataLoader(batch_size=batch_size, drop_last=drop_last,
  183. sampler=sampler, shuffle=loader_shuffle, **kw)