| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560 |
- # *****************************************************************************
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- # * Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # * Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # * Neither the name of the NVIDIA CORPORATION nor the
- # names of its contributors may be used to endorse or promote products
- # derived from this software without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- #
- # *****************************************************************************
- import argparse
- import copy
- import json
- import glob
- import os
- import re
- import time
- from collections import defaultdict, OrderedDict
- from contextlib import contextmanager
- import torch
- import numpy as np
- import torch.distributed as dist
- from scipy.io.wavfile import write as write_wav
- from torch.autograd import Variable
- from torch.nn.parameter import Parameter
- from torch.utils.data import DataLoader
- from torch.utils.data.distributed import DistributedSampler
- import dllogger as DLLogger
- from apex import amp
- from apex.optimizers import FusedAdam, FusedLAMB
- from apex.parallel import DistributedDataParallel as DDP
- import common
- import data_functions
- import loss_functions
- import models
- from common.log_helper import init_dllogger, TBLogger
- def parse_args(parser):
- """
- Parse commandline arguments.
- """
- parser.add_argument('-o', '--output', type=str, required=True,
- help='Directory to save checkpoints')
- parser.add_argument('-d', '--dataset-path', type=str, default='./',
- help='Path to dataset')
- parser.add_argument('--log-file', type=str, default='nvlog.json',
- help='Filename for logging')
- training = parser.add_argument_group('training setup')
- training.add_argument('--epochs', type=int, required=True,
- help='Number of total epochs to run')
- training.add_argument('--epochs-per-checkpoint', type=int, default=50,
- help='Number of epochs per checkpoint')
- training.add_argument('--checkpoint-path', type=str, default=None,
- help='Checkpoint path to resume training')
- training.add_argument('--checkpoint-resume', action='store_true',
- help='Resume training from the last available checkpoint')
- training.add_argument('--seed', type=int, default=1234,
- help='Seed for PyTorch random number generators')
- training.add_argument('--amp-run', action='store_true',
- help='Enable AMP')
- training.add_argument('--cuda', action='store_true',
- help='Run on GPU using CUDA')
- training.add_argument('--cudnn-enabled', action='store_true',
- help='Enable cudnn')
- training.add_argument('--cudnn-benchmark', action='store_true',
- help='Run cudnn benchmark')
- training.add_argument('--ema-decay', type=float, default=0,
- help='Discounting factor for training weights EMA')
- training.add_argument('--gradient-accumulation-steps', type=int, default=1,
- help='Training steps to accumulate gradients for')
- optimization = parser.add_argument_group('optimization setup')
- optimization.add_argument('--optimizer', type=str, default='lamb',
- help='Optimization algorithm')
- optimization.add_argument('-lr', '--learning-rate', type=float, required=True,
- help='Learing rate')
- optimization.add_argument('--weight-decay', default=1e-6, type=float,
- help='Weight decay')
- optimization.add_argument('--grad-clip-thresh', default=1000.0, type=float,
- help='Clip threshold for gradients')
- optimization.add_argument('-bs', '--batch-size', type=int, required=True,
- help='Batch size per GPU')
- optimization.add_argument('--warmup-steps', type=int, default=1000,
- help='Number of steps for lr warmup')
- optimization.add_argument('--dur-predictor-loss-scale', type=float,
- default=1.0, help='Rescale duration predictor loss')
- optimization.add_argument('--pitch-predictor-loss-scale', type=float,
- default=1.0, help='Rescale pitch predictor loss')
- dataset = parser.add_argument_group('dataset parameters')
- dataset.add_argument('--training-files', type=str, required=True,
- help='Path to training filelist')
- dataset.add_argument('--validation-files', type=str, required=True,
- help='Path to validation filelist')
- dataset.add_argument('--pitch-mean-std-file', type=str, default=None,
- help='Path to pitch stats to be stored in the model')
- dataset.add_argument('--text-cleaners', nargs='*',
- default=['english_cleaners'], type=str,
- help='Type of text cleaners for input text')
- distributed = parser.add_argument_group('distributed setup')
- distributed.add_argument('--rank', default=0, type=int,
- help='Rank of the process for multiproc. Do not set manually.')
- distributed.add_argument('--world-size', default=1, type=int,
- help='Number of processes for multiproc. Do not set manually.')
- distributed.add_argument('--dist-url', type=str, default='tcp://localhost:23456',
- help='Url used to set up distributed training')
- distributed.add_argument('--group-name', type=str, default='group_name',
- required=False, help='Distributed group name')
- distributed.add_argument('--dist-backend', default='nccl', type=str, choices={'nccl'},
- help='Distributed run backend')
- return parser
- def reduce_tensor(tensor, num_gpus):
- rt = tensor.clone()
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
- rt /= num_gpus
- return rt
- def init_distributed(args, world_size, rank, group_name):
- assert torch.cuda.is_available(), "Distributed mode requires CUDA."
- print("Initializing distributed training")
- # Set cuda device so everything is done on the right GPU.
- torch.cuda.set_device(rank % torch.cuda.device_count())
- # Initialize distributed communication
- dist.init_process_group(
- backend=args.dist_backend, init_method=args.dist_url,
- world_size=world_size, rank=rank, group_name=group_name)
- print("Done initializing distributed training")
- def last_checkpoint(output):
- def corrupted(fpath):
- try:
- torch.load(fpath, map_location='cpu')
- return False
- except:
- print(f'WARNING: Cannot load {fpath}')
- return True
- saved = sorted(
- glob.glob(f'{output}/FastPitch_checkpoint_*.pt'),
- key=lambda f: int(re.search('_(\d+).pt', f).group(1)))
- if len(saved) >= 1 and not corrupted(saved[-1]):
- return saved[-1]
- elif len(saved) >= 2:
- return saved[-2]
- else:
- return None
- def save_checkpoint(local_rank, model, ema_model, optimizer, epoch, config,
- amp_run, filepath):
- if local_rank != 0:
- return
- print(f"Saving model and optimizer state at epoch {epoch} to {filepath}")
- ema_dict = None if ema_model is None else ema_model.state_dict()
- checkpoint = {'epoch': epoch,
- 'config': config,
- 'state_dict': model.state_dict(),
- 'ema_state_dict': ema_dict,
- 'optimizer': optimizer.state_dict()}
- if amp_run:
- checkpoint['amp'] = amp.state_dict()
- torch.save(checkpoint, filepath)
- def load_checkpoint(local_rank, model, ema_model, optimizer, epoch, config,
- amp_run, filepath, world_size):
- if local_rank == 0:
- print(f'Loading model and optimizer state from {filepath}')
- checkpoint = torch.load(filepath, map_location='cpu')
- epoch[0] = checkpoint['epoch'] + 1
- config = checkpoint['config']
- sd = {k.replace('module.', ''): v
- for k, v in checkpoint['state_dict'].items()}
- getattr(model, 'module', model).load_state_dict(sd)
- optimizer.load_state_dict(checkpoint['optimizer'])
- if amp_run:
- amp.load_state_dict(checkpoint['amp'])
- if ema_model is not None:
- ema_model.load_state_dict(checkpoint['ema_state_dict'])
- def validate(model, criterion, valset, batch_size, world_size, collate_fn,
- distributed_run, rank, batch_to_gpu, use_gt_durations=False):
- """Handles all the validation scoring and printing"""
- was_training = model.training
- model.eval()
- with torch.no_grad():
- val_sampler = DistributedSampler(valset) if distributed_run else None
- val_loader = DataLoader(valset, num_workers=8, shuffle=False,
- sampler=val_sampler,
- batch_size=batch_size, pin_memory=False,
- collate_fn=collate_fn)
- val_meta = defaultdict(float)
- val_num_frames = 0
- for i, batch in enumerate(val_loader):
- x, y, num_frames = batch_to_gpu(batch)
- y_pred = model(x, use_gt_durations=use_gt_durations)
- loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
- if distributed_run:
- for k,v in meta.items():
- val_meta[k] += reduce_tensor(v, 1)
- val_num_frames += reduce_tensor(num_frames.data, 1).item()
- else:
- for k,v in meta.items():
- val_meta[k] += v
- val_num_frames = num_frames.item()
- val_meta = {k: v / len(valset) for k,v in val_meta.items()}
- val_loss = val_meta['loss']
- if was_training:
- model.train()
- return val_loss.item(), val_meta, val_num_frames
- def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
- if warmup_iters == 0:
- scale = 1.0
- elif total_iter > warmup_iters:
- scale = 1. / (total_iter ** 0.5)
- else:
- scale = total_iter / (warmup_iters ** 1.5)
- for param_group in opt.param_groups:
- param_group['lr'] = learning_rate * scale
- def apply_ema_decay(model, ema_model, decay):
- if not decay:
- return
- st = model.state_dict()
- add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
- for k,v in ema_model.state_dict().items():
- if add_module and not k.startswith('module.'):
- k = 'module.' + k
- v.copy_(decay * v + (1 - decay) * st[k])
- def main():
- parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
- allow_abbrev=False)
- parser = parse_args(parser)
- args, _ = parser.parse_known_args()
- if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- local_rank = int(os.environ['LOCAL_RANK'])
- world_size = int(os.environ['WORLD_SIZE'])
- else:
- local_rank = args.rank
- world_size = args.world_size
- distributed_run = world_size > 1
- torch.manual_seed(args.seed + local_rank)
- np.random.seed(args.seed + local_rank)
- if local_rank == 0:
- if not os.path.exists(args.output):
- os.makedirs(args.output)
- init_dllogger(args.log_file)
- else:
- init_dllogger(dummy=True)
- for k,v in vars(args).items():
- DLLogger.log(step="PARAMETER", data={k:v})
- parser = models.parse_model_args('FastPitch', parser)
- args, unk_args = parser.parse_known_args()
- if len(unk_args) > 0:
- raise ValueError(f'Invalid options {unk_args}')
- torch.backends.cudnn.enabled = args.cudnn_enabled
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
- if distributed_run:
- init_distributed(args, world_size, local_rank, args.group_name)
- device = torch.device('cuda' if args.cuda else 'cpu')
- model_config = models.get_model_config('FastPitch', args)
- model = models.get_model('FastPitch', model_config, device)
- # Store pitch mean/std as params to translate from Hz during inference
- fpath = common.utils.stats_filename(
- args.dataset_path, args.training_files, 'pitch_char')
- with open(args.pitch_mean_std_file, 'r') as f:
- stats = json.load(f)
- model.pitch_mean[0] = stats['mean']
- model.pitch_std[0] = stats['std']
- kw = dict(lr=args.learning_rate, betas=(0.9, 0.98), eps=1e-9,
- weight_decay=args.weight_decay)
- if args.optimizer == 'adam':
- optimizer = FusedAdam(model.parameters(), **kw)
- elif args.optimizer == 'lamb':
- optimizer = FusedLAMB(model.parameters(), **kw)
- else:
- raise ValueError
- if args.amp_run:
- model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
- if args.ema_decay > 0:
- ema_model = copy.deepcopy(model)
- else:
- ema_model = None
- if distributed_run:
- model = DDP(model)
- start_epoch = [1]
- assert args.checkpoint_path is None or args.checkpoint_resume is False, (
- "Specify a single checkpoint source")
- if args.checkpoint_path is not None:
- ch_fpath = args.checkpoint_path
- elif args.checkpoint_resume:
- ch_fpath = last_checkpoint(args.output)
- else:
- ch_fpath = None
- if ch_fpath is not None:
- load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch,
- model_config, args.amp_run, ch_fpath, world_size)
- start_epoch = start_epoch[0]
- criterion = loss_functions.get_loss_function('FastPitch',
- dur_predictor_loss_scale=args.dur_predictor_loss_scale,
- pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)
- collate_fn = data_functions.get_collate_function('FastPitch')
- trainset = data_functions.get_data_loader('FastPitch', args.dataset_path,
- args.training_files, args)
- valset = data_functions.get_data_loader('FastPitch', args.dataset_path,
- args.validation_files, args)
- if distributed_run:
- train_sampler, shuffle = DistributedSampler(trainset), False
- else:
- train_sampler, shuffle = None, True
- train_loader = DataLoader(trainset, num_workers=16, shuffle=shuffle,
- sampler=train_sampler, batch_size=args.batch_size,
- pin_memory=False, drop_last=True,
- collate_fn=collate_fn)
- batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')
- model.train()
- train_tblogger = TBLogger(local_rank, args.output, 'train')
- val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True)
- if args.ema_decay > 0:
- val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema')
- val_loss = 0.0
- total_iter = 0
- torch.cuda.synchronize()
- for epoch in range(start_epoch, args.epochs + 1):
- epoch_start_time = time.time()
- epoch_loss = 0.0
- epoch_mel_loss = 0.0
- epoch_num_frames = 0
- epoch_frames_per_sec = 0.0
- if distributed_run:
- train_loader.sampler.set_epoch(epoch)
- accumulated_steps = 0
- iter_loss = 0
- iter_num_frames = 0
- iter_meta = {}
- epoch_iter = 0
- num_iters = len(train_loader) // args.gradient_accumulation_steps
- for batch in train_loader:
- if accumulated_steps == 0:
- if epoch_iter == num_iters:
- break
- total_iter += 1
- epoch_iter += 1
- iter_start_time = time.time()
- start = time.perf_counter()
- old_lr = optimizer.param_groups[0]['lr']
- adjust_learning_rate(total_iter, optimizer, args.learning_rate,
- args.warmup_steps)
- new_lr = optimizer.param_groups[0]['lr']
- if new_lr != old_lr:
- dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}'
- train_tblogger.log_value(total_iter, 'lrate', new_lr)
- else:
- dllog_lrate_change = None
- model.zero_grad()
- x, y, num_frames = batch_to_gpu(batch)
- y_pred = model(x, use_gt_durations=True)
- loss, meta = criterion(y_pred, y)
- loss /= args.gradient_accumulation_steps
- meta = {k: v / args.gradient_accumulation_steps
- for k, v in meta.items()}
- if args.amp_run:
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- else:
- loss.backward()
- if distributed_run:
- reduced_loss = reduce_tensor(loss.data, world_size).item()
- reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
- meta = {k: reduce_tensor(v, world_size) for k,v in meta.items()}
- else:
- reduced_loss = loss.item()
- reduced_num_frames = num_frames.item()
- if np.isnan(reduced_loss):
- raise Exception("loss is NaN")
- accumulated_steps += 1
- iter_loss += reduced_loss
- iter_num_frames += reduced_num_frames
- iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
- if accumulated_steps % args.gradient_accumulation_steps == 0:
- train_tblogger.log_grads(total_iter, model)
- if args.amp_run:
- torch.nn.utils.clip_grad_norm_(
- amp.master_params(optimizer), args.grad_clip_thresh)
- else:
- torch.nn.utils.clip_grad_norm_(
- model.parameters(), args.grad_clip_thresh)
- optimizer.step()
- apply_ema_decay(model, ema_model, args.ema_decay)
- iter_stop_time = time.time()
- iter_time = iter_stop_time - iter_start_time
- frames_per_sec = iter_num_frames / iter_time
- epoch_frames_per_sec += frames_per_sec
- epoch_loss += iter_loss
- epoch_num_frames += iter_num_frames
- iter_mel_loss = iter_meta['mel_loss'].item()
- epoch_mel_loss += iter_mel_loss
- DLLogger.log((epoch, epoch_iter, num_iters), OrderedDict([
- ('train_loss', iter_loss), ('train_mel_loss', iter_mel_loss),
- ('train_frames/s', frames_per_sec), ('took', iter_time),
- ('lrate_change', dllog_lrate_change)
- ]))
- train_tblogger.log_meta(total_iter, iter_meta)
- accumulated_steps = 0
- iter_loss = 0
- iter_num_frames = 0
- iter_meta = {}
- # Finished epoch
- epoch_stop_time = time.time()
- epoch_time = epoch_stop_time - epoch_start_time
- DLLogger.log((epoch,), data=OrderedDict([
- ('avg_train_loss', epoch_loss / epoch_iter),
- ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
- ('avg_train_frames/s', epoch_num_frames / epoch_time),
- ('took', epoch_time)
- ]))
- tik = time.time()
- val_loss, meta, num_frames = validate(
- model, criterion, valset, args.batch_size, world_size, collate_fn,
- distributed_run, local_rank, batch_to_gpu, use_gt_durations=True)
- tok = time.time()
- DLLogger.log((epoch,), data=OrderedDict([
- ('val_loss', val_loss),
- ('val_mel_loss', meta['mel_loss'].item()),
- ('val_frames/s', num_frames / (tok - tik)),
- ('took', tok - tik),
- ]))
- val_tblogger.log_meta(total_iter, meta)
- if args.ema_decay > 0:
- tik_e = time.time()
- val_loss_e, meta_e, num_frames_e = validate(
- ema_model, criterion, valset, args.batch_size, world_size,
- collate_fn, distributed_run, local_rank, batch_to_gpu,
- use_gt_durations=True)
- tok_e = time.time()
- DLLogger.log((epoch,), data=OrderedDict([
- ('val_ema_loss', val_loss_e),
- ('val_ema_mel_loss', meta_e['mel_loss'].item()),
- ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)),
- ('took', tok_e - tik_e),
- ]))
- val_ema_tblogger.log_meta(total_iter, meta)
- if (epoch > 0 and args.epochs_per_checkpoint > 0 and
- (epoch % args.epochs_per_checkpoint == 0) and local_rank == 0):
- checkpoint_path = os.path.join(
- args.output, f"FastPitch_checkpoint_{epoch}.pt")
- save_checkpoint(local_rank, model, ema_model, optimizer, epoch,
- model_config, args.amp_run, checkpoint_path)
- if local_rank == 0:
- DLLogger.flush()
- # Finished training
- DLLogger.log((), data=OrderedDict([
- ('avg_train_loss', epoch_loss / epoch_iter),
- ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
- ('avg_train_frames/s', epoch_num_frames / epoch_time),
- ]))
- DLLogger.log((), data=OrderedDict([
- ('val_loss', val_loss),
- ('val_mel_loss', meta['mel_loss'].item()),
- ('val_frames/s', num_frames / (tok - tik)),
- ]))
- if local_rank == 0:
- DLLogger.flush()
- if __name__ == '__main__':
- main()
|