utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # MIT License
  15. #
  16. # Copyright (c) 2020 Jungil Kong
  17. #
  18. # Permission is hereby granted, free of charge, to any person obtaining a copy
  19. # of this software and associated documentation files (the "Software"), to deal
  20. # in the Software without restriction, including without limitation the rights
  21. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  22. # copies of the Software, and to permit persons to whom the Software is
  23. # furnished to do so, subject to the following conditions:
  24. #
  25. # The above copyright notice and this permission notice shall be included in all
  26. # copies or substantial portions of the Software.
  27. #
  28. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  29. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  30. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  31. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  32. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  33. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  34. # SOFTWARE.
  35. # The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
  36. # init_weights, get_padding, AttrDict
  37. import ctypes
  38. import glob
  39. import os
  40. import re
  41. import shutil
  42. import warnings
  43. from collections import defaultdict, OrderedDict
  44. from pathlib import Path
  45. from typing import Optional
  46. import librosa
  47. import numpy as np
  48. import torch
  49. import torch.distributed as dist
  50. from scipy.io.wavfile import read
  51. def mask_from_lens(lens, max_len: Optional[int] = None):
  52. if max_len is None:
  53. max_len = lens.max()
  54. ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
  55. mask = torch.lt(ids, lens.unsqueeze(1))
  56. return mask
  57. def load_wav(full_path, torch_tensor=False):
  58. import soundfile # flac
  59. data, sampling_rate = soundfile.read(full_path, dtype='int16')
  60. if torch_tensor:
  61. return torch.FloatTensor(data.astype(np.float32)), sampling_rate
  62. else:
  63. return data, sampling_rate
  64. def load_wav_to_torch(full_path, force_sampling_rate=None):
  65. if force_sampling_rate is not None:
  66. data, sampling_rate = librosa.load(full_path, sr=force_sampling_rate)
  67. else:
  68. sampling_rate, data = read(full_path)
  69. return torch.FloatTensor(data.astype(np.float32)), sampling_rate
  70. def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
  71. def split_line(root, line):
  72. parts = line.strip().split(split)
  73. if has_speakers:
  74. paths, non_paths = parts[:-2], parts[-2:]
  75. else:
  76. paths, non_paths = parts[:-1], parts[-1:]
  77. return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
  78. fpaths_and_text = []
  79. for fname in fnames:
  80. with open(fname, encoding='utf-8') as f:
  81. fpaths_and_text += [split_line(dataset_path, line) for line in f]
  82. return fpaths_and_text
  83. def to_gpu(x):
  84. x = x.contiguous()
  85. return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
  86. def l2_promote():
  87. _libcudart = ctypes.CDLL('libcudart.so')
  88. # Set device limit on the current device
  89. # cudaLimitMaxL2FetchGranularity = 0x05
  90. pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
  91. _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
  92. _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
  93. assert pValue.contents.value == 128
  94. def prepare_tmp(path):
  95. if path is None:
  96. return
  97. p = Path(path)
  98. if p.is_dir():
  99. warnings.warn(f'{p} exists. Removing...')
  100. shutil.rmtree(p, ignore_errors=True)
  101. p.mkdir(parents=False, exist_ok=False)
  102. def print_once(*msg):
  103. if not dist.is_initialized() or dist.get_rank() == 0:
  104. print(*msg)
  105. def init_weights(m, mean=0.0, std=0.01):
  106. classname = m.__class__.__name__
  107. if classname.find("Conv") != -1:
  108. m.weight.data.normal_(mean, std)
  109. def get_padding(kernel_size, dilation=1):
  110. return int((kernel_size*dilation - dilation)/2)
  111. def load_pretrained_weights(model, ckpt_fpath):
  112. model = getattr(model, "module", model)
  113. weights = torch.load(ckpt_fpath, map_location="cpu")["state_dict"]
  114. weights = {re.sub("^module.", "", k): v for k, v in weights.items()}
  115. ckpt_emb = weights["encoder.word_emb.weight"]
  116. new_emb = model.state_dict()["encoder.word_emb.weight"]
  117. ckpt_vocab_size = ckpt_emb.size(0)
  118. new_vocab_size = new_emb.size(0)
  119. if ckpt_vocab_size != new_vocab_size:
  120. print("WARNING: Resuming from a checkpoint with a different size "
  121. "of embedding table. For best results, extend the vocab "
  122. "and ensure the common symbols' indices match.")
  123. min_len = min(ckpt_vocab_size, new_vocab_size)
  124. weights["encoder.word_emb.weight"] = ckpt_emb if ckpt_vocab_size > new_vocab_size else new_emb
  125. weights["encoder.word_emb.weight"][:min_len] = ckpt_emb[:min_len]
  126. model.load_state_dict(weights)
  127. class AttrDict(dict):
  128. def __init__(self, *args, **kwargs):
  129. super(AttrDict, self).__init__(*args, **kwargs)
  130. self.__dict__ = self
  131. class DefaultAttrDict(defaultdict):
  132. def __init__(self, *args, **kwargs):
  133. super(DefaultAttrDict, self).__init__(*args, **kwargs)
  134. self.__dict__ = self
  135. def __getattr__(self, item):
  136. return self[item]
  137. class BenchmarkStats:
  138. """ Tracks statistics used for benchmarking. """
  139. def __init__(self):
  140. self.num_frames = []
  141. self.losses = []
  142. self.mel_losses = []
  143. self.took = []
  144. def update(self, num_frames, losses, mel_losses, took):
  145. self.num_frames.append(num_frames)
  146. self.losses.append(losses)
  147. self.mel_losses.append(mel_losses)
  148. self.took.append(took)
  149. def get(self, n_epochs):
  150. frames_s = sum(self.num_frames[-n_epochs:]) / sum(self.took[-n_epochs:])
  151. return {'frames/s': frames_s,
  152. 'loss': np.mean(self.losses[-n_epochs:]),
  153. 'mel_loss': np.mean(self.mel_losses[-n_epochs:]),
  154. 'took': np.mean(self.took[-n_epochs:]),
  155. 'benchmark_epochs_num': n_epochs}
  156. def __len__(self):
  157. return len(self.losses)
  158. class Checkpointer:
  159. def __init__(self, save_dir, keep_milestones=[]):
  160. self.save_dir = save_dir
  161. self.keep_milestones = keep_milestones
  162. find = lambda name: [
  163. (int(re.search("_(\d+).pt", fn).group(1)), fn)
  164. for fn in glob.glob(f"{save_dir}/{name}_checkpoint_*.pt")]
  165. tracked = sorted(find("FastPitch"), key=lambda t: t[0])
  166. self.tracked = OrderedDict(tracked)
  167. def last_checkpoint(self, output):
  168. def corrupted(fpath):
  169. try:
  170. torch.load(fpath, map_location="cpu")
  171. return False
  172. except:
  173. warnings.warn(f"Cannot load {fpath}")
  174. return True
  175. saved = sorted(
  176. glob.glob(f"{output}/FastPitch_checkpoint_*.pt"),
  177. key=lambda f: int(re.search("_(\d+).pt", f).group(1)))
  178. if len(saved) >= 1 and not corrupted(saved[-1]):
  179. return saved[-1]
  180. elif len(saved) >= 2:
  181. return saved[-2]
  182. else:
  183. return None
  184. def maybe_load(self, model, optimizer, scaler, train_state, args,
  185. ema_model=None):
  186. assert args.checkpoint_path is None or args.resume is False, (
  187. "Specify a single checkpoint source")
  188. fpath = None
  189. if args.checkpoint_path is not None:
  190. fpath = args.checkpoint_path
  191. self.tracked = OrderedDict() # Do not track/delete prev ckpts
  192. elif args.resume:
  193. fpath = self.last_checkpoint(args.output)
  194. if fpath is None:
  195. return
  196. print_once(f"Loading model and optimizer state from {fpath}")
  197. ckpt = torch.load(fpath, map_location="cpu")
  198. train_state["epoch"] = ckpt["epoch"] + 1
  199. train_state["total_iter"] = ckpt["iteration"]
  200. no_pref = lambda sd: {re.sub("^module.", "", k): v for k, v in sd.items()}
  201. unwrap = lambda m: getattr(m, "module", m)
  202. unwrap(model).load_state_dict(no_pref(ckpt["state_dict"]))
  203. if ema_model is not None:
  204. unwrap(ema_model).load_state_dict(no_pref(ckpt["ema_state_dict"]))
  205. optimizer.load_state_dict(ckpt["optimizer"])
  206. if "scaler" in ckpt:
  207. scaler.load_state_dict(ckpt["scaler"])
  208. else:
  209. warnings.warn("AMP scaler state missing from the checkpoint.")
  210. def maybe_save(self, args, model, ema_model, optimizer, scaler, epoch,
  211. total_iter, config):
  212. intermediate = (args.epochs_per_checkpoint > 0
  213. and epoch % args.epochs_per_checkpoint == 0)
  214. final = epoch == args.epochs
  215. if not intermediate and not final and epoch not in self.keep_milestones:
  216. return
  217. rank = 0
  218. if dist.is_initialized():
  219. dist.barrier()
  220. rank = dist.get_rank()
  221. if rank != 0:
  222. return
  223. unwrap = lambda m: getattr(m, "module", m)
  224. ckpt = {"epoch": epoch,
  225. "iteration": total_iter,
  226. "config": config,
  227. "train_setup": args.__dict__,
  228. "state_dict": unwrap(model).state_dict(),
  229. "optimizer": optimizer.state_dict(),
  230. "scaler": scaler.state_dict()}
  231. if ema_model is not None:
  232. ckpt["ema_state_dict"] = unwrap(ema_model).state_dict()
  233. fpath = Path(args.output, f"FastPitch_checkpoint_{epoch}.pt")
  234. print(f"Saving model and optimizer state at epoch {epoch} to {fpath}")
  235. torch.save(ckpt, fpath)
  236. # Remove old checkpoints; keep milestones and the last two
  237. self.tracked[epoch] = fpath
  238. for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
  239. try:
  240. os.remove(self.tracked[epoch])
  241. except:
  242. pass
  243. del self.tracked[epoch]