|
|
@@ -436,7 +436,7 @@ def evaluate(eval_iter, model, args):
|
|
|
|
|
|
|
|
|
def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
|
|
|
- optimizer, device, args):
|
|
|
+ optimizer, device, delay_unscale, args):
|
|
|
cpu = torch.device('cpu')
|
|
|
data_i = data_chunks[i].contiguous()
|
|
|
target_i = target_chunks[i].contiguous()
|
|
|
@@ -456,7 +456,7 @@ def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
|
|
|
if args.amp == 'pytorch':
|
|
|
scaler.scale(loss).backward()
|
|
|
elif args.amp == 'apex':
|
|
|
- with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
+ with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss:
|
|
|
scaled_loss.backward()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
@@ -498,12 +498,12 @@ def train(tr_iter, va_iter, model, para_model, model_config, optimizer,
|
|
|
with para_model.no_sync():
|
|
|
train_loss_chunk = train_iteration(
|
|
|
para_model, i, mems, data_chunks, target_chunks, scaler,
|
|
|
- optimizer, device, args
|
|
|
+ optimizer, device, True, args
|
|
|
)
|
|
|
else:
|
|
|
train_loss_chunk = train_iteration(
|
|
|
para_model, i, mems, data_chunks, target_chunks, scaler,
|
|
|
- optimizer, device, args
|
|
|
+ optimizer, device, False, args
|
|
|
)
|
|
|
|
|
|
train_loss += train_loss_chunk
|