utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # MIT License
  15. #
  16. # Copyright (c) 2020 Jungil Kong
  17. #
  18. # Permission is hereby granted, free of charge, to any person obtaining a copy
  19. # of this software and associated documentation files (the "Software"), to deal
  20. # in the Software without restriction, including without limitation the rights
  21. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  22. # copies of the Software, and to permit persons to whom the Software is
  23. # furnished to do so, subject to the following conditions:
  24. #
  25. # The above copyright notice and this permission notice shall be included in all
  26. # copies or substantial portions of the Software.
  27. #
  28. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  29. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  30. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  31. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  32. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  33. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  34. # SOFTWARE.
  35. # The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
  36. # plot_spectrogram, init_weights, get_padding, AttrDict
  37. import ctypes
  38. import glob
  39. import os
  40. import re
  41. import shutil
  42. import warnings
  43. from collections import defaultdict, OrderedDict
  44. from pathlib import Path
  45. from typing import Optional
  46. import soundfile # flac
  47. import amp_C
  48. import matplotlib
  49. matplotlib.use("Agg")
  50. import matplotlib.pylab as plt
  51. import numpy as np
  52. import torch
  53. import torch.distributed as dist
  54. def mask_from_lens(lens, max_len: Optional[int] = None):
  55. if max_len is None:
  56. max_len = lens.max()
  57. ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
  58. mask = torch.lt(ids, lens.unsqueeze(1))
  59. return mask
  60. def freeze(model):
  61. for p in model.parameters():
  62. p.requires_grad = False
  63. def unfreeze(model):
  64. for p in model.parameters():
  65. p.requires_grad = True
  66. def reduce_tensor(tensor, world_size):
  67. if world_size == 1:
  68. return tensor
  69. rt = tensor.detach().clone()
  70. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  71. return rt.true_divide(world_size)
  72. def adjust_fine_tuning_lr(args, ckpt_d):
  73. assert args.fine_tuning
  74. if args.fine_tune_lr_factor == 1.:
  75. return
  76. for k in ['optim_d', 'optim_g']:
  77. for param_group in ckpt_d[k]['param_groups']:
  78. old_v = param_group['lr']
  79. new_v = old_v * args.fine_tune_lr_factor
  80. print(f'Init fine-tuning: changing {k} lr: {old_v} --> {new_v}')
  81. param_group['lr'] = new_v
  82. def apply_ema_decay(model, ema_model, decay):
  83. if not decay:
  84. return
  85. st = model.state_dict()
  86. add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
  87. for k, v in ema_model.state_dict().items():
  88. if add_module and not k.startswith('module.'):
  89. k = 'module.' + k
  90. v.copy_(decay * v + (1 - decay) * st[k])
  91. def init_multi_tensor_ema(model, ema_model):
  92. model_weights = list(model.state_dict().values())
  93. ema_model_weights = list(ema_model.state_dict().values())
  94. ema_overflow_buf = torch.cuda.IntTensor([0])
  95. return model_weights, ema_model_weights, ema_overflow_buf
  96. def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
  97. amp_C.multi_tensor_axpby(
  98. 65536, overflow_buf, [ema_weights, model_weights, ema_weights],
  99. decay, 1-decay, -1)
  100. def init_distributed(args, world_size, rank):
  101. assert torch.cuda.is_available(), "Distributed mode requires CUDA."
  102. print(f"{args.local_rank}: Initializing distributed training")
  103. # Set cuda device so everything is done on the right GPU.
  104. torch.cuda.set_device(rank % torch.cuda.device_count())
  105. # Initialize distributed communication
  106. dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
  107. init_method='env://')
  108. print(f"{args.local_rank}: Done initializing distributed training")
  109. def load_wav(full_path, torch_tensor=False):
  110. data, sampling_rate = soundfile.read(full_path, dtype='int16')
  111. if torch_tensor:
  112. return torch.FloatTensor(data.astype(np.float32)), sampling_rate
  113. else:
  114. return data, sampling_rate
  115. def load_wav_to_torch(full_path, force_sampling_rate=None):
  116. if force_sampling_rate is not None:
  117. raise NotImplementedError
  118. return load_wav(full_path, True)
  119. def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
  120. def split_line(root, line):
  121. parts = line.strip().split(split)
  122. if len(parts) == 1:
  123. paths, non_paths = parts, []
  124. else:
  125. if has_speakers:
  126. paths, non_paths = parts[:-2], parts[-2:]
  127. else:
  128. paths, non_paths = parts[:-1], parts[-1:]
  129. return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
  130. fpaths_and_text = []
  131. for fname in fnames:
  132. with open(fname, encoding='utf-8') as f:
  133. fpaths_and_text += [split_line(dataset_path, line) for line in f]
  134. return fpaths_and_text
  135. def to_gpu(x):
  136. x = x.contiguous()
  137. return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
  138. def l2_promote():
  139. _libcudart = ctypes.CDLL('libcudart.so')
  140. # Set device limit on the current device
  141. # cudaLimitMaxL2FetchGranularity = 0x05
  142. pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
  143. _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
  144. _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
  145. assert pValue.contents.value == 128
  146. def prepare_tmp(path):
  147. if path is None:
  148. return
  149. p = Path(path)
  150. if p.is_dir():
  151. warnings.warn(f'{p} exists. Removing...')
  152. shutil.rmtree(p, ignore_errors=True)
  153. p.mkdir(parents=False, exist_ok=False)
  154. def print_once(*msg):
  155. if not dist.is_initialized() or dist.get_rank() == 0:
  156. print(*msg)
  157. def plot_spectrogram(spectrogram):
  158. fig, ax = plt.subplots(figsize=(10, 2))
  159. im = ax.imshow(spectrogram, aspect="auto", origin="lower",
  160. interpolation='none')
  161. plt.colorbar(im, ax=ax)
  162. fig.canvas.draw()
  163. plt.close()
  164. return fig
  165. def init_weights(m, mean=0.0, std=0.01):
  166. classname = m.__class__.__name__
  167. if classname.find("Conv") != -1:
  168. m.weight.data.normal_(mean, std)
  169. def get_padding(kernel_size, dilation=1):
  170. return int((kernel_size*dilation - dilation)/2)
  171. class AttrDict(dict):
  172. def __init__(self, *args, **kwargs):
  173. super(AttrDict, self).__init__(*args, **kwargs)
  174. self.__dict__ = self
  175. class DefaultAttrDict(defaultdict):
  176. def __init__(self, *args, **kwargs):
  177. super(DefaultAttrDict, self).__init__(*args, **kwargs)
  178. self.__dict__ = self
  179. def __getattr__(self, item):
  180. return self[item]
  181. class Checkpointer:
  182. def __init__(self, save_dir,
  183. keep_milestones=[1000, 2000, 3000, 4000, 5000, 6000]):
  184. self.save_dir = save_dir
  185. self.keep_milestones = keep_milestones
  186. find = lambda name: {int(re.search('_(\d+).pt', fn).group(1)): fn
  187. for fn in glob.glob(f'{save_dir}/{name}_checkpoint_*.pt')}
  188. saved_g = find('hifigan_gen')
  189. saved_d = find('hifigan_discrim')
  190. common_epochs = sorted(set(saved_g.keys()) & set(saved_d.keys()))
  191. self.tracked = OrderedDict([(ep, (saved_g[ep], saved_d[ep]))
  192. for ep in common_epochs])
  193. def maybe_load(self, gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d,
  194. train_state, args, gen_ema=None, mpd_ema=None, msd_ema=None):
  195. fpath_g = args.checkpoint_path_gen
  196. fpath_d = args.checkpoint_path_discrim
  197. assert (fpath_g is None) == (fpath_d is None)
  198. if fpath_g is not None:
  199. ckpt_paths = [(fpath_g, fpath_d)]
  200. self.tracked = OrderedDict() # Do not track/delete prev ckpts
  201. elif args.resume:
  202. ckpt_paths = list(reversed(self.tracked.values()))[:2]
  203. else:
  204. return
  205. ckpt_g = None
  206. ckpt_d = None
  207. for fpath_g, fpath_d in ckpt_paths:
  208. if args.local_rank == 0:
  209. print(f'Loading models from {fpath_g} {fpath_d}')
  210. try:
  211. ckpt_g = torch.load(fpath_g, map_location='cpu')
  212. ckpt_d = torch.load(fpath_d, map_location='cpu')
  213. break
  214. except:
  215. print(f'WARNING: Cannot load {fpath_g} and {fpath_d}')
  216. if ckpt_g is None or ckpt_d is None:
  217. return
  218. ep_g = ckpt_g.get('train_state', ckpt_g).get('epoch', None)
  219. ep_d = ckpt_d.get('train_state', ckpt_d).get('epoch', None)
  220. assert ep_g == ep_d, \
  221. f'Mismatched epochs of gen and discrim ({ep_g} != {ep_d})'
  222. train_state.update(ckpt_g['train_state'])
  223. fine_tune_epoch_start = train_state.get('fine_tune_epoch_start')
  224. if args.fine_tuning and fine_tune_epoch_start is None:
  225. # Fine-tuning just began
  226. train_state['fine_tune_epoch_start'] = train_state['epoch'] + 1
  227. train_state['fine_tune_lr_factor'] = args.fine_tune_lr_factor
  228. adjust_fine_tuning_lr(args, ckpt_d)
  229. unwrap = lambda m: getattr(m, 'module', m)
  230. unwrap(gen).load_state_dict(ckpt_g.get('gen', ckpt_g['generator']))
  231. unwrap(mpd).load_state_dict(ckpt_d['mpd'])
  232. unwrap(msd).load_state_dict(ckpt_d['msd'])
  233. optim_g.load_state_dict(ckpt_d['optim_g'])
  234. optim_d.load_state_dict(ckpt_d['optim_d'])
  235. if 'scaler_g' in ckpt_d:
  236. scaler_g.load_state_dict(ckpt_d['scaler_g'])
  237. scaler_d.load_state_dict(ckpt_d['scaler_d'])
  238. else:
  239. warnings.warn('No grad scaler state found in the checkpoint.')
  240. if gen_ema is not None:
  241. gen_ema.load_state_dict(ckpt_g['gen_ema'])
  242. if mpd_ema is not None:
  243. mpd_ema.load_state_dict(ckpt_d['mpd_ema'])
  244. if msd_ema is not None:
  245. msd_ema.load_state_dict(ckpt_d['msd_ema'])
  246. def maybe_save(self, gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d,
  247. epoch, train_state, args, gen_config, train_setup,
  248. gen_ema=None, mpd_ema=None, msd_ema=None):
  249. rank = 0
  250. if dist.is_initialized():
  251. dist.barrier()
  252. rank = dist.get_rank()
  253. if rank != 0:
  254. return
  255. if epoch == 0:
  256. return
  257. if epoch < args.epochs and (args.checkpoint_interval == 0
  258. or epoch % args.checkpoint_interval > 0):
  259. return
  260. unwrap = lambda m: getattr(m, 'module', m)
  261. fpath_g = Path(self.save_dir, f'hifigan_gen_checkpoint_{epoch}.pt')
  262. ckpt_g = {
  263. 'generator': unwrap(gen).state_dict(),
  264. 'gen_ema': gen_ema.state_dict() if gen_ema is not None else None,
  265. 'config': gen_config,
  266. 'train_setup': train_setup,
  267. 'train_state': train_state,
  268. }
  269. fpath_d = Path(self.save_dir, f'hifigan_discrim_checkpoint_{epoch}.pt')
  270. ckpt_d = {
  271. 'mpd': unwrap(mpd).state_dict(),
  272. 'msd': unwrap(msd).state_dict(),
  273. 'mpd_ema': mpd_ema.state_dict() if mpd_ema is not None else None,
  274. 'msd_ema': msd_ema.state_dict() if msd_ema is not None else None,
  275. 'optim_g': optim_g.state_dict(),
  276. 'optim_d': optim_d.state_dict(),
  277. 'scaler_g': scaler_g.state_dict(),
  278. 'scaler_d': scaler_d.state_dict(),
  279. 'train_state': train_state,
  280. # compat with original code
  281. 'steps': train_state['iters_all'],
  282. 'epoch': epoch,
  283. }
  284. print(f"Saving model and optimizer state to {fpath_g} and {fpath_d}")
  285. torch.save(ckpt_g, fpath_g)
  286. torch.save(ckpt_d, fpath_d)
  287. # Remove old checkpoints; keep milestones and the last two
  288. self.tracked[epoch] = (fpath_g, fpath_d)
  289. for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
  290. try:
  291. os.remove(self.tracked[epoch][0])
  292. os.remove(self.tracked[epoch][1])
  293. del self.tracked[epoch]
  294. except:
  295. pass