|
|
@@ -20,6 +20,7 @@ import itertools
|
|
|
import logging
|
|
|
import math
|
|
|
import os
|
|
|
+import shutil
|
|
|
import sys
|
|
|
import time
|
|
|
|
|
|
@@ -38,9 +39,12 @@ from data_utils import get_lm_corpus
|
|
|
from mem_transformer import MemTransformerLM
|
|
|
from utils.data_parallel import BalancedDataParallel
|
|
|
from utils.exp_utils import AverageMeter
|
|
|
+from utils.exp_utils import TimeoutHandler
|
|
|
from utils.exp_utils import benchmark
|
|
|
from utils.exp_utils import create_exp_dir
|
|
|
+from utils.exp_utils import l2_promote
|
|
|
from utils.exp_utils import log_env_info
|
|
|
+from utils.exp_utils import register_ignoring_timeout_handler
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
@@ -54,7 +58,7 @@ def parse_args():
|
|
|
cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
|
|
|
|
|
|
cfg_parser.add_argument('--config', default='default')
|
|
|
- cfg_parser.add_argument('--config_file', default='config.yaml')
|
|
|
+ cfg_parser.add_argument('--config_file', default=None)
|
|
|
|
|
|
config_args, _ = cfg_parser.parse_known_args()
|
|
|
|
|
|
@@ -81,16 +85,25 @@ def parse_args():
|
|
|
help='Run in debug mode (do not create exp dir)')
|
|
|
general.add_argument('--log_all_ranks', action='store_true',
|
|
|
help='Enable logging from all distributed ranks')
|
|
|
- general.add_argument('--save-all', action='store_true',
|
|
|
+ general.add_argument('--dllog_file', type=str, default='train_log.json',
|
|
|
+ help='Name of the DLLogger output file')
|
|
|
+ general.add_argument('--txtlog_file', type=str, default='train_log.log',
|
|
|
+ help='Name of the txt log file')
|
|
|
+ general.add_argument('--save_all', action='store_true',
|
|
|
help='Save all checkpoints')
|
|
|
general.add_argument('--no_env', action='store_true',
|
|
|
help='Do not print info on execution env')
|
|
|
+ general.add_argument('--no_eval', action='store_true',
|
|
|
+ help='Disable model evaluation')
|
|
|
general.add_argument('--log_interval', type=int, default=10,
|
|
|
help='Report interval')
|
|
|
general.add_argument('--target_throughput', type=float, default=None,
|
|
|
help='Target training throughput (for benchmarking)')
|
|
|
general.add_argument('--target_perplexity', type=float, default=None,
|
|
|
help='Target validation perplexity (for benchmarking)')
|
|
|
+ general.add_argument('--amp_mode', type=str, default='O2',
|
|
|
+ choices=['O0', 'O1', 'O2', 'O3'],
|
|
|
+ help='Optimization level for apex amp')
|
|
|
|
|
|
dataset = parser.add_argument_group('dataset setup')
|
|
|
dataset.add_argument('--data', type=str, default='../data/wikitext-103',
|
|
|
@@ -238,7 +251,8 @@ def parse_args():
|
|
|
|
|
|
|
|
|
def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
|
|
|
- train_step, best_val_loss, work_dir, name='checkpoint.pt'):
|
|
|
+ epoch, batch, last_iter, train_step, best_val_loss,
|
|
|
+ is_best, work_dir):
|
|
|
if args.fp16:
|
|
|
amp_state = amp.state_dict()
|
|
|
else:
|
|
|
@@ -252,15 +266,35 @@ def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
|
|
|
'scheduler_state': scheduler.state_dict(),
|
|
|
'vocab': vocab,
|
|
|
'amp_state': amp_state,
|
|
|
+ 'epoch': epoch,
|
|
|
+ 'batch': batch,
|
|
|
+ 'last_iter': last_iter,
|
|
|
'train_step': train_step,
|
|
|
'best_val_loss': best_val_loss,
|
|
|
}
|
|
|
|
|
|
+ last_chkpt_fname = 'checkpoint_last.pt'
|
|
|
+
|
|
|
with utils.distributed.sync_workers() as rank:
|
|
|
- path = os.path.join(work_dir, name)
|
|
|
- logging.info(f'Saving checkpoint to {path}')
|
|
|
+ last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
|
|
|
if rank == 0:
|
|
|
- torch.save(state, path)
|
|
|
+ # always save last checkpoint
|
|
|
+ logging.info(f'Saving checkpoint to {last_chkpt_path}')
|
|
|
+ torch.save(state, last_chkpt_path)
|
|
|
+
|
|
|
+ # save best checkpoint if better than previous best
|
|
|
+ if is_best:
|
|
|
+ best_chkpt_fname = 'checkpoint_best.pt'
|
|
|
+ best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
|
|
|
+ logging.info(f'Saving checkpoint to {best_chkpt_path}')
|
|
|
+ shutil.copy(last_chkpt_path, best_chkpt_path)
|
|
|
+
|
|
|
+ # save every checkpoint if save_all is true
|
|
|
+ if args.save_all:
|
|
|
+ step_chkpt_fname = f'checkpoint_{train_step}.pt'
|
|
|
+ step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
|
|
|
+ logging.info(f'Saving checkpoint to {step_chkpt_path}')
|
|
|
+ shutil.copy(last_chkpt_path, step_chkpt_path)
|
|
|
|
|
|
|
|
|
def load_checkpoint(path):
|
|
|
@@ -367,7 +401,7 @@ def evaluate(eval_iter, model, args):
|
|
|
loss, mems = model(data, target, mems)
|
|
|
loss = loss.float().mean()
|
|
|
if warm:
|
|
|
- assert (not mems) or all([m.size(0) == model.mem_len for m in mems])
|
|
|
+ assert (mems is None) or mems.size(1) == model.mem_len
|
|
|
total_loss += seq_len * loss.item()
|
|
|
total_len += seq_len
|
|
|
|
|
|
@@ -382,8 +416,9 @@ def evaluate(eval_iter, model, args):
|
|
|
|
|
|
|
|
|
def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
- optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch, train_step,
|
|
|
- best_val_loss, meters, args):
|
|
|
+ optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
|
|
|
+ last_batch, last_iter, train_step, best_val_loss, meters,
|
|
|
+ timeout_handler, args):
|
|
|
# Turn on training mode which enables dropout.
|
|
|
model.train()
|
|
|
|
|
|
@@ -393,13 +428,17 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
log_start_time = time.time()
|
|
|
|
|
|
mems = [None for _ in range(args.batch_chunk)]
|
|
|
- train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
|
|
|
+ if args.varlen:
|
|
|
+ train_iter = tr_iter.get_varlen_iter(start=last_iter)
|
|
|
+ else:
|
|
|
+ train_iter = tr_iter.get_fixlen_iter(start=last_iter)
|
|
|
|
|
|
- for batch, (data, target, seq_len, _) in enumerate(train_iter):
|
|
|
+ for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
|
|
|
log_step += 1
|
|
|
target_tokens += target.numel()
|
|
|
|
|
|
- model.zero_grad()
|
|
|
+ for param in model.parameters():
|
|
|
+ param.grad = None
|
|
|
|
|
|
data_chunks = torch.chunk(data, args.batch_chunk, 1)
|
|
|
target_chunks = torch.chunk(target, args.batch_chunk, 1)
|
|
|
@@ -467,7 +506,7 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
'| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
|
|
|
epoch,
|
|
|
train_step,
|
|
|
- batch+1,
|
|
|
+ batch,
|
|
|
tr_iter.n_batch,
|
|
|
lr,
|
|
|
avg_elapsed * 1000,
|
|
|
@@ -492,9 +531,13 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
dllogger_data['train_perplexity'] = math.exp(cur_loss)
|
|
|
|
|
|
logging.info(log_str)
|
|
|
- dllogger.log(step=train_step, data=dllogger_data)
|
|
|
+ dllogger.log(step=tuple([train_step]), data=dllogger_data)
|
|
|
+
|
|
|
+ do_periodic_eval = train_step % args.eval_interval == 0
|
|
|
+ is_final_step = train_step == args.max_step
|
|
|
+ interrupted = timeout_handler.interrupted
|
|
|
|
|
|
- if train_step % args.eval_interval == 0:
|
|
|
+ if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
|
|
|
eval_start_time = time.time()
|
|
|
val_loss = evaluate(va_iter, model, args)
|
|
|
val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
|
|
|
@@ -521,30 +564,21 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
dllogger_data['valid_perplexity'] = math.exp(val_loss)
|
|
|
logging.info(log_str)
|
|
|
logging.info('-' * 100)
|
|
|
- dllogger.log(step=train_step, data=dllogger_data)
|
|
|
+ dllogger.log(step=tuple([train_step]), data=dllogger_data)
|
|
|
+
|
|
|
+ last_iter = tr_iter.last_iter
|
|
|
|
|
|
- # Save the model if the validation loss is the best we've seen so far.
|
|
|
+ # Check if the validation loss is the best we've seen so far.
|
|
|
+ is_best = False
|
|
|
if not best_val_loss or val_loss < best_val_loss:
|
|
|
best_val_loss = val_loss
|
|
|
- if not args.debug:
|
|
|
- name = 'checkpoint_best.pt'
|
|
|
- save_checkpoint(args, model, model_config, optimizer,
|
|
|
- scheduler, vocab, train_step,
|
|
|
- best_val_loss, args.work_dir, name)
|
|
|
-
|
|
|
- # Always save after eval if save_all is true and not debug
|
|
|
- if not args.debug and args.save_all:
|
|
|
- name = f'checkpoint_{train_step}.pt'
|
|
|
- save_checkpoint(args, model, model_config, optimizer,
|
|
|
- scheduler, vocab, train_step, best_val_loss,
|
|
|
- args.work_dir, name)
|
|
|
+ is_best = True
|
|
|
|
|
|
- # Save last checkpoint if not debug and not save_all
|
|
|
- if not args.debug and not args.save_all:
|
|
|
- name = 'checkpoint_last.pt'
|
|
|
+ if not args.debug:
|
|
|
save_checkpoint(args, model, model_config, optimizer,
|
|
|
- scheduler, vocab, train_step, best_val_loss,
|
|
|
- args.work_dir, name)
|
|
|
+ scheduler, vocab, epoch, batch, last_iter,
|
|
|
+ train_step, best_val_loss, is_best,
|
|
|
+ args.work_dir)
|
|
|
|
|
|
# dev-performance based learning rate annealing
|
|
|
if args.scheduler == 'dev_perf':
|
|
|
@@ -555,16 +589,22 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
# subtract eval time from timers for training
|
|
|
log_start_time += time.time() - eval_start_time
|
|
|
|
|
|
- if train_step == args.max_step:
|
|
|
+ if interrupted:
|
|
|
+ logging.info(f'Received SIGTERM, exiting')
|
|
|
+ sys.exit(0)
|
|
|
+
|
|
|
+ if is_final_step:
|
|
|
break
|
|
|
return train_step, best_val_loss
|
|
|
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
+ utils.gpu_affinity.set_affinity(args.local_rank)
|
|
|
|
|
|
# Initialize device and distributed backend
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
+ l2_promote()
|
|
|
device = torch.device('cuda' if args.cuda else 'cpu')
|
|
|
utils.distributed.init_distributed(args.cuda)
|
|
|
|
|
|
@@ -584,8 +624,8 @@ def main():
|
|
|
if args.log_all_ranks:
|
|
|
log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
|
|
|
else:
|
|
|
- log_file = f'train_log.log'
|
|
|
- dllog_file = f'train_log.json'
|
|
|
+ log_file = args.txtlog_file
|
|
|
+ dllog_file = args.dllog_file
|
|
|
log_file = os.path.join(args.work_dir, log_file)
|
|
|
dllog_file = os.path.join(args.work_dir, dllog_file)
|
|
|
|
|
|
@@ -607,9 +647,13 @@ def main():
|
|
|
logging.info(args)
|
|
|
dllogger.log(step='PARAMETER', data=vars(args))
|
|
|
|
|
|
+ logging.info(f'world size: {utils.distributed.get_world_size()}')
|
|
|
+
|
|
|
if not args.no_env:
|
|
|
log_env_info()
|
|
|
|
|
|
+ register_ignoring_timeout_handler()
|
|
|
+
|
|
|
# Set the random seed manually for reproducibility.
|
|
|
np.random.seed(args.seed)
|
|
|
torch.manual_seed(args.seed)
|
|
|
@@ -732,7 +776,7 @@ def main():
|
|
|
model, optimizer = amp.initialize(
|
|
|
model,
|
|
|
optimizer,
|
|
|
- opt_level='O2',
|
|
|
+ opt_level=args.amp_mode,
|
|
|
)
|
|
|
|
|
|
if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
|
|
|
@@ -806,20 +850,36 @@ def main():
|
|
|
logging.info('#non emb params = {}'.format(args.n_nonemb_param))
|
|
|
|
|
|
train_step = 0
|
|
|
+ start_epoch = 1
|
|
|
+ last_batch = 0
|
|
|
+ last_iter = 0
|
|
|
best_val_loss = None
|
|
|
|
|
|
if args.restart:
|
|
|
- checkpoint = load_checkpoint(args.restart)
|
|
|
- model.load_state_dict(checkpoint['model_state'])
|
|
|
- optimizer.load_state_dict(checkpoint['optimizer_state'])
|
|
|
- scheduler.load_state_dict(checkpoint['scheduler_state'])
|
|
|
- if args.fp16:
|
|
|
- amp.load_state_dict(checkpoint['amp_state'])
|
|
|
- train_step = checkpoint['train_step']
|
|
|
- best_val_loss = checkpoint['best_val_loss']
|
|
|
-
|
|
|
- model.apply(functools.partial(update_dropout, args=args))
|
|
|
- model.apply(functools.partial(update_dropatt, args=args))
|
|
|
+ try:
|
|
|
+ checkpoint = load_checkpoint(args.restart)
|
|
|
+ model.load_state_dict(checkpoint['model_state'])
|
|
|
+ optimizer.load_state_dict(checkpoint['optimizer_state'])
|
|
|
+ scheduler.load_state_dict(checkpoint['scheduler_state'])
|
|
|
+ if args.fp16:
|
|
|
+ amp.load_state_dict(checkpoint['amp_state'])
|
|
|
+ train_step = checkpoint['train_step']
|
|
|
+ start_epoch = checkpoint['epoch']
|
|
|
+ last_batch = checkpoint['batch']
|
|
|
+ last_iter = checkpoint['last_iter']
|
|
|
+ best_val_loss = checkpoint['best_val_loss']
|
|
|
+
|
|
|
+ if train_step >= args.max_step:
|
|
|
+ logging.info(f'Loaded checkpoint after {train_step} steps, but '
|
|
|
+ f'this run was scheduled for a total of '
|
|
|
+ f'{args.max_step} steps, exiting')
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ model.apply(functools.partial(update_dropout, args=args))
|
|
|
+ model.apply(functools.partial(update_dropatt, args=args))
|
|
|
+ except FileNotFoundError:
|
|
|
+ logging.info(f'Could not load checkpoint from {args.restart}, '
|
|
|
+ f'starting training from random init')
|
|
|
|
|
|
meters = {}
|
|
|
warmup = args.mem_len // args.tgt_len + 2
|
|
|
@@ -830,23 +890,28 @@ def main():
|
|
|
# Loop over epochs.
|
|
|
# At any point you can hit Ctrl + C to break out of training early.
|
|
|
start_time = time.time()
|
|
|
- try:
|
|
|
- for epoch in itertools.count(start=1):
|
|
|
- if args.roll:
|
|
|
- tr_iter.roll()
|
|
|
- train_step, best_val_loss = train(
|
|
|
- tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
- optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
|
|
|
- train_step, best_val_loss, meters, args
|
|
|
- )
|
|
|
+ with TimeoutHandler() as timeout_handler:
|
|
|
+ try:
|
|
|
+ for epoch in itertools.count(start=start_epoch):
|
|
|
+ if args.roll:
|
|
|
+ tr_iter.roll(seed=args.seed + epoch)
|
|
|
+ train_step, best_val_loss = train(
|
|
|
+ tr_iter, va_iter, model, para_model, model_config,
|
|
|
+ optimizer, optimizer_sparse, scheduler, scheduler_sparse,
|
|
|
+ vocab, epoch, last_batch, last_iter, train_step,
|
|
|
+ best_val_loss, meters, timeout_handler, args
|
|
|
+ )
|
|
|
|
|
|
- if train_step == args.max_step:
|
|
|
- logging.info('-' * 100)
|
|
|
- logging.info('End of training')
|
|
|
- break
|
|
|
- except KeyboardInterrupt:
|
|
|
- logging.info('-' * 100)
|
|
|
- logging.info('Exiting from training early')
|
|
|
+ last_batch = 0
|
|
|
+ last_iter = 0
|
|
|
+
|
|
|
+ if train_step == args.max_step:
|
|
|
+ logging.info('-' * 100)
|
|
|
+ logging.info('End of training')
|
|
|
+ break
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ logging.info('-' * 100)
|
|
|
+ logging.info('Exiting from training early')
|
|
|
elapsed = time.time() - start_time
|
|
|
|
|
|
###########################################################################
|
|
|
@@ -854,7 +919,7 @@ def main():
|
|
|
###########################################################################
|
|
|
summary = {}
|
|
|
test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
|
|
|
- if not args.debug and os.path.exists(test_path):
|
|
|
+ if not args.debug and not args.no_eval and os.path.exists(test_path):
|
|
|
# Load the best saved model.
|
|
|
checkpoint = load_checkpoint(test_path)
|
|
|
model.load_state_dict(checkpoint['model_state'])
|
|
|
@@ -911,4 +976,17 @@ def main():
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
+ # Disable profiling executor
|
|
|
+ try:
|
|
|
+ torch._C._jit_set_profiling_executor(False)
|
|
|
+ torch._C._jit_set_profiling_mode(False)
|
|
|
+ except AttributeError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Before we do anything with models, we want to ensure that we get fp16
|
|
|
+ # execution of torch.einsum.
|
|
|
+ # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
|
|
|
+ # Note that running `--amp_mode O2` will remove the need for this
|
|
|
+ # code, but it is still valid.
|
|
|
+ amp.register_half_function(torch, 'einsum')
|
|
|
main()
|