|
|
@@ -45,7 +45,7 @@ from torch.nn.parameter import Parameter
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
|
|
-import dllogger as DLLogger
|
|
|
+import common.tb_dllogger as logger
|
|
|
from apex import amp
|
|
|
from apex.optimizers import FusedAdam, FusedLAMB
|
|
|
|
|
|
@@ -53,7 +53,6 @@ import common
|
|
|
import data_functions
|
|
|
import loss_functions
|
|
|
import models
|
|
|
-from common.log_helper import init_dllogger, TBLogger, unique_dllogger_fpath
|
|
|
|
|
|
|
|
|
def parse_args(parser):
|
|
|
@@ -82,10 +81,8 @@ def parse_args(parser):
|
|
|
help='Enable AMP')
|
|
|
training.add_argument('--cuda', action='store_true',
|
|
|
help='Run on GPU using CUDA')
|
|
|
- training.add_argument('--cudnn-enabled', action='store_true',
|
|
|
- help='Enable cudnn')
|
|
|
training.add_argument('--cudnn-benchmark', action='store_true',
|
|
|
- help='Run cudnn benchmark')
|
|
|
+ help='Enable cudnn benchmark mode')
|
|
|
training.add_argument('--ema-decay', type=float, default=0,
|
|
|
help='Discounting factor for training weights EMA')
|
|
|
training.add_argument('--gradient-accumulation-steps', type=int, default=1,
|
|
|
@@ -131,8 +128,7 @@ def parse_args(parser):
|
|
|
def reduce_tensor(tensor, num_gpus):
|
|
|
rt = tensor.clone()
|
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
|
- rt /= num_gpus
|
|
|
- return rt
|
|
|
+ return rt.true_divide(num_gpus)
|
|
|
|
|
|
|
|
|
def init_distributed(args, world_size, rank):
|
|
|
@@ -174,6 +170,7 @@ def save_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
|
|
|
config, amp_run, filepath):
|
|
|
if local_rank != 0:
|
|
|
return
|
|
|
+
|
|
|
print(f"Saving model and optimizer state at epoch {epoch} to {filepath}")
|
|
|
ema_dict = None if ema_model is None else ema_model.state_dict()
|
|
|
checkpoint = {'epoch': epoch,
|
|
|
@@ -208,11 +205,14 @@ def load_checkpoint(local_rank, model, ema_model, optimizer, epoch, total_iter,
|
|
|
ema_model.load_state_dict(checkpoint['ema_state_dict'])
|
|
|
|
|
|
|
|
|
-def validate(model, criterion, valset, batch_size, world_size, collate_fn,
|
|
|
- distributed_run, rank, batch_to_gpu, use_gt_durations=False):
|
|
|
+def validate(model, epoch, total_iter, criterion, valset, batch_size,
|
|
|
+ collate_fn, distributed_run, batch_to_gpu, use_gt_durations=False,
|
|
|
+ ema=False):
|
|
|
"""Handles all the validation scoring and printing"""
|
|
|
was_training = model.training
|
|
|
model.eval()
|
|
|
+
|
|
|
+ tik = time.perf_counter()
|
|
|
with torch.no_grad():
|
|
|
val_sampler = DistributedSampler(valset) if distributed_run else None
|
|
|
val_loader = DataLoader(valset, num_workers=8, shuffle=False,
|
|
|
@@ -225,6 +225,7 @@ def validate(model, criterion, valset, batch_size, world_size, collate_fn,
|
|
|
x, y, num_frames = batch_to_gpu(batch)
|
|
|
y_pred = model(x, use_gt_durations=use_gt_durations)
|
|
|
loss, meta = criterion(y_pred, y, is_training=False, meta_agg='sum')
|
|
|
+
|
|
|
if distributed_run:
|
|
|
for k,v in meta.items():
|
|
|
val_meta[k] += reduce_tensor(v, 1)
|
|
|
@@ -233,12 +234,24 @@ def validate(model, criterion, valset, batch_size, world_size, collate_fn,
|
|
|
for k,v in meta.items():
|
|
|
val_meta[k] += v
|
|
|
val_num_frames = num_frames.item()
|
|
|
+
|
|
|
val_meta = {k: v / len(valset) for k,v in val_meta.items()}
|
|
|
- val_loss = val_meta['loss']
|
|
|
+
|
|
|
+ val_meta['took'] = time.perf_counter() - tik
|
|
|
+
|
|
|
+ logger.log((epoch,) if epoch is not None else (),
|
|
|
+ tb_total_steps=total_iter,
|
|
|
+ subset='val_ema' if ema else 'val',
|
|
|
+ data=OrderedDict([
|
|
|
+ ('loss', val_meta['loss'].item()),
|
|
|
+ ('mel_loss', val_meta['mel_loss'].item()),
|
|
|
+ ('frames/s', num_frames.item() / val_meta['took']),
|
|
|
+ ('took', val_meta['took'])]),
|
|
|
+ )
|
|
|
|
|
|
if was_training:
|
|
|
model.train()
|
|
|
- return val_loss.item(), val_meta, val_num_frames
|
|
|
+ return val_meta
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(total_iter, opt, learning_rate, warmup_iters=None):
|
|
|
@@ -270,39 +283,33 @@ def main():
|
|
|
parser = parse_args(parser)
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
|
|
- if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
|
|
- local_rank = int(os.environ['LOCAL_RANK'])
|
|
|
- world_size = int(os.environ['WORLD_SIZE'])
|
|
|
- else:
|
|
|
- local_rank = args.rank
|
|
|
- world_size = args.world_size
|
|
|
- distributed_run = world_size > 1
|
|
|
+ distributed_run = args.world_size > 1
|
|
|
|
|
|
- torch.manual_seed(args.seed + local_rank)
|
|
|
- np.random.seed(args.seed + local_rank)
|
|
|
+ torch.manual_seed(args.seed + args.local_rank)
|
|
|
+ np.random.seed(args.seed + args.local_rank)
|
|
|
|
|
|
- if local_rank == 0:
|
|
|
+ if args.local_rank == 0:
|
|
|
if not os.path.exists(args.output):
|
|
|
os.makedirs(args.output)
|
|
|
|
|
|
- log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
|
|
|
- log_fpath = unique_dllogger_fpath(log_fpath)
|
|
|
- init_dllogger(log_fpath)
|
|
|
- else:
|
|
|
- init_dllogger(dummy=True)
|
|
|
+ log_fpath = args.log_file or os.path.join(args.output, 'nvlog.json')
|
|
|
+ tb_subsets = ['train', 'val']
|
|
|
+ if args.ema_decay > 0.0:
|
|
|
+ tb_subsets.append('val_ema')
|
|
|
|
|
|
- [DLLogger.log("PARAMETER", {k:v}) for k,v in vars(args).items()]
|
|
|
+ logger.init(log_fpath, args.output, enabled=(args.local_rank == 0),
|
|
|
+ tb_subsets=tb_subsets)
|
|
|
+ logger.parameters(vars(args), tb_subset='train')
|
|
|
|
|
|
parser = models.parse_model_args('FastPitch', parser)
|
|
|
args, unk_args = parser.parse_known_args()
|
|
|
if len(unk_args) > 0:
|
|
|
raise ValueError(f'Invalid options {unk_args}')
|
|
|
|
|
|
- torch.backends.cudnn.enabled = args.cudnn_enabled
|
|
|
torch.backends.cudnn.benchmark = args.cudnn_benchmark
|
|
|
|
|
|
if distributed_run:
|
|
|
- init_distributed(args, world_size, local_rank)
|
|
|
+ init_distributed(args, args.world_size, args.local_rank)
|
|
|
|
|
|
device = torch.device('cuda' if args.cuda else 'cpu')
|
|
|
model_config = models.get_model_config('FastPitch', args)
|
|
|
@@ -351,9 +358,9 @@ def main():
|
|
|
ch_fpath = None
|
|
|
|
|
|
if ch_fpath is not None:
|
|
|
- load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch,
|
|
|
+ load_checkpoint(args.local_rank, model, ema_model, optimizer, start_epoch,
|
|
|
start_iter, model_config, args.amp, ch_fpath,
|
|
|
- world_size)
|
|
|
+ args.world_size)
|
|
|
|
|
|
start_epoch = start_epoch[0]
|
|
|
total_iter = start_iter[0]
|
|
|
@@ -381,15 +388,9 @@ def main():
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
- train_tblogger = TBLogger(local_rank, args.output, 'train')
|
|
|
- val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True)
|
|
|
- if args.ema_decay > 0:
|
|
|
- val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema')
|
|
|
-
|
|
|
- val_loss = 0.0
|
|
|
torch.cuda.synchronize()
|
|
|
for epoch in range(start_epoch, args.epochs + 1):
|
|
|
- epoch_start_time = time.time()
|
|
|
+ epoch_start_time = time.perf_counter()
|
|
|
|
|
|
epoch_loss = 0.0
|
|
|
epoch_mel_loss = 0.0
|
|
|
@@ -407,24 +408,16 @@ def main():
|
|
|
epoch_iter = 0
|
|
|
num_iters = len(train_loader) // args.gradient_accumulation_steps
|
|
|
for batch in train_loader:
|
|
|
+
|
|
|
if accumulated_steps == 0:
|
|
|
if epoch_iter == num_iters:
|
|
|
break
|
|
|
total_iter += 1
|
|
|
epoch_iter += 1
|
|
|
- iter_start_time = time.time()
|
|
|
- start = time.perf_counter()
|
|
|
+ iter_start_time = time.perf_counter()
|
|
|
|
|
|
- old_lr = optimizer.param_groups[0]['lr']
|
|
|
adjust_learning_rate(total_iter, optimizer, args.learning_rate,
|
|
|
args.warmup_steps)
|
|
|
- new_lr = optimizer.param_groups[0]['lr']
|
|
|
-
|
|
|
- if new_lr != old_lr:
|
|
|
- dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}'
|
|
|
- train_tblogger.log_value(total_iter, 'lrate', new_lr)
|
|
|
- else:
|
|
|
- dllog_lrate_change = None
|
|
|
|
|
|
model.zero_grad()
|
|
|
|
|
|
@@ -443,9 +436,9 @@ def main():
|
|
|
loss.backward()
|
|
|
|
|
|
if distributed_run:
|
|
|
- reduced_loss = reduce_tensor(loss.data, world_size).item()
|
|
|
+ reduced_loss = reduce_tensor(loss.data, args.world_size).item()
|
|
|
reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
|
|
|
- meta = {k: reduce_tensor(v, world_size) for k,v in meta.items()}
|
|
|
+ meta = {k: reduce_tensor(v, args.world_size) for k,v in meta.items()}
|
|
|
else:
|
|
|
reduced_loss = loss.item()
|
|
|
reduced_num_frames = num_frames.item()
|
|
|
@@ -459,7 +452,7 @@ def main():
|
|
|
|
|
|
if accumulated_steps % args.gradient_accumulation_steps == 0:
|
|
|
|
|
|
- train_tblogger.log_grads(total_iter, model)
|
|
|
+ logger.log_grads_tb(total_iter, model)
|
|
|
if args.amp:
|
|
|
torch.nn.utils.clip_grad_norm_(
|
|
|
amp.master_params(optimizer), args.grad_clip_thresh)
|
|
|
@@ -470,21 +463,23 @@ def main():
|
|
|
optimizer.step()
|
|
|
apply_ema_decay(model, ema_model, args.ema_decay)
|
|
|
|
|
|
- iter_stop_time = time.time()
|
|
|
- iter_time = iter_stop_time - iter_start_time
|
|
|
- frames_per_sec = iter_num_frames / iter_time
|
|
|
- epoch_frames_per_sec += frames_per_sec
|
|
|
+ iter_time = time.perf_counter() - iter_start_time
|
|
|
+ iter_mel_loss = iter_meta['mel_loss'].item()
|
|
|
+ epoch_frames_per_sec += iter_num_frames / iter_time
|
|
|
epoch_loss += iter_loss
|
|
|
epoch_num_frames += iter_num_frames
|
|
|
- iter_mel_loss = iter_meta['mel_loss'].item()
|
|
|
epoch_mel_loss += iter_mel_loss
|
|
|
|
|
|
- DLLogger.log((epoch, epoch_iter, num_iters), OrderedDict([
|
|
|
- ('train_loss', iter_loss), ('train_mel_loss', iter_mel_loss),
|
|
|
- ('train_frames/s', frames_per_sec), ('took', iter_time),
|
|
|
- ('lrate_change', dllog_lrate_change)
|
|
|
- ]))
|
|
|
- train_tblogger.log_meta(total_iter, iter_meta)
|
|
|
+ logger.log((epoch, epoch_iter, num_iters),
|
|
|
+ tb_total_steps=total_iter,
|
|
|
+ subset='train',
|
|
|
+ data=OrderedDict([
|
|
|
+ ('loss', iter_loss),
|
|
|
+ ('mel_loss', iter_mel_loss),
|
|
|
+ ('frames/s', iter_num_frames / iter_time),
|
|
|
+ ('took', iter_time),
|
|
|
+ ('lrate', optimizer.param_groups[0]['lr'])]),
|
|
|
+ )
|
|
|
|
|
|
accumulated_steps = 0
|
|
|
iter_loss = 0
|
|
|
@@ -492,69 +487,48 @@ def main():
|
|
|
iter_meta = {}
|
|
|
|
|
|
# Finished epoch
|
|
|
- epoch_stop_time = time.time()
|
|
|
- epoch_time = epoch_stop_time - epoch_start_time
|
|
|
-
|
|
|
- DLLogger.log((epoch,), data=OrderedDict([
|
|
|
- ('avg_train_loss', epoch_loss / epoch_iter),
|
|
|
- ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
|
|
|
- ('avg_train_frames/s', epoch_num_frames / epoch_time),
|
|
|
- ('took', epoch_time)
|
|
|
- ]))
|
|
|
-
|
|
|
- tik = time.time()
|
|
|
- val_loss, meta, num_frames = validate(
|
|
|
- model, criterion, valset, args.batch_size, world_size, collate_fn,
|
|
|
- distributed_run, local_rank, batch_to_gpu, use_gt_durations=True)
|
|
|
- tok = time.time()
|
|
|
-
|
|
|
- DLLogger.log((epoch,), data=OrderedDict([
|
|
|
- ('val_loss', val_loss),
|
|
|
- ('val_mel_loss', meta['mel_loss'].item()),
|
|
|
- ('val_frames/s', num_frames / (tok - tik)),
|
|
|
- ('took', tok - tik),
|
|
|
- ]))
|
|
|
- val_tblogger.log_meta(total_iter, meta)
|
|
|
+ epoch_time = time.perf_counter() - epoch_start_time
|
|
|
+
|
|
|
+ logger.log((epoch,),
|
|
|
+ tb_total_steps=None,
|
|
|
+ subset='train_avg',
|
|
|
+ data=OrderedDict([
|
|
|
+ ('loss', epoch_loss / epoch_iter),
|
|
|
+ ('mel_loss', epoch_mel_loss / epoch_iter),
|
|
|
+ ('frames/s', epoch_num_frames / epoch_time),
|
|
|
+ ('took', epoch_time)]),
|
|
|
+ )
|
|
|
+
|
|
|
+ validate(model, epoch, total_iter, criterion, valset, args.batch_size,
|
|
|
+ collate_fn, distributed_run, batch_to_gpu,
|
|
|
+ use_gt_durations=True)
|
|
|
|
|
|
if args.ema_decay > 0:
|
|
|
- tik_e = time.time()
|
|
|
- val_loss_e, meta_e, num_frames_e = validate(
|
|
|
- ema_model, criterion, valset, args.batch_size, world_size,
|
|
|
- collate_fn, distributed_run, local_rank, batch_to_gpu,
|
|
|
- use_gt_durations=True)
|
|
|
- tok_e = time.time()
|
|
|
-
|
|
|
- DLLogger.log((epoch,), data=OrderedDict([
|
|
|
- ('val_ema_loss', val_loss_e),
|
|
|
- ('val_ema_mel_loss', meta_e['mel_loss'].item()),
|
|
|
- ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)),
|
|
|
- ('took', tok_e - tik_e),
|
|
|
- ]))
|
|
|
- val_ema_tblogger.log_meta(total_iter, meta)
|
|
|
+ validate(ema_model, epoch, total_iter, criterion, valset,
|
|
|
+ args.batch_size, collate_fn, distributed_run, batch_to_gpu,
|
|
|
+ use_gt_durations=True, ema=True)
|
|
|
|
|
|
if (epoch > 0 and args.epochs_per_checkpoint > 0 and
|
|
|
- (epoch % args.epochs_per_checkpoint == 0) and local_rank == 0):
|
|
|
+ (epoch % args.epochs_per_checkpoint == 0) and args.local_rank == 0):
|
|
|
|
|
|
checkpoint_path = os.path.join(
|
|
|
args.output, f"FastPitch_checkpoint_{epoch}.pt")
|
|
|
- save_checkpoint(local_rank, model, ema_model, optimizer, epoch,
|
|
|
+ save_checkpoint(args.local_rank, model, ema_model, optimizer, epoch,
|
|
|
total_iter, model_config, args.amp, checkpoint_path)
|
|
|
- if local_rank == 0:
|
|
|
- DLLogger.flush()
|
|
|
+ logger.flush()
|
|
|
|
|
|
# Finished training
|
|
|
- DLLogger.log((), data=OrderedDict([
|
|
|
- ('avg_train_loss', epoch_loss / epoch_iter),
|
|
|
- ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
|
|
|
- ('avg_train_frames/s', epoch_num_frames / epoch_time),
|
|
|
- ]))
|
|
|
- DLLogger.log((), data=OrderedDict([
|
|
|
- ('val_loss', val_loss),
|
|
|
- ('val_mel_loss', meta['mel_loss'].item()),
|
|
|
- ('val_frames/s', num_frames / (tok - tik)),
|
|
|
- ]))
|
|
|
- if local_rank == 0:
|
|
|
- DLLogger.flush()
|
|
|
+ logger.log((),
|
|
|
+ tb_total_steps=None,
|
|
|
+ subset='train_avg',
|
|
|
+ data=OrderedDict([
|
|
|
+ ('loss', epoch_loss / epoch_iter),
|
|
|
+ ('mel_loss', epoch_mel_loss / epoch_iter),
|
|
|
+ ('frames/s', epoch_num_frames / epoch_time),
|
|
|
+ ('took', epoch_time)]),
|
|
|
+ )
|
|
|
+ validate(model, None, total_iter, criterion, valset, args.batch_size,
|
|
|
+ collate_fn, distributed_run, batch_to_gpu, use_gt_durations=True)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|