| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060 |
- # coding: utf-8
- # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- import functools
- import itertools
- import logging
- import math
- import os
- import shutil
- import sys
- import time
- import warnings
- import dllogger
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import yaml
- try:
- from apex import amp
- except ModuleNotFoundError:
- warnings.warn('APEX AMP is unavailable')
- from torch.nn.parallel import DistributedDataParallel
- import lamb
- import utils
- from data_utils import get_lm_corpus
- from mem_transformer import MemTransformerLM
- from utils.data_parallel import BalancedDataParallel
- from utils.exp_utils import AverageMeter
- from utils.exp_utils import TimeoutHandler
- from utils.exp_utils import benchmark
- from utils.exp_utils import create_exp_dir
- from utils.exp_utils import l2_promote
- from utils.exp_utils import log_env_info
- from utils.exp_utils import register_ignoring_timeout_handler
- def parse_args():
- parent_parser = argparse.ArgumentParser(
- description='PyTorch Transformer-XL Language Model',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- add_help=False,
- )
- parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
- cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
- cfg_parser.add_argument('--config', default='default')
- cfg_parser.add_argument('--config_file', default=None)
- config_args, _ = cfg_parser.parse_known_args()
- if config_args.config is not None and config_args.config_file is not None:
- with open(config_args.config_file) as f:
- config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
- else:
- config = {}
- general = parser.add_argument_group('general setup')
- general.add_argument('--work_dir', default='LM-TFM', type=str,
- help='Directory for the results')
- general.add_argument('--append_dataset', action='store_true',
- help='Automatically append dataset name to work_dir')
- general.add_argument('--append_time', action='store_true',
- help='Automatically append current time to work_dir')
- general.add_argument('--cuda', action='store_true',
- help='Run training on a GPU using CUDA')
- general.add_argument('--fp16', action='store_true',
- help='Run training in fp16/mixed precision')
- general.add_argument('--restart', type=str, default='',
- help='Restart training from the saved checkpoint')
- general.add_argument('--debug', action='store_true',
- help='Run in debug mode (do not create exp dir)')
- general.add_argument('--log_all_ranks', action='store_true',
- help='Enable logging from all distributed ranks')
- general.add_argument('--dllog_file', type=str, default='train_log.json',
- help='Name of the DLLogger output file')
- general.add_argument('--txtlog_file', type=str, default='train_log.log',
- help='Name of the txt log file')
- general.add_argument('--save_all', action='store_true',
- help='Save all checkpoints')
- general.add_argument('--no_env', action='store_true',
- help='Do not print info on execution env')
- general.add_argument('--no_eval', action='store_true',
- help='Disable model evaluation')
- general.add_argument('--log_interval', type=int, default=10,
- help='Report interval')
- general.add_argument('--target_throughput', type=float, default=None,
- help='Target training throughput (for benchmarking)')
- general.add_argument('--target_perplexity', type=float, default=None,
- help='Target validation perplexity (for benchmarking)')
- general.add_argument('--apex_amp_opt_level', type=str, default='O2',
- choices=['O0', 'O1', 'O2', 'O3'],
- help='Optimization level for apex amp')
- general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
- help='Implementation of automatic mixed precision')
- dataset = parser.add_argument_group('dataset setup')
- dataset.add_argument('--data', type=str, default='../data/wikitext-103',
- help='Location of the data corpus')
- dataset.add_argument('--dataset', type=str, default='wt103',
- choices=['wt103', 'lm1b', 'enwik8', 'text8'],
- help='Dataset name')
- dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
- help='Type of vocabulary')
- model = parser.add_argument_group('model setup')
- model.add_argument('--n_layer', type=int, default=16,
- help='Number of total layers')
- model.add_argument('--n_head', type=int, default=8,
- help='Number of heads')
- model.add_argument('--d_head', type=int, default=64,
- help='Head dimension')
- model.add_argument('--d_embed', type=int, default=-1,
- help='Embedding dimension')
- model.add_argument('--d_model', type=int, default=512,
- help='Model dimension')
- model.add_argument('--d_inner', type=int, default=2048,
- help='Inner dimension in feedforward layer')
- model.add_argument('--dropout', type=float, default=0.1,
- help='Global dropout rate')
- model.add_argument('--dropatt', type=float, default=0.0,
- help='Attention probability dropout rate')
- model.add_argument('--pre_lnorm', action='store_true',
- help='Apply LayerNorm to the input instead of the output')
- model.add_argument('--attn_type', type=int, default=0,
- help='Attention type. 0 for ours, 1 for Shaw et al,'
- '2 for Vaswani et al, 3 for Al Rfou et al.')
- model.add_argument('--not_tied', action='store_true',
- help='Do not tie the word embedding and softmax weights')
- model.add_argument('--clamp_len', type=int, default=-1,
- help='Use the same pos embeddings after clamp_len')
- model.add_argument('--adaptive', action='store_true',
- help='Use adaptive softmax')
- model.add_argument('--div_val', type=int, default=1,
- help='Dividend value for adaptive input and softmax')
- model.add_argument('--sample_softmax', type=int, default=-1,
- help='Number of samples in sampled softmax')
- model.add_argument('--init', default='normal', type=str,
- help='Parameter initializer to use')
- model.add_argument('--emb_init', default='normal', type=str,
- help='Parameter initializer to use')
- model.add_argument('--init_range', type=float, default=0.1,
- help='Parameters initialized by U(-init_range, init_range)')
- model.add_argument('--emb_init_range', type=float, default=0.01,
- help='Parameters initialized by U(-init_range, init_range)')
- model.add_argument('--init_std', type=float, default=0.02,
- help='Parameters initialized by N(0, init_std)')
- model.add_argument('--proj_init_std', type=float, default=0.01,
- help='Parameters initialized by N(0, init_std)')
- opt = parser.add_argument_group('optimizer setup')
- opt.add_argument('--optim', default='jitlamb', type=str,
- choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
- help='Optimizer to use')
- opt.add_argument('--lr', type=float, default=0.01,
- help='Initial learning rate')
- opt.add_argument('--mom', type=float, default=0.0,
- help='Momentum for sgd')
- opt.add_argument('--scheduler', default='cosine', type=str,
- choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
- help='LR scheduler to use')
- opt.add_argument('--max_step_scheduler', type=int, default=None,
- help='Max number of training steps for LR scheduler')
- opt.add_argument('--warmup_step', type=int, default=1000,
- help='Number of iterations for LR warmup')
- opt.add_argument('--decay_rate', type=float, default=0.5,
- help='Decay factor when ReduceLROnPlateau is used')
- opt.add_argument('--lr_min', type=float, default=0.0,
- help='Minimum learning rate during annealing')
- opt.add_argument('--clip', type=float, default=0.25,
- help='Gradient clipping')
- opt.add_argument('--weight_decay', type=float, default=0.0,
- help='Weight decay for adam|lamb')
- opt.add_argument('--clip_nonemb', action='store_true',
- help='Only clip the gradient of non-embedding params')
- opt.add_argument('--patience', type=int, default=0,
- help='Patience')
- opt.add_argument('--eta_min', type=float, default=0.001,
- help='Min learning rate for cosine scheduler')
- training = parser.add_argument_group('training setup')
- training.add_argument('--max_step', type=int, default=40000,
- help='Max number of training steps')
- training.add_argument('--batch_size', type=int, default=256,
- help='Global batch size')
- training.add_argument('--local_batch_size', type=int, default=None,
- help='Local (per-device) batch size, this setting \
- overrides global --batch_size and sets batch_size \
- to local_batch_size * world_size')
- training.add_argument('--batch_chunk', type=int, default=1,
- help='Split batch into chunks and train with '
- 'gradient accumulation')
- training.add_argument('--roll', action='store_true',
- help='Enable random shifts within each data stream')
- training.add_argument('--tgt_len', type=int, default=192,
- help='Number of tokens to predict')
- training.add_argument('--ext_len', type=int, default=0,
- help='Length of the extended context')
- training.add_argument('--mem_len', type=int, default=192,
- help='Length of the retained previous heads')
- training.add_argument('--seed', type=int, default=1111,
- help='Random seed')
- training.add_argument('--multi_gpu', default=None, type=str,
- choices=['ddp', 'dp'],
- help='Use multiple GPU')
- training.add_argument('--gpu0_bsz', type=int, default=-1,
- help='Batch size on gpu 0 (for "dp" backend)')
- training.add_argument('--same_length', action='store_true',
- help='Use the same attn length for all tokens')
- training.add_argument('--varlen', action='store_true',
- help='Use variable length')
- training.add_argument('--swap_mem', action='store_true',
- help='Swap memory tensors to cpu')
- val = parser.add_argument_group('validation setup')
- val.add_argument('--eval_tgt_len', type=int, default=192,
- help='Number of tokens to predict for evaluation')
- val.add_argument('--eval_batch_size', type=int, default=16,
- help='Eval batch size')
- val.add_argument('--eval_max_steps', type=int, default=-1,
- help='Max eval steps')
- val.add_argument('--eval_interval', type=int, default=5000,
- help='Evaluation interval')
- dist = parser.add_argument_group('distributed setup')
- dist.add_argument('--local_rank', type=int,
- default=os.getenv('LOCAL_RANK', 0),
- help='Used for multi-process training.')
- parser.set_defaults(**config)
- args, _ = parser.parse_known_args()
- args.tied = not args.not_tied
- if args.d_embed < 0:
- args.d_embed = args.d_model
- if args.ext_len < 0:
- raise RuntimeError('Extended context length must be non-negative')
- if args.batch_size % args.batch_chunk != 0:
- raise RuntimeError('Batch size needs to be divisible by batch chunk')
- if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
- raise RuntimeError(
- 'APEX AMP unavailable, install APEX or switch to pytorch AMP'
- )
- return args
- def save_checkpoint(args, model, model_config, optimizer, scheduler, scaler,
- vocab, epoch, batch, last_iter, train_step, best_val_loss,
- is_best, work_dir):
- if args.fp16:
- if args.amp == 'pytorch':
- amp_state = scaler.state_dict()
- elif args.amp == 'apex':
- amp_state = amp.state_dict()
- else:
- amp_state = None
- state = {
- 'args': args,
- 'model_config': model_config,
- 'model_state': model.state_dict(),
- 'optimizer_state': optimizer.state_dict(),
- 'scheduler_state': scheduler.state_dict(),
- 'vocab': vocab,
- 'amp_state': amp_state,
- 'epoch': epoch,
- 'batch': batch,
- 'last_iter': last_iter,
- 'train_step': train_step,
- 'best_val_loss': best_val_loss,
- }
- last_chkpt_fname = 'checkpoint_last.pt'
- with utils.distributed.sync_workers() as rank:
- last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
- if rank == 0:
- # always save last checkpoint
- logging.info(f'Saving checkpoint to {last_chkpt_path}')
- torch.save(state, last_chkpt_path)
- # save best checkpoint if better than previous best
- if is_best:
- best_chkpt_fname = 'checkpoint_best.pt'
- best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
- logging.info(f'Saving checkpoint to {best_chkpt_path}')
- shutil.copy(last_chkpt_path, best_chkpt_path)
- # save every checkpoint if save_all is true
- if args.save_all:
- step_chkpt_fname = f'checkpoint_{train_step}.pt'
- step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
- logging.info(f'Saving checkpoint to {step_chkpt_path}')
- shutil.copy(last_chkpt_path, step_chkpt_path)
- def load_checkpoint(path):
- if os.path.isdir(path):
- path = os.path.join(path, 'checkpoint_last.pt')
- dst = f'cuda:{torch.cuda.current_device()}'
- logging.info(f'Loading checkpoint from {path}')
- checkpoint = torch.load(path, map_location=dst)
- return checkpoint
- def init_weight(weight, args):
- if args.init == 'uniform':
- nn.init.uniform_(weight, -args.init_range, args.init_range)
- elif args.init == 'normal':
- nn.init.normal_(weight, 0.0, args.init_std)
- def init_bias(bias):
- nn.init.constant_(bias, 0.0)
- def weights_init(m, args):
- classname = m.__class__.__name__
- if classname.find('Linear') != -1:
- if hasattr(m, 'weight') and m.weight is not None:
- init_weight(m.weight, args)
- if hasattr(m, 'bias') and m.bias is not None:
- init_bias(m.bias)
- elif classname.find('AdaptiveEmbedding') != -1:
- if hasattr(m, 'emb_projs'):
- for i in range(len(m.emb_projs)):
- if m.emb_projs[i] is not None:
- nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
- elif classname.find('Embedding') != -1:
- if hasattr(m, 'weight'):
- init_weight(m.weight, args)
- elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
- if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
- init_weight(m.cluster_weight, args)
- if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
- init_bias(m.cluster_bias)
- if hasattr(m, 'out_projs'):
- for i in range(len(m.out_projs)):
- if m.out_projs[i] is not None:
- nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
- if hasattr(m, 'out_layers_weights'):
- for i in range(len(m.out_layers_weights)):
- if m.out_layers_weights[i] is not None:
- init_weight(m.out_layers_weights[i], args)
- elif classname.find('LayerNorm') != -1:
- if hasattr(m, 'weight'):
- nn.init.normal_(m.weight, 1.0, args.init_std)
- if hasattr(m, 'bias') and m.bias is not None:
- init_bias(m.bias)
- elif classname.find('TransformerLM') != -1:
- if hasattr(m, 'r_emb'):
- init_weight(m.r_emb, args)
- if hasattr(m, 'r_w_bias'):
- init_weight(m.r_w_bias, args)
- if hasattr(m, 'r_r_bias'):
- init_weight(m.r_r_bias, args)
- if hasattr(m, 'r_bias'):
- init_bias(m.r_bias)
- def update_dropout(m, args):
- classname = m.__class__.__name__
- if classname.find('Dropout') != -1:
- if hasattr(m, 'p'):
- m.p = args.dropout
- def update_dropatt(m, args):
- if hasattr(m, 'dropatt'):
- m.dropatt.p = args.dropatt
- def evaluate(eval_iter, model, args):
- # Turn on evaluation mode which disables dropout.
- model.eval()
- # If the model does not use memory at all, make the ext_len longer.
- # Otherwise, make the mem_len longer and keep the ext_len the same.
- if args.mem_len == 0:
- model.reset_length(tgt_len=args.eval_tgt_len,
- ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
- mem_len=args.mem_len
- )
- else:
- model.reset_length(tgt_len=args.eval_tgt_len,
- ext_len=args.ext_len,
- mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
- )
- # Evaluation
- total_len, total_loss = 0, 0.
- with torch.no_grad():
- mems = None
- for i, (data, target, seq_len, warm) in enumerate(eval_iter):
- if args.eval_max_steps > 0 and i >= args.eval_max_steps:
- break
- loss, mems = model(data, target, mems)
- loss = loss.float().mean()
- if warm:
- assert (mems is None) or mems.size(1) == model.mem_len
- total_loss += seq_len * loss.item()
- total_len += seq_len
- # Switch back to the training mode
- model.reset_length(tgt_len=args.tgt_len,
- ext_len=args.ext_len,
- mem_len=args.mem_len
- )
- model.train()
- return total_loss / total_len
- def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
- optimizer, device, args):
- cpu = torch.device('cpu')
- data_i = data_chunks[i].contiguous()
- target_i = target_chunks[i].contiguous()
- if args.swap_mem and mems[i] is not None:
- mems[i] = mems[i].to(device, non_blocking=True)
- enable_autocast = args.fp16 and args.amp == 'pytorch'
- with torch.cuda.amp.autocast(enable_autocast):
- loss, mems[i] = model(data_i, target_i, mems[i])
- loss = loss.float().mean().type_as(loss) / args.batch_chunk
- if args.swap_mem and mems[i] is not None:
- mems[i] = mems[i].to(cpu, non_blocking=True)
- if args.fp16:
- if args.amp == 'pytorch':
- scaler.scale(loss).backward()
- elif args.amp == 'apex':
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- else:
- loss.backward()
- train_loss = loss.float().item()
- return train_loss
- def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
- optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch,
- last_batch, last_iter, train_step, best_val_loss, meters,
- timeout_handler, device, args):
- # Turn on training mode which enables dropout.
- model.train()
- train_loss = 0
- target_tokens = 0
- log_step = 0
- log_start_time = time.time()
- mems = [None for _ in range(args.batch_chunk)]
- if args.varlen:
- train_iter = tr_iter.get_varlen_iter(start=last_iter)
- else:
- train_iter = tr_iter.get_fixlen_iter(start=last_iter)
- for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
- log_step += 1
- target_tokens += target.numel()
- for param in model.parameters():
- param.grad = None
- data_chunks = torch.chunk(data, args.batch_chunk, 1)
- target_chunks = torch.chunk(target, args.batch_chunk, 1)
- for i in range(args.batch_chunk):
- if i < args.batch_chunk - 1 and isinstance(para_model, DistributedDataParallel):
- with para_model.no_sync():
- train_loss_chunk = train_iteration(
- para_model, i, mems, data_chunks, target_chunks, scaler,
- optimizer, device, args
- )
- else:
- train_loss_chunk = train_iteration(
- para_model, i, mems, data_chunks, target_chunks, scaler,
- optimizer, device, args
- )
- train_loss += train_loss_chunk
- if args.fp16:
- if args.amp == 'pytorch':
- scaler.unscale_(optimizer)
- torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
- elif args.amp == 'apex':
- torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
- else:
- torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
- if args.fp16 and args.amp == 'pytorch':
- scaler.step(optimizer)
- scaler.update()
- else:
- optimizer.step()
- if optimizer_sparse:
- optimizer_sparse.step()
- # step-wise learning rate annealing
- train_step += 1
- if args.scheduler in ['cosine', 'constant', 'dev_perf']:
- # linear warmup stage
- if train_step < args.warmup_step:
- curr_lr = args.lr * train_step / args.warmup_step
- optimizer.param_groups[0]['lr'] = curr_lr
- if optimizer_sparse:
- optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
- else:
- if args.scheduler == 'cosine':
- scheduler.step(train_step - args.warmup_step)
- if scheduler_sparse:
- scheduler_sparse.step(train_step - args.warmup_step)
- elif args.scheduler == 'inv_sqrt':
- scheduler.step(train_step)
- if scheduler_sparse:
- scheduler_sparse.step(train_step)
- if train_step % args.log_interval == 0:
- cur_loss = train_loss / log_step
- cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
- train_loss = 0
- elapsed = time.time() - log_start_time
- avg_elapsed = elapsed / log_step
- avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
- log_start_time = time.time()
- log_step = 0
- lr = optimizer.param_groups[0]['lr']
- throughput = target_tokens / elapsed
- throughput = utils.distributed.all_reduce_item(throughput, op='sum')
- meters['train_throughput'].update(throughput)
- target_tokens = 0
- log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
- '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
- epoch,
- train_step,
- batch,
- tr_iter.n_batch,
- lr,
- avg_elapsed * 1000,
- throughput,
- cur_loss,
- )
- dllogger_data = {
- 'epoch': epoch,
- 'train_batch': batch+1,
- 'lr': lr,
- 'train_time/batch': avg_elapsed * 1000,
- 'train_throughput': throughput,
- 'train_loss': cur_loss,
- }
- if args.dataset in ['enwik8', 'text8']:
- log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
- dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
- else:
- log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
- dllogger_data['train_perplexity'] = math.exp(cur_loss)
- logging.info(log_str)
- dllogger.log(step=tuple([train_step]), data=dllogger_data)
- do_periodic_eval = train_step % args.eval_interval == 0
- is_final_step = train_step == args.max_step
- interrupted = timeout_handler.interrupted
- if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
- eval_start_time = time.time()
- val_loss = evaluate(va_iter, model, args)
- val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')
- logging.info('-' * 100)
- log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
- '| valid loss {:5.2f}'.format(
- train_step // args.eval_interval,
- train_step,
- (time.time() - eval_start_time),
- val_loss,
- )
- dllogger_data = {
- 'valid_elapsed': (time.time() - eval_start_time),
- 'valid_loss': val_loss,
- }
- if args.dataset in ['enwik8', 'text8']:
- log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
- dllogger_data['valid_bits_per_character'] = val_loss / math.log(2)
- else:
- log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
- dllogger_data['valid_perplexity'] = math.exp(val_loss)
- logging.info(log_str)
- logging.info('-' * 100)
- dllogger.log(step=tuple([train_step]), data=dllogger_data)
- last_iter = tr_iter.last_iter
- # Check if the validation loss is the best we've seen so far.
- is_best = False
- if not best_val_loss or val_loss < best_val_loss:
- best_val_loss = val_loss
- is_best = True
- if not args.debug:
- save_checkpoint(args, model, model_config, optimizer, scheduler,
- scaler, vocab, epoch, batch, last_iter,
- train_step, best_val_loss, is_best,
- args.work_dir)
- # dev-performance based learning rate annealing
- if args.scheduler == 'dev_perf':
- scheduler.step(val_loss)
- if scheduler_sparse:
- scheduler_sparse.step(val_loss)
- # subtract eval time from timers for training
- log_start_time += time.time() - eval_start_time
- if interrupted:
- logging.info(f'Received SIGTERM, exiting')
- sys.exit(0)
- if is_final_step:
- break
- return train_step, best_val_loss
- def main():
- args = parse_args()
- utils.gpu_affinity.set_affinity(args.local_rank)
- # Initialize device and distributed backend
- torch.cuda.set_device(args.local_rank)
- l2_promote()
- device = torch.device('cuda' if args.cuda else 'cpu')
- utils.distributed.init_distributed(args.cuda)
- args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
- args.dataset,
- args.append_dataset,
- args.append_time,
- )
- with utils.distributed.sync_workers() as rank:
- if rank == 0:
- create_exp_dir(args.work_dir,
- scripts_to_save=['train.py', 'mem_transformer.py'],
- debug=args.debug)
- # Setup logging
- if args.log_all_ranks:
- log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
- else:
- log_file = args.txtlog_file
- dllog_file = args.dllog_file
- log_file = os.path.join(args.work_dir, log_file)
- dllog_file = os.path.join(args.work_dir, dllog_file)
- if args.debug:
- log_file = os.devnull
- dllog_file = os.devnull
- utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
- filename=log_file,
- )
- utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)
- if args.local_batch_size is not None:
- world_size = utils.distributed.get_world_size()
- args.batch_size = world_size * args.local_batch_size
- logging.info(f'--local_batch_size was set, adjusting global batch size'
- f' to {args.batch_size} (local_batch_size * world_size)')
- logging.info(args)
- dllogger.log(step='PARAMETER', data=vars(args))
- logging.info(f'world size: {utils.distributed.get_world_size()}')
- if not args.no_env:
- log_env_info()
- register_ignoring_timeout_handler()
- # Set the random seed manually for reproducibility.
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
- ###########################################################################
- # Load data
- ###########################################################################
- corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
- ntokens = len(corpus.vocab)
- vocab = corpus.vocab
- args.n_token = ntokens
- if args.mem_len == 0:
- eval_mem_len = 0
- else:
- eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len
- tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
- device=device, ext_len=args.ext_len)
- va_iter = corpus.get_iterator('valid', args.eval_batch_size,
- args.eval_tgt_len, device=device,
- mem_len=eval_mem_len, ext_len=args.ext_len)
- te_iter = corpus.get_iterator('test', args.eval_batch_size,
- args.eval_tgt_len, device=device,
- mem_len=eval_mem_len, ext_len=args.ext_len)
- # adaptive softmax / embedding
- cutoffs, tie_projs = [], [False]
- if args.adaptive:
- assert args.dataset in ['wt103', 'lm1b']
- if args.dataset == 'wt103':
- cutoffs = [19997, 39997, 199997]
- tie_projs += [True] * len(cutoffs)
- elif args.dataset == 'lm1b':
- cutoffs = [59997, 99997, 639997]
- tie_projs += [False] * len(cutoffs)
- ###########################################################################
- # Build the model
- ###########################################################################
- model_config = {
- 'n_token': ntokens,
- 'n_layer': args.n_layer,
- 'n_head': args.n_head,
- 'd_model': args.d_model,
- 'd_head': args.d_head,
- 'd_inner': args.d_inner,
- 'dropout': args.dropout,
- 'dropatt': args.dropatt,
- 'dtype': None,
- 'tie_weight': args.tied,
- 'd_embed': args.d_embed,
- 'div_val': args.div_val,
- 'tie_projs': tie_projs,
- 'pre_lnorm': args.pre_lnorm,
- 'tgt_len': args.tgt_len,
- 'ext_len': args.ext_len,
- 'mem_len': args.mem_len,
- 'cutoffs': cutoffs,
- 'same_length': args.same_length,
- 'attn_type': args.attn_type,
- 'clamp_len': args.clamp_len,
- 'sample_softmax': args.sample_softmax,
- }
- model = MemTransformerLM(**model_config)
- model.apply(functools.partial(weights_init, args=args))
- # ensure embedding init is not overridden by out_layer in case of weight sharing
- model.word_emb.apply(functools.partial(weights_init, args=args))
- args.n_all_param = sum([p.nelement() for p in model.parameters()])
- args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
- # optimizer
- if args.optim.lower() == 'sgd':
- if args.sample_softmax > 0:
- dense_params, sparse_params = [], []
- for param in model.parameters():
- if param.size() == model.word_emb.weight.size():
- sparse_params.append(param)
- else:
- dense_params.append(param)
- optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
- optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
- else:
- optimizer = optim.SGD(model.parameters(), lr=args.lr,
- momentum=args.mom)
- optimizer_sparse = None
- elif args.optim.lower() == 'adam':
- if args.sample_softmax > 0:
- dense_params, sparse_params = [], []
- for param in model.parameters():
- if param.size() == model.word_emb.weight.size():
- sparse_params.append(param)
- else:
- dense_params.append(param)
- optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
- optimizer = optim.Adam(dense_params, lr=args.lr,
- weight_decay=args.weight_decay)
- else:
- optimizer = optim.Adam(model.parameters(), lr=args.lr,
- weight_decay=args.weight_decay)
- optimizer_sparse = None
- elif args.optim.lower() == 'adagrad':
- optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
- optimizer_sparse = None
- elif args.optim.lower() == 'lamb':
- optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
- weight_decay=args.weight_decay)
- optimizer_sparse = None
- elif args.optim.lower() == 'jitlamb':
- optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
- weight_decay=args.weight_decay)
- optimizer_sparse = None
- model = model.to(device)
- scaler = None
- if args.fp16:
- if args.amp == 'pytorch':
- scaler = torch.cuda.amp.GradScaler()
- elif args.amp == 'apex':
- model, optimizer = amp.initialize(
- model,
- optimizer,
- opt_level=args.apex_amp_opt_level,
- )
- if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
- para_model = DistributedDataParallel(model,
- device_ids=[args.local_rank],
- output_device=args.local_rank,
- broadcast_buffers=False,
- find_unused_parameters=True,
- )
- elif args.multi_gpu == 'dp':
- if args.gpu0_bsz >= 0:
- para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
- model, dim=1).to(device)
- else:
- para_model = nn.DataParallel(model, dim=1).to(device)
- else:
- para_model = model
- # scheduler
- if args.scheduler == 'cosine':
- if args.max_step_scheduler:
- max_step = args.max_step_scheduler
- else:
- max_step = args.max_step
- scheduler = optim.lr_scheduler.CosineAnnealingLR(
- optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
- if args.sample_softmax > 0 and optimizer_sparse is not None:
- scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
- optimizer_sparse, max_step - args.warmup_step,
- eta_min=args.eta_min)
- else:
- scheduler_sparse = None
- elif args.scheduler == 'inv_sqrt':
- # originally used for Transformer (in Attention is all you need)
- def lr_lambda(step):
- # return a multiplier instead of a learning rate
- if step == 0 and args.warmup_step == 0:
- return 1.
- else:
- return 1. / (step ** 0.5) if step > args.warmup_step \
- else step / (args.warmup_step ** 1.5)
- scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
- if args.sample_softmax > 0 and optimizer_sparse is not None:
- scheduler_sparse = optim.lr_scheduler.LambdaLR(
- optimizer_sparse,
- lr_lambda=lr_lambda
- )
- else:
- scheduler_sparse = None
- elif args.scheduler == 'dev_perf':
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(
- optimizer, factor=args.decay_rate, patience=args.patience,
- min_lr=args.lr_min,
- )
- if args.sample_softmax > 0 and optimizer_sparse is not None:
- scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
- optimizer_sparse, factor=args.decay_rate, patience=args.patience,
- min_lr=args.lr_min,
- )
- else:
- scheduler_sparse = None
- elif args.scheduler == 'constant':
- pass
- logging.info('=' * 100)
- for k, v in args.__dict__.items():
- logging.info(' - {} : {}'.format(k, v))
- logging.info('=' * 100)
- logging.info('#params = {}'.format(args.n_all_param))
- logging.info('#non emb params = {}'.format(args.n_nonemb_param))
- train_step = 0
- start_epoch = 1
- last_batch = 0
- last_iter = 0
- best_val_loss = None
- if args.restart:
- try:
- checkpoint = load_checkpoint(args.restart)
- model.load_state_dict(checkpoint['model_state'])
- optimizer.load_state_dict(checkpoint['optimizer_state'])
- scheduler.load_state_dict(checkpoint['scheduler_state'])
- if args.fp16:
- if args.amp == 'pytorch':
- scaler.load_state_dict(checkpoint['amp_state'])
- elif args.amp == 'apex':
- amp.load_state_dict(checkpoint['amp_state'])
- train_step = checkpoint['train_step']
- start_epoch = checkpoint['epoch']
- last_batch = checkpoint['batch']
- last_iter = checkpoint['last_iter']
- best_val_loss = checkpoint['best_val_loss']
- if train_step >= args.max_step:
- logging.info(f'Loaded checkpoint after {train_step} steps, but '
- f'this run was scheduled for a total of '
- f'{args.max_step} steps, exiting')
- sys.exit(1)
- model.apply(functools.partial(update_dropout, args=args))
- model.apply(functools.partial(update_dropatt, args=args))
- except FileNotFoundError:
- logging.info(f'Could not load checkpoint from {args.restart}, '
- f'starting training from random init')
- meters = {}
- warmup = args.mem_len // args.tgt_len + 2
- meters['train_throughput'] = AverageMeter(warmup=warmup)
- ###########################################################################
- # Train
- ###########################################################################
- # Loop over epochs.
- # At any point you can hit Ctrl + C to break out of training early.
- start_time = time.time()
- with TimeoutHandler() as timeout_handler:
- try:
- for epoch in itertools.count(start=start_epoch):
- if args.roll:
- tr_iter.roll(seed=args.seed + epoch)
- train_step, best_val_loss = train(
- tr_iter, va_iter, model, para_model, model_config,
- optimizer, optimizer_sparse, scheduler, scheduler_sparse,
- scaler, vocab, epoch, last_batch, last_iter, train_step,
- best_val_loss, meters, timeout_handler, device, args
- )
- last_batch = 0
- last_iter = 0
- if train_step == args.max_step:
- logging.info('-' * 100)
- logging.info('End of training')
- break
- except KeyboardInterrupt:
- logging.info('-' * 100)
- logging.info('Exiting from training early')
- elapsed = time.time() - start_time
- ###########################################################################
- # Test
- ###########################################################################
- summary = {}
- test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
- if not args.debug and not args.no_eval and os.path.exists(test_path):
- # Load the best saved model.
- checkpoint = load_checkpoint(test_path)
- model.load_state_dict(checkpoint['model_state'])
- # Run on test data.
- test_start_time = time.time()
- test_loss = evaluate(te_iter, model, args)
- test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')
- test_elapsed = time.time() - test_start_time
- logging.info('=' * 100)
- if args.dataset in ['enwik8', 'text8']:
- logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
- test_elapsed, test_loss, test_loss / math.log(2)))
- else:
- logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
- test_elapsed, test_loss, math.exp(test_loss)))
- logging.info('=' * 100)
- summary.update({
- 'test_elapsed': test_elapsed,
- 'test_loss': test_loss,
- })
- if args.dataset in ['enwik8', 'text8']:
- summary['test_bits_per_character'] = test_loss / math.log(2)
- else:
- summary['test_perplexity'] = math.exp(test_loss)
- logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
- logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')
- if best_val_loss:
- val_perplexity = math.exp(best_val_loss)
- else:
- val_perplexity = None
- summary.update({
- 'train_throughput': meters['train_throughput'].avg,
- 'train_elapsed': elapsed / 60,
- 'valid_loss': best_val_loss,
- 'valid_perplexity': val_perplexity,
- })
- dllogger.log(step=tuple(), data=summary)
- passed = benchmark(
- target_perplexity=args.target_perplexity,
- test_perplexity=val_perplexity,
- target_throughput=args.target_throughput,
- test_throughput=meters['train_throughput'].avg
- )
- if not passed:
- sys.exit(1)
- if __name__ == "__main__":
- # Disable profiling executor
- try:
- torch._C._jit_set_profiling_executor(False)
- torch._C._jit_set_profiling_mode(False)
- except AttributeError:
- pass
- # Before we do anything with models, we want to ensure that we get fp16
- # execution of torch.einsum in APEX AMP.
- # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
- # Note that running `--apex_amp_opt_level O2` will remove the need for this
- # code, but it is still valid.
- if 'apex' in sys.modules:
- amp.register_half_function(torch, 'einsum')
- main()
|