utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # MIT License
  15. #
  16. # Copyright (c) 2020 Jungil Kong
  17. #
  18. # Permission is hereby granted, free of charge, to any person obtaining a copy
  19. # of this software and associated documentation files (the "Software"), to deal
  20. # in the Software without restriction, including without limitation the rights
  21. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  22. # copies of the Software, and to permit persons to whom the Software is
  23. # furnished to do so, subject to the following conditions:
  24. #
  25. # The above copyright notice and this permission notice shall be included in all
  26. # copies or substantial portions of the Software.
  27. #
  28. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  29. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  30. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  31. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  32. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  33. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  34. # SOFTWARE.
  35. # The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
  36. # plot_spectrogram, init_weights, get_padding, AttrDict
  37. import ctypes
  38. import glob
  39. import os
  40. import re
  41. import shutil
  42. import warnings
  43. from collections import defaultdict, OrderedDict
  44. from pathlib import Path
  45. from typing import Optional
  46. import soundfile # flac
  47. import matplotlib
  48. import numpy as np
  49. import torch
  50. import torch.distributed as dist
  51. def mask_from_lens(lens, max_len: Optional[int] = None):
  52. if max_len is None:
  53. max_len = lens.max()
  54. ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
  55. mask = torch.lt(ids, lens.unsqueeze(1))
  56. return mask
  57. def freeze(model):
  58. for p in model.parameters():
  59. p.requires_grad = False
  60. def unfreeze(model):
  61. for p in model.parameters():
  62. p.requires_grad = True
  63. def reduce_tensor(tensor, world_size):
  64. if world_size == 1:
  65. return tensor
  66. rt = tensor.detach().clone()
  67. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  68. return rt.true_divide(world_size)
  69. def adjust_fine_tuning_lr(args, ckpt_d):
  70. assert args.fine_tuning
  71. if args.fine_tune_lr_factor == 1.:
  72. return
  73. for k in ['optim_d', 'optim_g']:
  74. for param_group in ckpt_d[k]['param_groups']:
  75. old_v = param_group['lr']
  76. new_v = old_v * args.fine_tune_lr_factor
  77. print(f'Init fine-tuning: changing {k} lr: {old_v} --> {new_v}')
  78. param_group['lr'] = new_v
  79. def init_distributed(args, world_size, rank):
  80. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  81. print(f"{args.local_rank}: Initializing distributed training")
  82. # Set cuda device so everything is done on the right GPU.
  83. torch.cuda.set_device(rank % torch.cuda.device_count())
  84. # Initialize distributed communication
  85. dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
  86. init_method='env://')
  87. print(f"{args.local_rank}: Done initializing distributed training")
  88. def load_wav(full_path, torch_tensor=False):
  89. data, sampling_rate = soundfile.read(full_path, dtype='int16')
  90. if torch_tensor:
  91. return torch.FloatTensor(data.astype(np.float32)), sampling_rate
  92. else:
  93. return data, sampling_rate
  94. def load_wav_to_torch(full_path, force_sampling_rate=None):
  95. if force_sampling_rate is not None:
  96. raise NotImplementedError
  97. return load_wav(full_path, True)
  98. def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
  99. def split_line(root, line):
  100. parts = line.strip().split(split)
  101. if len(parts) == 1:
  102. paths, non_paths = parts, []
  103. else:
  104. if has_speakers:
  105. paths, non_paths = parts[:-2], parts[-2:]
  106. else:
  107. paths, non_paths = parts[:-1], parts[-1:]
  108. return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
  109. fpaths_and_text = []
  110. for fname in fnames:
  111. with open(fname, encoding='utf-8') as f:
  112. fpaths_and_text += [split_line(dataset_path, line) for line in f]
  113. return fpaths_and_text
  114. def to_gpu(x):
  115. x = x.contiguous()
  116. return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
  117. def l2_promote():
  118. _libcudart = ctypes.CDLL('libcudart.so')
  119. # Set device limit on the current device
  120. # cudaLimitMaxL2FetchGranularity = 0x05
  121. pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
  122. _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
  123. _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
  124. assert pValue.contents.value == 128
  125. def prepare_tmp(path):
  126. if path is None:
  127. return
  128. p = Path(path)
  129. if p.is_dir():
  130. warnings.warn(f'{p} exists. Removing...')
  131. shutil.rmtree(p, ignore_errors=True)
  132. p.mkdir(parents=False, exist_ok=False)
  133. def print_once(*msg):
  134. if not dist.is_initialized() or dist.get_rank() == 0:
  135. print(*msg)
  136. def plot_spectrogram(spectrogram):
  137. matplotlib.use("Agg")
  138. import matplotlib.pylab as plt
  139. fig, ax = plt.subplots(figsize=(10, 2))
  140. im = ax.imshow(spectrogram, aspect="auto", origin="lower",
  141. interpolation='none')
  142. plt.colorbar(im, ax=ax)
  143. fig.canvas.draw()
  144. plt.close()
  145. return fig
  146. def init_weights(m, mean=0.0, std=0.01):
  147. classname = m.__class__.__name__
  148. if classname.find("Conv") != -1:
  149. m.weight.data.normal_(mean, std)
  150. def get_padding(kernel_size, dilation=1):
  151. return int((kernel_size*dilation - dilation)/2)
  152. class AttrDict(dict):
  153. def __init__(self, *args, **kwargs):
  154. super(AttrDict, self).__init__(*args, **kwargs)
  155. self.__dict__ = self
  156. class DefaultAttrDict(defaultdict):
  157. def __init__(self, *args, **kwargs):
  158. super(DefaultAttrDict, self).__init__(*args, **kwargs)
  159. self.__dict__ = self
  160. def __getattr__(self, item):
  161. return self[item]
  162. class Checkpointer:
  163. def __init__(self, save_dir,
  164. keep_milestones=[1000, 2000, 3000, 4000, 5000, 6000]):
  165. self.save_dir = save_dir
  166. self.keep_milestones = keep_milestones
  167. find = lambda name: {int(re.search('_(\d+).pt', fn).group(1)): fn
  168. for fn in glob.glob(f'{save_dir}/{name}_checkpoint_*.pt')}
  169. saved_g = find('hifigan_gen')
  170. saved_d = find('hifigan_discrim')
  171. common_epochs = sorted(set(saved_g.keys()) & set(saved_d.keys()))
  172. self.tracked = OrderedDict([(ep, (saved_g[ep], saved_d[ep]))
  173. for ep in common_epochs])
  174. def maybe_load(self, gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d,
  175. train_state, args, gen_ema=None, mpd_ema=None, msd_ema=None):
  176. fpath_g = args.checkpoint_path_gen
  177. fpath_d = args.checkpoint_path_discrim
  178. assert (fpath_g is None) == (fpath_d is None)
  179. if fpath_g is not None:
  180. ckpt_paths = [(fpath_g, fpath_d)]
  181. self.tracked = OrderedDict() # Do not track/delete prev ckpts
  182. elif args.resume:
  183. ckpt_paths = list(reversed(self.tracked.values()))[:2]
  184. else:
  185. return
  186. ckpt_g = None
  187. ckpt_d = None
  188. for fpath_g, fpath_d in ckpt_paths:
  189. if args.local_rank == 0:
  190. print(f'Loading models from {fpath_g} {fpath_d}')
  191. try:
  192. ckpt_g = torch.load(fpath_g, map_location='cpu')
  193. ckpt_d = torch.load(fpath_d, map_location='cpu')
  194. break
  195. except:
  196. print(f'WARNING: Cannot load {fpath_g} and {fpath_d}')
  197. if ckpt_g is None or ckpt_d is None:
  198. return
  199. ep_g = ckpt_g.get('train_state', ckpt_g).get('epoch', None)
  200. ep_d = ckpt_d.get('train_state', ckpt_d).get('epoch', None)
  201. assert ep_g == ep_d, \
  202. f'Mismatched epochs of gen and discrim ({ep_g} != {ep_d})'
  203. train_state.update(ckpt_g['train_state'])
  204. fine_tune_epoch_start = train_state.get('fine_tune_epoch_start')
  205. if args.fine_tuning and fine_tune_epoch_start is None:
  206. # Fine-tuning just began
  207. train_state['fine_tune_epoch_start'] = train_state['epoch'] + 1
  208. train_state['fine_tune_lr_factor'] = args.fine_tune_lr_factor
  209. adjust_fine_tuning_lr(args, ckpt_d)
  210. unwrap = lambda m: getattr(m, 'module', m)
  211. unwrap(gen).load_state_dict(ckpt_g.get('gen', ckpt_g['generator']))
  212. unwrap(mpd).load_state_dict(ckpt_d['mpd'])
  213. unwrap(msd).load_state_dict(ckpt_d['msd'])
  214. optim_g.load_state_dict(ckpt_d['optim_g'])
  215. optim_d.load_state_dict(ckpt_d['optim_d'])
  216. if 'scaler_g' in ckpt_d:
  217. scaler_g.load_state_dict(ckpt_d['scaler_g'])
  218. scaler_d.load_state_dict(ckpt_d['scaler_d'])
  219. else:
  220. warnings.warn('No grad scaler state found in the checkpoint.')
  221. if gen_ema is not None:
  222. gen_ema.load_state_dict(ckpt_g['gen_ema'])
  223. if mpd_ema is not None:
  224. mpd_ema.load_state_dict(ckpt_d['mpd_ema'])
  225. if msd_ema is not None:
  226. msd_ema.load_state_dict(ckpt_d['msd_ema'])
  227. def maybe_save(self, gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d,
  228. epoch, train_state, args, gen_config, train_setup,
  229. gen_ema=None, mpd_ema=None, msd_ema=None):
  230. rank = 0
  231. if dist.is_initialized():
  232. dist.barrier()
  233. rank = dist.get_rank()
  234. if rank != 0:
  235. return
  236. if epoch == 0:
  237. return
  238. if epoch < args.epochs and (args.checkpoint_interval == 0
  239. or epoch % args.checkpoint_interval > 0):
  240. return
  241. unwrap = lambda m: getattr(m, 'module', m)
  242. fpath_g = Path(self.save_dir, f'hifigan_gen_checkpoint_{epoch}.pt')
  243. ckpt_g = {
  244. 'generator': unwrap(gen).state_dict(),
  245. 'gen_ema': gen_ema.state_dict() if gen_ema is not None else None,
  246. 'config': gen_config,
  247. 'train_setup': train_setup,
  248. 'train_state': train_state,
  249. }
  250. fpath_d = Path(self.save_dir, f'hifigan_discrim_checkpoint_{epoch}.pt')
  251. ckpt_d = {
  252. 'mpd': unwrap(mpd).state_dict(),
  253. 'msd': unwrap(msd).state_dict(),
  254. 'mpd_ema': mpd_ema.state_dict() if mpd_ema is not None else None,
  255. 'msd_ema': msd_ema.state_dict() if msd_ema is not None else None,
  256. 'optim_g': optim_g.state_dict(),
  257. 'optim_d': optim_d.state_dict(),
  258. 'scaler_g': scaler_g.state_dict(),
  259. 'scaler_d': scaler_d.state_dict(),
  260. 'train_state': train_state,
  261. # compat with original code
  262. 'steps': train_state['iters_all'],
  263. 'epoch': epoch,
  264. }
  265. print(f"Saving model and optimizer state to {fpath_g} and {fpath_d}")
  266. torch.save(ckpt_g, fpath_g)
  267. torch.save(ckpt_d, fpath_d)
  268. # Remove old checkpoints; keep milestones and the last two
  269. self.tracked[epoch] = (fpath_g, fpath_d)
  270. for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
  271. try:
  272. os.remove(self.tracked[epoch][0])
  273. os.remove(self.tracked[epoch][1])
  274. del self.tracked[epoch]
  275. except:
  276. pass