|
|
@@ -1,6 +1,6 @@
|
|
|
# coding: utf-8
|
|
|
|
|
|
-# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
|
|
|
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
|
|
|
#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
@@ -23,6 +23,7 @@ import os
|
|
|
import shutil
|
|
|
import sys
|
|
|
import time
|
|
|
+import warnings
|
|
|
|
|
|
import dllogger
|
|
|
import numpy as np
|
|
|
@@ -30,7 +31,11 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import yaml
|
|
|
-from apex import amp
|
|
|
+try:
|
|
|
+ from apex import amp
|
|
|
+except ModuleNotFoundError:
|
|
|
+ warnings.warn('APEX AMP is unavailable')
|
|
|
+
|
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
|
|
import lamb
|
|
|
@@ -101,9 +106,11 @@ def parse_args():
|
|
|
help='Target training throughput (for benchmarking)')
|
|
|
general.add_argument('--target_perplexity', type=float, default=None,
|
|
|
help='Target validation perplexity (for benchmarking)')
|
|
|
- general.add_argument('--amp_mode', type=str, default='O2',
|
|
|
+ general.add_argument('--apex_amp_opt_level', type=str, default='O2',
|
|
|
choices=['O0', 'O1', 'O2', 'O3'],
|
|
|
help='Optimization level for apex amp')
|
|
|
+ general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
|
|
|
+ help='Implementation of automatic mixed precision')
|
|
|
|
|
|
dataset = parser.add_argument_group('dataset setup')
|
|
|
dataset.add_argument('--data', type=str, default='../data/wikitext-103',
|
|
|
@@ -220,6 +227,8 @@ def parse_args():
|
|
|
help='Use the same attn length for all tokens')
|
|
|
training.add_argument('--varlen', action='store_true',
|
|
|
help='Use variable length')
|
|
|
+ training.add_argument('--swap_mem', action='store_true',
|
|
|
+ help='Swap memory tensors to cpu')
|
|
|
|
|
|
val = parser.add_argument_group('validation setup')
|
|
|
val.add_argument('--eval_tgt_len', type=int, default=192,
|
|
|
@@ -244,17 +253,28 @@ def parse_args():
|
|
|
if args.d_embed < 0:
|
|
|
args.d_embed = args.d_model
|
|
|
|
|
|
- assert args.ext_len >= 0, 'extended context length must be non-negative'
|
|
|
- assert args.batch_size % args.batch_chunk == 0
|
|
|
+ if args.ext_len < 0:
|
|
|
+ raise RuntimeError('Extended context length must be non-negative')
|
|
|
+
|
|
|
+ if args.batch_size % args.batch_chunk != 0:
|
|
|
+ raise RuntimeError('Batch size needs to be divisible by batch chunk')
|
|
|
+
|
|
|
+ if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
|
|
|
+ raise RuntimeError(
|
|
|
+ 'APEX AMP unavailable, install APEX or switch to pytorch AMP'
|
|
|
+ )
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
-def save_checkpoint(args, model, model_config, optimizer, scheduler, vocab,
|
|
|
- epoch, batch, last_iter, train_step, best_val_loss,
|
|
|
+def save_checkpoint(args, model, model_config, optimizer, scheduler, scaler,
|
|
|
+ vocab, epoch, batch, last_iter, train_step, best_val_loss,
|
|
|
is_best, work_dir):
|
|
|
if args.fp16:
|
|
|
- amp_state = amp.state_dict()
|
|
|
+ if args.amp == 'pytorch':
|
|
|
+ amp_state = scaler.state_dict()
|
|
|
+ elif args.amp == 'apex':
|
|
|
+ amp_state = amp.state_dict()
|
|
|
else:
|
|
|
amp_state = None
|
|
|
|
|
|
@@ -415,10 +435,40 @@ def evaluate(eval_iter, model, args):
|
|
|
return total_loss / total_len
|
|
|
|
|
|
|
|
|
+def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
|
|
|
+ optimizer, device, args):
|
|
|
+ cpu = torch.device('cpu')
|
|
|
+ data_i = data_chunks[i].contiguous()
|
|
|
+ target_i = target_chunks[i].contiguous()
|
|
|
+
|
|
|
+ if args.swap_mem and mems[i] is not None:
|
|
|
+ mems[i] = mems[i].to(device, non_blocking=True)
|
|
|
+
|
|
|
+ enable_autocast = args.fp16 and args.amp == 'pytorch'
|
|
|
+ with torch.cuda.amp.autocast(enable_autocast):
|
|
|
+ loss, mems[i] = model(data_i, target_i, mems[i])
|
|
|
+ loss = loss.float().mean().type_as(loss) / args.batch_chunk
|
|
|
+
|
|
|
+ if args.swap_mem and mems[i] is not None:
|
|
|
+ mems[i] = mems[i].to(cpu, non_blocking=True)
|
|
|
+
|
|
|
+ if args.fp16:
|
|
|
+ if args.amp == 'pytorch':
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ elif args.amp == 'apex':
|
|
|
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
+ scaled_loss.backward()
|
|
|
+ else:
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ train_loss = loss.float().item()
|
|
|
+ return train_loss
|
|
|
+
|
|
|
+
|
|
|
def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
- optimizer_sparse, scheduler, scheduler_sparse, vocab, epoch,
|
|
|
+ optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch,
|
|
|
last_batch, last_iter, train_step, best_val_loss, meters,
|
|
|
- timeout_handler, args):
|
|
|
+ timeout_handler, device, args):
|
|
|
# Turn on training mode which enables dropout.
|
|
|
model.train()
|
|
|
|
|
|
@@ -444,27 +494,36 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
target_chunks = torch.chunk(target, args.batch_chunk, 1)
|
|
|
|
|
|
for i in range(args.batch_chunk):
|
|
|
- data_i = data_chunks[i].contiguous()
|
|
|
- target_i = target_chunks[i].contiguous()
|
|
|
- loss, mems[i] = para_model(data_i, target_i, mems[i])
|
|
|
- loss = loss.float().mean().type_as(loss) / args.batch_chunk
|
|
|
-
|
|
|
- if args.fp16:
|
|
|
- with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
- scaled_loss.backward()
|
|
|
+ if i < args.batch_chunk - 1 and isinstance(para_model, DistributedDataParallel):
|
|
|
+ with para_model.no_sync():
|
|
|
+ train_loss_chunk = train_iteration(
|
|
|
+ para_model, i, mems, data_chunks, target_chunks, scaler,
|
|
|
+ optimizer, device, args
|
|
|
+ )
|
|
|
else:
|
|
|
- loss.backward()
|
|
|
+ train_loss_chunk = train_iteration(
|
|
|
+ para_model, i, mems, data_chunks, target_chunks, scaler,
|
|
|
+ optimizer, device, args
|
|
|
+ )
|
|
|
|
|
|
- train_loss += loss.float().item()
|
|
|
+ train_loss += train_loss_chunk
|
|
|
|
|
|
if args.fp16:
|
|
|
- torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
|
|
|
+ if args.amp == 'pytorch':
|
|
|
+ scaler.unscale_(optimizer)
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
|
|
+ elif args.amp == 'apex':
|
|
|
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
|
|
|
else:
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
|
|
|
|
|
- optimizer.step()
|
|
|
- if optimizer_sparse:
|
|
|
- optimizer_sparse.step()
|
|
|
+ if args.fp16 and args.amp == 'pytorch':
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ else:
|
|
|
+ optimizer.step()
|
|
|
+ if optimizer_sparse:
|
|
|
+ optimizer_sparse.step()
|
|
|
|
|
|
# step-wise learning rate annealing
|
|
|
train_step += 1
|
|
|
@@ -575,8 +634,8 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
is_best = True
|
|
|
|
|
|
if not args.debug:
|
|
|
- save_checkpoint(args, model, model_config, optimizer,
|
|
|
- scheduler, vocab, epoch, batch, last_iter,
|
|
|
+ save_checkpoint(args, model, model_config, optimizer, scheduler,
|
|
|
+ scaler, vocab, epoch, batch, last_iter,
|
|
|
train_step, best_val_loss, is_best,
|
|
|
args.work_dir)
|
|
|
|
|
|
@@ -772,12 +831,16 @@ def main():
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
+ scaler = None
|
|
|
if args.fp16:
|
|
|
- model, optimizer = amp.initialize(
|
|
|
- model,
|
|
|
- optimizer,
|
|
|
- opt_level=args.amp_mode,
|
|
|
- )
|
|
|
+ if args.amp == 'pytorch':
|
|
|
+ scaler = torch.cuda.amp.GradScaler()
|
|
|
+ elif args.amp == 'apex':
|
|
|
+ model, optimizer = amp.initialize(
|
|
|
+ model,
|
|
|
+ optimizer,
|
|
|
+ opt_level=args.apex_amp_opt_level,
|
|
|
+ )
|
|
|
|
|
|
if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
|
|
|
para_model = DistributedDataParallel(model,
|
|
|
@@ -862,7 +925,10 @@ def main():
|
|
|
optimizer.load_state_dict(checkpoint['optimizer_state'])
|
|
|
scheduler.load_state_dict(checkpoint['scheduler_state'])
|
|
|
if args.fp16:
|
|
|
- amp.load_state_dict(checkpoint['amp_state'])
|
|
|
+ if args.amp == 'pytorch':
|
|
|
+ scaler.load_state_dict(checkpoint['amp_state'])
|
|
|
+ elif args.amp == 'apex':
|
|
|
+ amp.load_state_dict(checkpoint['amp_state'])
|
|
|
train_step = checkpoint['train_step']
|
|
|
start_epoch = checkpoint['epoch']
|
|
|
last_batch = checkpoint['batch']
|
|
|
@@ -871,8 +937,8 @@ def main():
|
|
|
|
|
|
if train_step >= args.max_step:
|
|
|
logging.info(f'Loaded checkpoint after {train_step} steps, but '
|
|
|
- f'this run was scheduled for a total of '
|
|
|
- f'{args.max_step} steps, exiting')
|
|
|
+ f'this run was scheduled for a total of '
|
|
|
+ f'{args.max_step} steps, exiting')
|
|
|
sys.exit(1)
|
|
|
|
|
|
model.apply(functools.partial(update_dropout, args=args))
|
|
|
@@ -898,8 +964,8 @@ def main():
|
|
|
train_step, best_val_loss = train(
|
|
|
tr_iter, va_iter, model, para_model, model_config,
|
|
|
optimizer, optimizer_sparse, scheduler, scheduler_sparse,
|
|
|
- vocab, epoch, last_batch, last_iter, train_step,
|
|
|
- best_val_loss, meters, timeout_handler, args
|
|
|
+ scaler, vocab, epoch, last_batch, last_iter, train_step,
|
|
|
+ best_val_loss, meters, timeout_handler, device, args
|
|
|
)
|
|
|
|
|
|
last_batch = 0
|
|
|
@@ -984,9 +1050,11 @@ if __name__ == "__main__":
|
|
|
pass
|
|
|
|
|
|
# Before we do anything with models, we want to ensure that we get fp16
|
|
|
- # execution of torch.einsum.
|
|
|
+ # execution of torch.einsum in APEX AMP.
|
|
|
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
|
|
|
- # Note that running `--amp_mode O2` will remove the need for this
|
|
|
+ # Note that running `--apex_amp_opt_level O2` will remove the need for this
|
|
|
# code, but it is still valid.
|
|
|
- amp.register_half_function(torch, 'einsum')
|
|
|
+ if 'apex' in sys.modules:
|
|
|
+ amp.register_half_function(torch, 'einsum')
|
|
|
+
|
|
|
main()
|