generate_1h_10h_datasets.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. from itertools import chain
  16. from pathlib import Path
  17. def load_lines(fpath):
  18. with open(fpath) as f:
  19. return [line for line in f]
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument('ls_ft', type=Path,
  22. help='Libri-light librispeech_finetuning dir')
  23. parser.add_argument('ls_filelists', type=Path,
  24. help='Directory with .tsv .wrd etc files for LibriSpeech full 960')
  25. parser.add_argument('out', type=Path, help='Output directory')
  26. args = parser.parse_args()
  27. # Load LS
  28. tsv = load_lines(args.ls_filelists / "train-full-960.tsv")
  29. wrd = load_lines(args.ls_filelists / "train-full-960.wrd")
  30. ltr = load_lines(args.ls_filelists / "train-full-960.ltr")
  31. assert len(tsv) == len(wrd) + 1
  32. assert len(ltr) == len(wrd)
  33. files = {}
  34. for path_frames, w, l in zip(tsv[1:], wrd, ltr):
  35. path, _ = path_frames.split("\t")
  36. key = Path(path).stem
  37. files[key] = (path_frames, w, l)
  38. print(f"Loaded {len(files)} entries from {args.ls_filelists}/train-full-960")
  39. # Load LL-LS
  40. files_1h = list((args.ls_ft / "1h").rglob("*.flac"))
  41. files_9h = list((args.ls_ft / "9h").rglob("*.flac"))
  42. print(f"Found {len(files_1h)} files in the 1h dataset")
  43. print(f"Found {len(files_9h)} files in the 9h dataset")
  44. for name, file_iter in [("train-1h", files_1h),
  45. ("train-10h", chain(files_1h, files_9h))]:
  46. with open(args.out / f"{name}.tsv", "w") as ftsv, \
  47. open(args.out / f"{name}.wrd", "w") as fwrd, \
  48. open(args.out / f"{name}.ltr", "w") as fltr:
  49. nframes = 0
  50. ftsv.write(tsv[0])
  51. for fpath in file_iter:
  52. key = fpath.stem
  53. t, w, l = files[key]
  54. ftsv.write(t)
  55. fwrd.write(w)
  56. fltr.write(l)
  57. nframes += int(t.split()[1])
  58. print(f"Written {nframes} frames ({nframes / 16000 / 60 / 60:.2f} h at 16kHz)")