| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # MIT License
- #
- # Copyright (c) 2020 Jungil Kong
- #
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE.
- # The following functions/classes were based on code from https://github.com/jik876/hifi-gan:
- # plot_spectrogram, init_weights, get_padding, AttrDict
- import ctypes
- import glob
- import os
- import re
- import shutil
- import warnings
- from collections import defaultdict, OrderedDict
- from pathlib import Path
- from typing import Optional
- import soundfile # flac
- import amp_C
- import matplotlib
- matplotlib.use("Agg")
- import matplotlib.pylab as plt
- import numpy as np
- import torch
- import torch.distributed as dist
- def mask_from_lens(lens, max_len: Optional[int] = None):
- if max_len is None:
- max_len = lens.max()
- ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
- mask = torch.lt(ids, lens.unsqueeze(1))
- return mask
- def freeze(model):
- for p in model.parameters():
- p.requires_grad = False
- def unfreeze(model):
- for p in model.parameters():
- p.requires_grad = True
- def reduce_tensor(tensor, world_size):
- if world_size == 1:
- return tensor
- rt = tensor.detach().clone()
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
- return rt.true_divide(world_size)
- def adjust_fine_tuning_lr(args, ckpt_d):
- assert args.fine_tuning
- if args.fine_tune_lr_factor == 1.:
- return
- for k in ['optim_d', 'optim_g']:
- for param_group in ckpt_d[k]['param_groups']:
- old_v = param_group['lr']
- new_v = old_v * args.fine_tune_lr_factor
- print(f'Init fine-tuning: changing {k} lr: {old_v} --> {new_v}')
- param_group['lr'] = new_v
- def apply_ema_decay(model, ema_model, decay):
- if not decay:
- return
- st = model.state_dict()
- add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
- for k, v in ema_model.state_dict().items():
- if add_module and not k.startswith('module.'):
- k = 'module.' + k
- v.copy_(decay * v + (1 - decay) * st[k])
- def init_multi_tensor_ema(model, ema_model):
- model_weights = list(model.state_dict().values())
- ema_model_weights = list(ema_model.state_dict().values())
- ema_overflow_buf = torch.cuda.IntTensor([0])
- return model_weights, ema_model_weights, ema_overflow_buf
- def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
- amp_C.multi_tensor_axpby(
- 65536, overflow_buf, [ema_weights, model_weights, ema_weights],
- decay, 1-decay, -1)
- def init_distributed(args, world_size, rank):
- assert torch.cuda.is_available(), "Distributed mode requires CUDA."
- print(f"{args.local_rank}: Initializing distributed training")
- # Set cuda device so everything is done on the right GPU.
- torch.cuda.set_device(rank % torch.cuda.device_count())
- # Initialize distributed communication
- dist.init_process_group(backend=('nccl' if args.cuda else 'gloo'),
- init_method='env://')
- print(f"{args.local_rank}: Done initializing distributed training")
- def load_wav(full_path, torch_tensor=False):
- data, sampling_rate = soundfile.read(full_path, dtype='int16')
- if torch_tensor:
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
- else:
- return data, sampling_rate
- def load_wav_to_torch(full_path, force_sampling_rate=None):
- if force_sampling_rate is not None:
- raise NotImplementedError
- return load_wav(full_path, True)
- def load_filepaths_and_text(dataset_path, fnames, has_speakers=False, split="|"):
- def split_line(root, line):
- parts = line.strip().split(split)
- if len(parts) == 1:
- paths, non_paths = parts, []
- else:
- if has_speakers:
- paths, non_paths = parts[:-2], parts[-2:]
- else:
- paths, non_paths = parts[:-1], parts[-1:]
- return tuple(str(Path(root, p)) for p in paths) + tuple(non_paths)
- fpaths_and_text = []
- for fname in fnames:
- with open(fname, encoding='utf-8') as f:
- fpaths_and_text += [split_line(dataset_path, line) for line in f]
- return fpaths_and_text
- def to_gpu(x):
- x = x.contiguous()
- return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
- def l2_promote():
- _libcudart = ctypes.CDLL('libcudart.so')
- # Set device limit on the current device
- # cudaLimitMaxL2FetchGranularity = 0x05
- pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))
- _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
- _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
- assert pValue.contents.value == 128
- def prepare_tmp(path):
- if path is None:
- return
- p = Path(path)
- if p.is_dir():
- warnings.warn(f'{p} exists. Removing...')
- shutil.rmtree(p, ignore_errors=True)
- p.mkdir(parents=False, exist_ok=False)
- def print_once(*msg):
- if not dist.is_initialized() or dist.get_rank() == 0:
- print(*msg)
- def plot_spectrogram(spectrogram):
- fig, ax = plt.subplots(figsize=(10, 2))
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
- interpolation='none')
- plt.colorbar(im, ax=ax)
- fig.canvas.draw()
- plt.close()
- return fig
- def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- m.weight.data.normal_(mean, std)
- def get_padding(kernel_size, dilation=1):
- return int((kernel_size*dilation - dilation)/2)
- class AttrDict(dict):
- def __init__(self, *args, **kwargs):
- super(AttrDict, self).__init__(*args, **kwargs)
- self.__dict__ = self
- class DefaultAttrDict(defaultdict):
- def __init__(self, *args, **kwargs):
- super(DefaultAttrDict, self).__init__(*args, **kwargs)
- self.__dict__ = self
- def __getattr__(self, item):
- return self[item]
- class Checkpointer:
- def __init__(self, save_dir,
- keep_milestones=[1000, 2000, 3000, 4000, 5000, 6000]):
- self.save_dir = save_dir
- self.keep_milestones = keep_milestones
- find = lambda name: {int(re.search('_(\d+).pt', fn).group(1)): fn
- for fn in glob.glob(f'{save_dir}/{name}_checkpoint_*.pt')}
- saved_g = find('hifigan_gen')
- saved_d = find('hifigan_discrim')
- common_epochs = sorted(set(saved_g.keys()) & set(saved_d.keys()))
- self.tracked = OrderedDict([(ep, (saved_g[ep], saved_d[ep]))
- for ep in common_epochs])
- def maybe_load(self, gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d,
- train_state, args, gen_ema=None, mpd_ema=None, msd_ema=None):
- fpath_g = args.checkpoint_path_gen
- fpath_d = args.checkpoint_path_discrim
- assert (fpath_g is None) == (fpath_d is None)
- if fpath_g is not None:
- ckpt_paths = [(fpath_g, fpath_d)]
- self.tracked = OrderedDict() # Do not track/delete prev ckpts
- elif args.resume:
- ckpt_paths = list(reversed(self.tracked.values()))[:2]
- else:
- return
- ckpt_g = None
- ckpt_d = None
- for fpath_g, fpath_d in ckpt_paths:
- if args.local_rank == 0:
- print(f'Loading models from {fpath_g} {fpath_d}')
- try:
- ckpt_g = torch.load(fpath_g, map_location='cpu')
- ckpt_d = torch.load(fpath_d, map_location='cpu')
- break
- except:
- print(f'WARNING: Cannot load {fpath_g} and {fpath_d}')
- if ckpt_g is None or ckpt_d is None:
- return
- ep_g = ckpt_g.get('train_state', ckpt_g).get('epoch', None)
- ep_d = ckpt_d.get('train_state', ckpt_d).get('epoch', None)
- assert ep_g == ep_d, \
- f'Mismatched epochs of gen and discrim ({ep_g} != {ep_d})'
- train_state.update(ckpt_g['train_state'])
- fine_tune_epoch_start = train_state.get('fine_tune_epoch_start')
- if args.fine_tuning and fine_tune_epoch_start is None:
- # Fine-tuning just began
- train_state['fine_tune_epoch_start'] = train_state['epoch'] + 1
- train_state['fine_tune_lr_factor'] = args.fine_tune_lr_factor
- adjust_fine_tuning_lr(args, ckpt_d)
- unwrap = lambda m: getattr(m, 'module', m)
- unwrap(gen).load_state_dict(ckpt_g.get('gen', ckpt_g['generator']))
- unwrap(mpd).load_state_dict(ckpt_d['mpd'])
- unwrap(msd).load_state_dict(ckpt_d['msd'])
- optim_g.load_state_dict(ckpt_d['optim_g'])
- optim_d.load_state_dict(ckpt_d['optim_d'])
- if 'scaler_g' in ckpt_d:
- scaler_g.load_state_dict(ckpt_d['scaler_g'])
- scaler_d.load_state_dict(ckpt_d['scaler_d'])
- else:
- warnings.warn('No grad scaler state found in the checkpoint.')
- if gen_ema is not None:
- gen_ema.load_state_dict(ckpt_g['gen_ema'])
- if mpd_ema is not None:
- mpd_ema.load_state_dict(ckpt_d['mpd_ema'])
- if msd_ema is not None:
- msd_ema.load_state_dict(ckpt_d['msd_ema'])
- def maybe_save(self, gen, mpd, msd, optim_g, optim_d, scaler_g, scaler_d,
- epoch, train_state, args, gen_config, train_setup,
- gen_ema=None, mpd_ema=None, msd_ema=None):
- rank = 0
- if dist.is_initialized():
- dist.barrier()
- rank = dist.get_rank()
- if rank != 0:
- return
- if epoch == 0:
- return
- if epoch < args.epochs and (args.checkpoint_interval == 0
- or epoch % args.checkpoint_interval > 0):
- return
- unwrap = lambda m: getattr(m, 'module', m)
- fpath_g = Path(self.save_dir, f'hifigan_gen_checkpoint_{epoch}.pt')
- ckpt_g = {
- 'generator': unwrap(gen).state_dict(),
- 'gen_ema': gen_ema.state_dict() if gen_ema is not None else None,
- 'config': gen_config,
- 'train_setup': train_setup,
- 'train_state': train_state,
- }
- fpath_d = Path(self.save_dir, f'hifigan_discrim_checkpoint_{epoch}.pt')
- ckpt_d = {
- 'mpd': unwrap(mpd).state_dict(),
- 'msd': unwrap(msd).state_dict(),
- 'mpd_ema': mpd_ema.state_dict() if mpd_ema is not None else None,
- 'msd_ema': msd_ema.state_dict() if msd_ema is not None else None,
- 'optim_g': optim_g.state_dict(),
- 'optim_d': optim_d.state_dict(),
- 'scaler_g': scaler_g.state_dict(),
- 'scaler_d': scaler_d.state_dict(),
- 'train_state': train_state,
- # compat with original code
- 'steps': train_state['iters_all'],
- 'epoch': epoch,
- }
- print(f"Saving model and optimizer state to {fpath_g} and {fpath_d}")
- torch.save(ckpt_g, fpath_g)
- torch.save(ckpt_d, fpath_d)
- # Remove old checkpoints; keep milestones and the last two
- self.tracked[epoch] = (fpath_g, fpath_d)
- for epoch in set(list(self.tracked)[:-2]) - set(self.keep_milestones):
- try:
- os.remove(self.tracked[epoch][0])
- os.remove(self.tracked[epoch][1])
- del self.tracked[epoch]
- except:
- pass
|