|
@@ -47,9 +47,6 @@ from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_
|
|
|
|
|
|
|
|
import dllogger
|
|
import dllogger
|
|
|
|
|
|
|
|
-from apex.parallel import DistributedDataParallel as DDP
|
|
|
|
|
-from apex import amp
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
def synchronized_timestamp():
|
|
def synchronized_timestamp():
|
|
|
torch.cuda.synchronize()
|
|
torch.cuda.synchronize()
|
|
@@ -252,12 +249,8 @@ def main():
|
|
|
model = model.cuda()
|
|
model = model.cuda()
|
|
|
criterion = criterion.cuda()
|
|
criterion = criterion.cuda()
|
|
|
|
|
|
|
|
- if args.amp:
|
|
|
|
|
- model, optimizer = amp.initialize(model, optimizer, opt_level="O2",
|
|
|
|
|
- keep_batchnorm_fp32=False, loss_scale='dynamic')
|
|
|
|
|
-
|
|
|
|
|
if args.distributed:
|
|
if args.distributed:
|
|
|
- model = DDP(model)
|
|
|
|
|
|
|
+ model = torch.nn.parallel.DistributedDataParallel(model)
|
|
|
|
|
|
|
|
local_batch = args.batch_size // args.world_size
|
|
local_batch = args.batch_size // args.world_size
|
|
|
traced_criterion = torch.jit.trace(criterion.forward,
|
|
traced_criterion = torch.jit.trace(criterion.forward,
|
|
@@ -291,6 +284,7 @@ def main():
|
|
|
best_epoch = 0
|
|
best_epoch = 0
|
|
|
best_model_timestamp = synchronized_timestamp()
|
|
best_model_timestamp = synchronized_timestamp()
|
|
|
train_throughputs, eval_throughputs = [], []
|
|
train_throughputs, eval_throughputs = [], []
|
|
|
|
|
+ scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
|
|
|
|
|
|
|
for epoch in range(args.epochs):
|
|
for epoch in range(args.epochs):
|
|
|
|
|
|
|
@@ -311,16 +305,14 @@ def main():
|
|
|
label_features = batch_dict[LABEL_CHANNEL_NAME]
|
|
label_features = batch_dict[LABEL_CHANNEL_NAME]
|
|
|
label_batch = label_features[label_feature_name]
|
|
label_batch = label_features[label_feature_name]
|
|
|
|
|
|
|
|
- outputs = model(user_batch, item_batch)
|
|
|
|
|
- loss = traced_criterion(outputs, label_batch.view(-1, 1)).float()
|
|
|
|
|
- loss = torch.mean(loss.view(-1), 0)
|
|
|
|
|
|
|
+ with torch.cuda.amp.autocast(enabled=args.amp):
|
|
|
|
|
+ outputs = model(user_batch, item_batch)
|
|
|
|
|
+ loss = traced_criterion(outputs, label_batch.view(-1, 1))
|
|
|
|
|
+ loss = torch.mean(loss.float().view(-1), 0)
|
|
|
|
|
|
|
|
- if args.amp:
|
|
|
|
|
- with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
|
|
- scaled_loss.backward()
|
|
|
|
|
- else:
|
|
|
|
|
- loss.backward()
|
|
|
|
|
- optimizer.step()
|
|
|
|
|
|
|
+ scaler.scale(loss).backward()
|
|
|
|
|
+ scaler.step(optimizer)
|
|
|
|
|
+ scaler.update()
|
|
|
|
|
|
|
|
for p in model.parameters():
|
|
for p in model.parameters():
|
|
|
p.grad = None
|
|
p.grad = None
|