| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992 |
- # coding: utf-8
- # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- import functools
- import itertools
- import logging
- import math
- import os
- import shutil
- import sys
- import time
- import dllogger
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import yaml
- from apex import amp
- from torch.nn.parallel import DistributedDataParallel
- import lamb
- import utils
- from data_utils import get_lm_corpus
- from mem_transformer import MemTransformerLM
- from utils.data_parallel import BalancedDataParallel
- from utils.exp_utils import AverageMeter
- from utils.exp_utils import TimeoutHandler
- from utils.exp_utils import benchmark
- from utils.exp_utils import create_exp_dir
- from utils.exp_utils import l2_promote
- from utils.exp_utils import log_env_info
- from utils.exp_utils import register_ignoring_timeout_handler
- def parse_args():
- parent_parser = argparse.ArgumentParser(
- description='PyTorch Transformer-XL Language Model',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- add_help=False,
- )
- parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
- cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
- cfg_parser.add_argument('--config', default='default')
- cfg_parser.add_argument('--config_file', default=None)
- config_args, _ = cfg_parser.parse_known_args()
- if config_args.config is not None and config_args.config_file is not None:
- with open(config_args.config_file) as f:
- config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
- else:
- config = {}
- general = parser.add_argument_group('general setup')
- general.add_argument('--work_dir', default='LM-TFM', type=str,
- help='Directory for the results')
- general.add_argument('--append_dataset', action='store_true',
- help='Automatically append dataset name to work_dir')
- general.add_argument('--append_time', action='store_true',
- help='Automatically append current time to work_dir')
- general.add_argument('--cuda', action='store_true',
- help='Run training on a GPU using CUDA')
- general.add_argument('--fp16', action='store_true',
- help='Run training in fp16/mixed precision')
- general.add_argument('--restart', type=str, default='',
- help='Restart training from the saved checkpoint')
- general.add_argument('--debug', action='store_true',
- help='Run in debug mode (do not create exp dir)')
- general.add_argument('--log_all_ranks', action='store_true',
- help='Enable logging from all distributed ranks')
- general.add_argument('--dllog_file', type=str, default='train_log.json',
- help='Name of the DLLogger output file')
- general.add_argument('--txtlog_file', type=str, default='train_log.log',
- help='Name of the txt log file')
- general.add_argument('--save_all', action='store_true',
- help='Save all checkpoints')
- general.add_argument('--no_env', action='store_true',
- help='Do not print info on execution env')
- general.add_argument('--no_eval', action='store_true',
- help='Disable model evaluation')
- general.add_argument('--log_interval', type=int, default=10,
- help='Report interval')
- general.add_argument('--target_throughput', type=float, default=None,
- help='Target training throughput (for benchmarking)')
- general.add_argument('--target_perplexity', type=float, default=None,
- help='Target validation perplexity (for benchmarking)')
- general.add_argument('--amp_mode', type=str, default='O2',
- choices=['O0', 'O1', 'O2', 'O3'],
- help='Optimization level for apex amp')
- dataset = parser.add_argument_group('dataset setup')
- dataset.add_argument('--data', type=str, default='../data/wikitext-103',
- help='Location of the data corpus')
- dataset.add_argument('--dataset', type=str, default='wt103',
- choices=['wt103', 'lm1b', 'enwik8', 'text8'],
- help='Dataset name')
- dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
- help='Type of vocabulary')
- model = parser.add_argument_group('model setup')
- model.add_argument('--n_layer', type=int, default=16,
- help='Number of total layers')
- model.add_argument('--n_head', type=int, default=8,
- help='Number of heads')
- model.add_argument('--d_head', type=int, default=64,
- help='Head dimension')
- model.add_argument('--d_embed', type=int, default=-1,
- help='Embedding dimension')
- model.add_argument('--d_model', type=int, default=512,
- help='Model dimension')
- model.add_argument('--d_inner', type=int, default=2048,
- help='Inner dimension in feedforward layer')
- model.add_argument('--dropout', type=float, default=0.1,
- help='Global dropout rate')
- model.add_argument('--dropatt', type=float, default=0.0,
- help='Attention probability dropout rate')
- model.add_argument('--pre_lnorm', action='store_true',
- help='Apply LayerNorm to the input instead of the output')
- model.add_argument('--attn_type', type=int, default=0,
- help='Attention type. 0 for ours, 1 for Shaw et al,'
- '2 for Vaswani et al, 3 for Al Rfou et al.')
- model.add_argument('--not_tied', action='store_true',
- help='Do not tie the word embedding and softmax weights')
- model.add_argument('--clamp_len', type=int, default=-1,
- help='Use the same pos embeddings after clamp_len')
- model.add_argument('--adaptive', action='store_true',
- help='Use adaptive softmax')
- model.add_argument('--div_val', type=int, default=1,
- help='Dividend value for adaptive input and softmax')
- model.add_argument('--sample_softmax', type=int, default=-1,
- help='Number of samples in sampled softmax')
- model.add_argument('--init', default='normal', type=str,
- help='Parameter initializer to use')
- model.add_argument('--emb_init', default='normal', type=str,
- help='Parameter initializer to use')
- model.add_argument('--init_range', type=float, default=0.1,
- help='Parameters initialized by U(-init_range, init_range)')
- model.add_argument('--emb_init_range', type=float, default=0.01,
- help='Parameters initialized by U(-init_range, init_range)')
- model.add_argument('--init_std', type=float, default=0.02,
- help='Parameters initialized by N(0, init_std)')
- model.add_argument('--proj_init_std', type=float, default=0.01,
- help='Parameters initialized by N(0, init_std)')
- opt = parser.add_argument_group('optimizer setup')
- opt.add_argument('--optim', default='jitlamb', type=str,
- choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
- help='Optimizer to use')
- opt.add_argument('--lr', type=float, default=0.01,
- help='Initial learning rate')
- opt.add_argument('--mom', type=float, default=0.0,
- help='Momentum for sgd')
- opt.add_argument('--scheduler', default='cosine', type=str,
- choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
- help='LR scheduler to use')
- opt.add_argument('--max_step_scheduler', type=int, default=None,
- help='Max number of training steps for LR scheduler')
- opt.add_argument('--warmup_step', type=int, default=1000,
- help='Number of iterations for LR warmup')
- opt.add_argument('--decay_rate', type=float, default=0.5,
- help='Decay factor when ReduceLROnPlateau is used')
- opt.add_argument('--lr_min', type=float, default=0.0,
- help='Minimum learning rate during annealing')
- opt.add_argument('--clip', type=float, default=0.25,
- help='Gradient clipping')
- opt.add_argument('--weight_decay', type=float, default=0.0,
- help='Weight decay for adam|lamb')
- opt.add_argument('--clip_nonemb', action='store_true',
- help='Only clip the gradient of non-embedding params')
- opt.add_argument('--patience', type=int, default=0,
- help='Patience')
- opt.add_argument('--eta_min', type=float, default=0.001,
- help='Min learning rate for cosine scheduler')
- training = parser.add_argument_group('training setup')
- training.add_argument('--max_step', type=int, default=40000,
- help='Max number of training steps')
- training.add_argument('--batch_size', type=int, default=256,
- help='Global batch size')
- training.add_argument('--local_batch_size', type=int, default=None,
- help='Local (per-device) batch size, this setting \
- overrides global --batch_size and sets batch_size \
- to local_batch_size * world_size')
- training.add_argument('--batch_chunk', type=int, default=1,
- help='Split batch into chunks and train with '
- 'gradient accumulation')
- training.add_argument('--roll', action='store_true',
- help='Enable random shifts within each data stream')
- training.add_argument('--tgt_len', type=int, default=192,
- help='Number of tokens to predict')
- training.add_argument('--ext_len', type=int, default=0,
- help='Length of the extended context')
- training.add_argument('--mem_len', type=int, default=192,
- help='Length of the retained previous heads')
- training.add_argument('--seed', type=int, default=1111,
- help='Random seed')
- training.add_argument('--multi_gpu', default=None, type=str,
- choices=['ddp', 'dp'],
- help='Use multiple GPU')
- training.add_argument('--gpu0_bsz', type=int, default=-1,
- help='Batch size on gpu 0 (for "dp" backend)')
- training.add_argument('--same_length', action='store_true',
- help='Use the same attn length for all tokens')
- training.add_argument('--varlen', action='store_true',
- help='Use variable length')
- val = parser.add_argument_group('validation setup')
- val.add_argument('--eval_tgt_len', type=int, default=192,
- help='Number of tokens to predict for evaluation')
- val.add_argument('--eval_batch_size', type=int, default=16,
- help='Eval batch size')
- val.add_argument('--eval_max_steps', type=int, default=-1,
- help='Max eval steps')
- val.add_argument('--eval_interval', type=int, default=5000,
- help='Evaluation interval')
- dist = parser.add_argument_group('distributed setup')
- dist.add_argument('--local_rank', type=int,
- default=os.getenv('LOCAL_RANK', 0),
- help='Used for multi-process training.')
- parser.set_defaults(**config)
- args, _ = parser.parse_known_args()
- args.tied = not args.not_tied
- if args.d_embed < 0:
- args.d_embed = args.d_model
- assert args.ext_len >= 0, 'extended context length must be non-negative'
- assert args.batch_size % args.batch_chunk == 0
- return args
- def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
- epoch, batch, last_iter, train_step, best_val_loss,
- is_best, work_dir):
- if args.fp16:
- amp_state = amp.state_dict()
- else:
- amp_state = None
- state = {
- 'args': args,
- 'model_config': model_config,
- 'model_state': model.state_dict(),
- 'optimizer_state': optimizer.state_dict(),
- 'scheduler_state': scheduler.state_dict(),
- 'vocab': vocab,
- 'amp_state': amp_state,
- 'epoch': epoch,
- 'batch': batch,
- 'last_iter': last_iter,
- 'train_step': train_step,
- 'best_val_loss': best_val_loss,
- }
- last_chkpt_fname = 'checkpoint_last.pt'
- with utils.distributed.sync_workers() as rank:
- last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
- if rank == 0:
- # always save last checkpoint
- logging.info(f'Saving checkpoint to {last_chkpt_path}')
- torch.save(state, last_chkpt_path)
- # save best checkpoint if better than previous best
- if is_best:
- best_chkpt_fname = 'checkpoint_best.pt'
- best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
- logging.info(f'Saving checkpoint to {best_chkpt_path}')
- shutil.copy(last_chkpt_path, best_chkpt_path)
- # save every checkpoint if save_all is true
- if args.save_all:
- step_chkpt_fname = f'checkpoint_{train_step}.pt'
- step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
- logging.info(f'Saving checkpoint to {step_chkpt_path}')
- shutil.copy(last_chkpt_path, step_chkpt_path)
- def load_checkpoint(path):
- if os.path.isdir(path):
- path = os.path.join(path, 'checkpoint_last.pt')
- dst = f'cuda:{torch.cuda.current_device()}'
- logging.info(f'Loading checkpoint from {path}')
- checkpoint = torch.load(path, map_location=dst)
- return checkpoint
- def init_weight(weight, args):
- if args.init == 'uniform':
- nn.init.uniform_(weight, -args.init_range, args.init_range)
- elif args.init == 'normal':
- nn.init.normal_(weight, 0.0, args.init_std)
- def init_bias(bias):
- nn.init.constant_(bias, 0.0)
- def weights_init(m, args):
- classname = m.__class__.__name__
- if classname.find('Linear') != -1:
- if hasattr(m, 'weight') and m.weight is not None:
- init_weight(m.weight, args)
- if hasattr(m, 'bias') and m.bias is not None:
- init_bias(m.bias)
- elif classname.find('AdaptiveEmbedding') != -1:
- if hasattr(m, 'emb_projs'):
- for i in range(len(m.emb_projs)):
- if m.emb_projs[i] is not None:
- nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
- elif classname.find('Embedding') != -1:
- if hasattr(m, 'weight'):
- init_weight(m.weight, args)
- elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
- if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
- init_weight(m.cluster_weight, args)
- if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
- init_bias(m.cluster_bias)
- if hasattr(m, 'out_projs'):
- for i in range(len(m.out_projs)):
- if m.out_projs[i] is not None:
- nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
- if hasattr(m, 'out_layers_weights'):
- for i in range(len(m.out_layers_weights)):
- if m.out_layers_weights[i] is not None:
- init_weight(m.out_layers_weights[i], args)
- elif classname.find('LayerNorm') != -1:
- if hasattr(m, 'weight'):
- nn.init.normal_(m.weight, 1.0, args.init_std)
- if hasattr(m, 'bias') and m.bias is not None:
- init_bias(m.bias)
- elif classname.find('TransformerLM') != -1:
- if hasattr(m, 'r_emb'):
- init_weight(m.r_emb, args)
- if hasattr(m, 'r_w_bias'):
- init_weight(m.r_w_bias, args)
- if hasattr(m, 'r_r_bias'):
- init_weight(m.r_r_bias, args)
- if hasattr(m, 'r_bias'):
- init_bias(m.r_bias)
- def update_dropout(m, args):
- classname = m.__class__.__name__
- if classname.find('Dropout') != -1:
- if hasattr(m, 'p'):
- m.p = args.dropout
- def update_dropatt(m, args):
- if hasattr(m, 'dropatt'):
- m.dropatt.p = args.dropatt
- def evaluate(eval_iter, model, args):
- # Turn on evaluation mode which disables dropout.
- model.eval()
- # If the model does not use memory at all, make the ext_len longer.
- # Otherwise, make the mem_len longer and keep the ext_len the same.
- if args.mem_len == 0:
- model.reset_length(tgt_len=args.eval_tgt_len,
- ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
- mem_len=args.mem_len
- )
- else:
- model.reset_length(tgt_len=args.eval_tgt_len,
- ext_len=args.ext_len,
- mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
- )
- # Evaluation
- total_len, total_loss = 0, 0.
- with torch.no_grad():
- mems = None
- for i, (data, target, seq_len, warm) in enumerate(eval_iter):
- if args.eval_max_steps > 0 and i >= args.eval_max_steps:
- break
- loss, mems = model(data, target, mems)
- loss = loss.float().mean()
- if warm:
- assert (mems is None) or mems.size(1) == model.mem_len
- total_loss += seq_len * loss.item()
- total_len += seq_len
- # Switch back to the training mode
- model.reset_length(tgt_len=args.tgt_len,
- ext_len=args.ext_len,
- mem_len=args.mem_len
- )
- model.train()
- return total_loss / total_len
- def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
- optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
- last_batch, last_iter, train_step, best_val_loss, meters,
- timeout_handler, args):
- # Turn on training mode which enables dropout.
- model.train()
- train_loss = 0
- target_tokens = 0
- log_step = 0
- log_start_time = time.time()
- mems = [None for _ in range(args.batch_chunk)]
- if args.varlen:
- train_iter = tr_iter.get_varlen_iter(start=last_iter)
- else:
- train_iter = tr_iter.get_fixlen_iter(start=last_iter)
- for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
- log_step += 1
- target_tokens += target.numel()
- for param in model.parameters():
- param.grad = None
- data_chunks = torch.chunk(data, args.batch_chunk, 1)
- target_chunks = torch.chunk(target, args.batch_chunk, 1)
- for i in range(args.batch_chunk):
- data_i = data_chunks[i].contiguous()
- target_i = target_chunks[i].contiguous()
- loss, mems[i] = para_model(data_i, target_i, mems[i])
- loss = loss.float().mean().type_as(loss) / args.batch_chunk
- if args.fp16:
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- else:
- loss.backward()
- train_loss += loss.float().item()
- if args.fp16:
- torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
- else:
- torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
- optimizer.step()
- if optimizer_sparse:
- optimizer_sparse.step()
- # step-wise learning rate annealing
- train_step += 1
- if args.scheduler in ['cosine', 'constant', 'dev_perf']:
- # linear warmup stage
- if train_step < args.warmup_step:
- curr_lr = args.lr * train_step / args.warmup_step
- optimizer.param_groups[0]['lr'] = curr_lr
- if optimizer_sparse:
- optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
- else:
- if args.scheduler == 'cosine':
- scheduler.step(train_step - args.warmup_step)
- if scheduler_sparse:
- scheduler_sparse.step(train_step - args.warmup_step)
- elif args.scheduler == 'inv_sqrt':
- scheduler.step(train_step)
- if scheduler_sparse:
- scheduler_sparse.step(train_step)
- if train_step % args.log_interval == 0:
- cur_loss = train_loss / log_step
- cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
- train_loss = 0
- elapsed = time.time() - log_start_time
- avg_elapsed = elapsed / log_step
- avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
- log_start_time = time.time()
- log_step = 0
- lr = optimizer.param_groups[0]['lr']
- throughput = target_tokens / elapsed
- throughput = utils.distributed.all_reduce_item(throughput, op='sum')
- meters['train_throughput'].update(throughput)
- target_tokens = 0
- log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
- '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
- epoch,
- train_step,
- batch,
- tr_iter.n_batch,
- lr,
- avg_elapsed * 1000,
- throughput,
- cur_loss,
- )
- dllogger_data = {
- 'epoch': epoch,
- 'train_batch': batch+1,
- 'lr': lr,
- 'train_time/batch': avg_elapsed * 1000,
- 'train_throughput': throughput,
- 'train_loss': cur_loss,
- }
- if args.dataset in ['enwik8', 'text8']:
- log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
- dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
- else:
- log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
- dllogger_data['train_perplexity'] = math.exp(cur_loss)
- logging.info(log_str)
- dllogger.log(step=tuple([train_step]), data=dllogger_data)
- do_periodic_eval = train_step % args.eval_interval == 0
- is_final_step = train_step == args.max_step
- interrupted = timeout_handler.interrupted
- if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
- eval_start_time = time.time()
- val_loss = evaluate(va_iter, model, args)
- val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
- logging.info('-' * 100)
- log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
- '| valid loss {:5.2f}'.format(
- train_step // args.eval_interval,
- train_step,
- (time.time() - eval_start_time),
- val_loss,
- )
- dllogger_data = {
- 'valid_elapsed': (time.time() - eval_start_time),
- 'valid_loss': val_loss,
- }
- if args.dataset in ['enwik8', 'text8']:
- log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
- dllogger_data['valid_bits_per_character'] = val_loss / math.log(2)
- else:
- log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
- dllogger_data['valid_perplexity'] = math.exp(val_loss)
- logging.info(log_str)
- logging.info('-' * 100)
- dllogger.log(step=tuple([train_step]), data=dllogger_data)
- last_iter = tr_iter.last_iter
- # Check if the validation loss is the best we've seen so far.
- is_best = False
- if not best_val_loss or val_loss < best_val_loss:
- best_val_loss = val_loss
- is_best = True
- if not args.debug:
- save_checkpoint(args, model, model_config, optimizer,
- scheduler, vocab, epoch, batch, last_iter,
- train_step, best_val_loss, is_best,
- args.work_dir)
- # dev-performance based learning rate annealing
- if args.scheduler == 'dev_perf':
- scheduler.step(val_loss)
- if scheduler_sparse:
- scheduler_sparse.step(val_loss)
- # subtract eval time from timers for training
- log_start_time += time.time() - eval_start_time
- if interrupted:
- logging.info(f'Received SIGTERM, exiting')
- sys.exit(0)
- if is_final_step:
- break
- return train_step, best_val_loss
- def main():
- args = parse_args()
- utils.gpu_affinity.set_affinity(args.local_rank)
- # Initialize device and distributed backend
- torch.cuda.set_device(args.local_rank)
- l2_promote()
- device = torch.device('cuda' if args.cuda else 'cpu')
- utils.distributed.init_distributed(args.cuda)
- args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
- args.dataset,
- args.append_dataset,
- args.append_time,
- )
- with utils.distributed.sync_workers() as rank:
- if rank == 0:
- create_exp_dir(args.work_dir,
- scripts_to_save=['train.py', 'mem_transformer.py'],
- debug=args.debug)
- # Setup logging
- if args.log_all_ranks:
- log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
- else:
- log_file = args.txtlog_file
- dllog_file = args.dllog_file
- log_file = os.path.join(args.work_dir, log_file)
- dllog_file = os.path.join(args.work_dir, dllog_file)
- if args.debug:
- log_file = os.devnull
- dllog_file = os.devnull
- utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
- filename=log_file,
- )
- utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
- if args.local_batch_size is not None:
- world_size = utils.distributed.get_world_size()
- args.batch_size = world_size * args.local_batch_size
- logging.info(f'--local_batch_size was set, adjusting global batch size'
- f' to {args.batch_size} (local_batch_size * world_size)')
- logging.info(args)
- dllogger.log(step='PARAMETER', data=vars(args))
- logging.info(f'world size: {utils.distributed.get_world_size()}')
- if not args.no_env:
- log_env_info()
- register_ignoring_timeout_handler()
- # Set the random seed manually for reproducibility.
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
- ###########################################################################
- # Load data
- ###########################################################################
- corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
- ntokens = len(corpus.vocab)
- vocab = corpus.vocab
- args.n_token = ntokens
- if args.mem_len == 0:
- eval_mem_len = 0
- else:
- eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len
- tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
- device=device, ext_len=args.ext_len)
- va_iter = corpus.get_iterator('valid', args.eval_batch_size,
- args.eval_tgt_len, device=device,
- mem_len=eval_mem_len, ext_len=args.ext_len)
- te_iter = corpus.get_iterator('test', args.eval_batch_size,
- args.eval_tgt_len, device=device,
- mem_len=eval_mem_len, ext_len=args.ext_len)
- # adaptive softmax / embedding
- cutoffs, tie_projs = [], [False]
- if args.adaptive:
- assert args.dataset in ['wt103', 'lm1b']
- if args.dataset == 'wt103':
- cutoffs = [19997, 39997, 199997]
- tie_projs += [True] * len(cutoffs)
- elif args.dataset == 'lm1b':
- cutoffs = [59997, 99997, 639997]
- tie_projs += [False] * len(cutoffs)
- ###########################################################################
- # Build the model
- ###########################################################################
- model_config = {
- 'n_token': ntokens,
- 'n_layer': args.n_layer,
- 'n_head': args.n_head,
- 'd_model': args.d_model,
- 'd_head': args.d_head,
- 'd_inner': args.d_inner,
- 'dropout': args.dropout,
- 'dropatt': args.dropatt,
- 'dtype': None,
- 'tie_weight': args.tied,
- 'd_embed': args.d_embed,
- 'div_val': args.div_val,
- 'tie_projs': tie_projs,
- 'pre_lnorm': args.pre_lnorm,
- 'tgt_len': args.tgt_len,
- 'ext_len': args.ext_len,
- 'mem_len': args.mem_len,
- 'cutoffs': cutoffs,
- 'same_length': args.same_length,
- 'attn_type': args.attn_type,
- 'clamp_len': args.clamp_len,
- 'sample_softmax': args.sample_softmax,
- }
- model = MemTransformerLM(**model_config)
- model.apply(functools.partial(weights_init, args=args))
- # ensure embedding init is not overridden by out_layer in case of weight sharing
- model.word_emb.apply(functools.partial(weights_init, args=args))
- args.n_all_param = sum([p.nelement() for p in model.parameters()])
- args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
- # optimizer
- if args.optim.lower() == 'sgd':
- if args.sample_softmax > 0:
- dense_params, sparse_params = [], []
- for param in model.parameters():
- if param.size() == model.word_emb.weight.size():
- sparse_params.append(param)
- else:
- dense_params.append(param)
- optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
- optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
- else:
- optimizer = optim.SGD(model.parameters(), lr=args.lr,
- momentum=args.mom)
- optimizer_sparse = None
- elif args.optim.lower() == 'adam':
- if args.sample_softmax > 0:
- dense_params, sparse_params = [], []
- for param in model.parameters():
- if param.size() == model.word_emb.weight.size():
- sparse_params.append(param)
- else:
- dense_params.append(param)
- optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
- optimizer = optim.Adam(dense_params, lr=args.lr,
- weight_decay=args.weight_decay)
- else:
- optimizer = optim.Adam(model.parameters(), lr=args.lr,
- weight_decay=args.weight_decay)
- optimizer_sparse = None
- elif args.optim.lower() == 'adagrad':
- optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
- optimizer_sparse = None
- elif args.optim.lower() == 'lamb':
- optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
- weight_decay=args.weight_decay)
- optimizer_sparse = None
- elif args.optim.lower() == 'jitlamb':
- optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
- weight_decay=args.weight_decay)
- optimizer_sparse = None
- model = model.to(device)
- if args.fp16:
- model, optimizer = amp.initialize(
- model,
- optimizer,
- opt_level=args.amp_mode,
- )
- if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
- para_model = DistributedDataParallel(model,
- device_ids=[args.local_rank],
- output_device=args.local_rank,
- broadcast_buffers=False,
- find_unused_parameters=True,
- )
- elif args.multi_gpu == 'dp':
- if args.gpu0_bsz >= 0:
- para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
- model, dim=1).to(device)
- else:
- para_model = nn.DataParallel(model, dim=1).to(device)
- else:
- para_model = model
- # scheduler
- if args.scheduler == 'cosine':
- if args.max_step_scheduler:
- max_step = args.max_step_scheduler
- else:
- max_step = args.max_step
- scheduler = optim.lr_scheduler.CosineAnnealingLR(
- optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
- if args.sample_softmax > 0 and optimizer_sparse is not None:
- scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
- optimizer_sparse, max_step - args.warmup_step,
- eta_min=args.eta_min)
- else:
- scheduler_sparse = None
- elif args.scheduler == 'inv_sqrt':
- # originally used for Transformer (in Attention is all you need)
- def lr_lambda(step):
- # return a multiplier instead of a learning rate
- if step == 0 and args.warmup_step == 0:
- return 1.
- else:
- return 1. / (step ** 0.5) if step > args.warmup_step \
- else step / (args.warmup_step ** 1.5)
- scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
- if args.sample_softmax > 0 and optimizer_sparse is not None:
- scheduler_sparse = optim.lr_scheduler.LambdaLR(
- optimizer_sparse,
- lr_lambda=lr_lambda
- )
- else:
- scheduler_sparse = None
- elif args.scheduler == 'dev_perf':
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(
- optimizer, factor=args.decay_rate, patience=args.patience,
- min_lr=args.lr_min,
- )
- if args.sample_softmax > 0 and optimizer_sparse is not None:
- scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
- optimizer_sparse, factor=args.decay_rate, patience=args.patience,
- min_lr=args.lr_min,
- )
- else:
- scheduler_sparse = None
- elif args.scheduler == 'constant':
- pass
- logging.info('=' * 100)
- for k, v in args.__dict__.items():
- logging.info(' - {} : {}'.format(k, v))
- logging.info('=' * 100)
- logging.info('#params = {}'.format(args.n_all_param))
- logging.info('#non emb params = {}'.format(args.n_nonemb_param))
- train_step = 0
- start_epoch = 1
- last_batch = 0
- last_iter = 0
- best_val_loss = None
- if args.restart:
- try:
- checkpoint = load_checkpoint(args.restart)
- model.load_state_dict(checkpoint['model_state'])
- optimizer.load_state_dict(checkpoint['optimizer_state'])
- scheduler.load_state_dict(checkpoint['scheduler_state'])
- if args.fp16:
- amp.load_state_dict(checkpoint['amp_state'])
- train_step = checkpoint['train_step']
- start_epoch = checkpoint['epoch']
- last_batch = checkpoint['batch']
- last_iter = checkpoint['last_iter']
- best_val_loss = checkpoint['best_val_loss']
- if train_step >= args.max_step:
- logging.info(f'Loaded checkpoint after {train_step} steps, but '
- f'this run was scheduled for a total of '
- f'{args.max_step} steps, exiting')
- sys.exit(1)
- model.apply(functools.partial(update_dropout, args=args))
- model.apply(functools.partial(update_dropatt, args=args))
- except FileNotFoundError:
- logging.info(f'Could not load checkpoint from {args.restart}, '
- f'starting training from random init')
- meters = {}
- warmup = args.mem_len // args.tgt_len + 2
- meters['train_throughput'] = AverageMeter(warmup=warmup)
- ###########################################################################
- # Train
- ###########################################################################
- # Loop over epochs.
- # At any point you can hit Ctrl + C to break out of training early.
- start_time = time.time()
- with TimeoutHandler() as timeout_handler:
- try:
- for epoch in itertools.count(start=start_epoch):
- if args.roll:
- tr_iter.roll(seed=args.seed + epoch)
- train_step, best_val_loss = train(
- tr_iter, va_iter, model, para_model, model_config,
- optimizer, optimizer_sparse, scheduler, scheduler_sparse,
- vocab, epoch, last_batch, last_iter, train_step,
- best_val_loss, meters, timeout_handler, args
- )
- last_batch = 0
- last_iter = 0
- if train_step == args.max_step:
- logging.info('-' * 100)
- logging.info('End of training')
- break
- except KeyboardInterrupt:
- logging.info('-' * 100)
- logging.info('Exiting from training early')
- elapsed = time.time() - start_time
- ###########################################################################
- # Test
- ###########################################################################
- summary = {}
- test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
- if not args.debug and not args.no_eval and os.path.exists(test_path):
- # Load the best saved model.
- checkpoint = load_checkpoint(test_path)
- model.load_state_dict(checkpoint['model_state'])
- # Run on test data.
- test_start_time = time.time()
- test_loss = evaluate(te_iter, model, args)
- test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
- test_elapsed = time.time() - test_start_time
- logging.info('=' * 100)
- if args.dataset in ['enwik8', 'text8']:
- logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
- test_elapsed, test_loss, test_loss / math.log(2)))
- else:
- logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
- test_elapsed, test_loss, math.exp(test_loss)))
- logging.info('=' * 100)
- summary.update({
- 'test_elapsed': test_elapsed,
- 'test_loss': test_loss,
- })
- if args.dataset in ['enwik8', 'text8']:
- summary['test_bits_per_character'] = test_loss / math.log(2)
- else:
- summary['test_perplexity'] = math.exp(test_loss)
- logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
- logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
- if best_val_loss:
- val_perplexity = math.exp(best_val_loss)
- else:
- val_perplexity = None
- summary.update({
- 'train_throughput': meters['train_throughput'].avg,
- 'train_elapsed': elapsed / 60,
- 'valid_loss': best_val_loss,
- 'valid_perplexity': val_perplexity,
- })
- dllogger.log(step=tuple(), data=summary)
- passed = benchmark(
- target_perplexity=args.target_perplexity,
- test_perplexity=val_perplexity,
- target_throughput=args.target_throughput,
- test_throughput=meters['train_throughput'].avg
- )
- if not passed:
- sys.exit(1)
- if __name__ == "__main__":
- # Disable profiling executor
- try:
- torch._C._jit_set_profiling_executor(False)
- torch._C._jit_set_profiling_mode(False)
- except AttributeError:
- pass
- # Before we do anything with models, we want to ensure that we get fp16
- # execution of torch.einsum.
- # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
- # Note that running `--amp_mode O2` will remove the need for this
- # code, but it is still valid.
- amp.register_half_function(torch, 'einsum')
- main()
|