Просмотр исходного кода

Merge: [NCF/PyT] Stop using deprecated apex AMP and apex DDP

Krzysztof Kudrynski 2 лет назад
Родитель
Сommit
2a7c251dcb
2 измененных файлов с 10 добавлено и 30 удалено
  1. 1 13
      PyTorch/Recommendation/NCF/README.md
  2. 9 17
      PyTorch/Recommendation/NCF/ncf.py

+ 1 - 13
PyTorch/Recommendation/NCF/README.md

@@ -143,23 +143,11 @@ The ability to train deep learning networks with lower precision was introduced
 For information about:
 For information about:
 -   How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.
 -   How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.
 -   Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
 -   Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
--   APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
 
 
 
 
 #### Enabling mixed precision
 #### Enabling mixed precision
 
 
-Using the Automatic Mixed Precision (AMP) package requires two modifications in the source code.
-The first one is to initialize the model and the optimizer using the `amp.initialize` function:
-```python
-model, optimizer = amp.initialize(model, optimizer, opt_level="O2"
-                                          keep_batchnorm_fp32=False, loss_scale='dynamic')
-```
-
-The second one is to use the AMP's loss scaling context manager:
-```python
-with amp.scale_loss(loss, optimizer) as scaled_loss:
-    scaled_loss.backward()
-```
+Mixed precision training is turned off by default. To turn it on issue the `--amp` flag to the `main.py` script.
 
 
 #### Enabling TF32
 #### Enabling TF32
 
 

+ 9 - 17
PyTorch/Recommendation/NCF/ncf.py

@@ -47,9 +47,6 @@ from neumf_constants import USER_CHANNEL_NAME, ITEM_CHANNEL_NAME, LABEL_CHANNEL_
 
 
 import dllogger
 import dllogger
 
 
-from apex.parallel import DistributedDataParallel as DDP
-from apex import amp
-
 
 
 def synchronized_timestamp():
 def synchronized_timestamp():
     torch.cuda.synchronize()
     torch.cuda.synchronize()
@@ -252,12 +249,8 @@ def main():
     model = model.cuda()
     model = model.cuda()
     criterion = criterion.cuda()
     criterion = criterion.cuda()
 
 
-    if args.amp:
-        model, optimizer = amp.initialize(model, optimizer, opt_level="O2",
-                                          keep_batchnorm_fp32=False, loss_scale='dynamic')
-
     if args.distributed:
     if args.distributed:
-        model = DDP(model)
+        model = torch.nn.parallel.DistributedDataParallel(model)
 
 
     local_batch = args.batch_size // args.world_size
     local_batch = args.batch_size // args.world_size
     traced_criterion = torch.jit.trace(criterion.forward,
     traced_criterion = torch.jit.trace(criterion.forward,
@@ -291,6 +284,7 @@ def main():
     best_epoch = 0
     best_epoch = 0
     best_model_timestamp = synchronized_timestamp()
     best_model_timestamp = synchronized_timestamp()
     train_throughputs, eval_throughputs = [], []
     train_throughputs, eval_throughputs = [], []
+    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
 
 
     for epoch in range(args.epochs):
     for epoch in range(args.epochs):
 
 
@@ -311,16 +305,14 @@ def main():
                 label_features = batch_dict[LABEL_CHANNEL_NAME]
                 label_features = batch_dict[LABEL_CHANNEL_NAME]
                 label_batch = label_features[label_feature_name]
                 label_batch = label_features[label_feature_name]
 
 
-                outputs = model(user_batch, item_batch)
-                loss = traced_criterion(outputs, label_batch.view(-1, 1)).float()
-                loss = torch.mean(loss.view(-1), 0)
+                with torch.cuda.amp.autocast(enabled=args.amp):
+                    outputs = model(user_batch, item_batch)
+                    loss = traced_criterion(outputs, label_batch.view(-1, 1))
+                    loss = torch.mean(loss.float().view(-1), 0)
 
 
-                if args.amp:
-                    with amp.scale_loss(loss, optimizer) as scaled_loss:
-                        scaled_loss.backward()
-                else:
-                    loss.backward()
-            optimizer.step()
+                scaler.scale(loss).backward()
+            scaler.step(optimizer)
+            scaler.update()
 
 
             for p in model.parameters():
             for p in model.parameters():
                 p.grad = None
                 p.grad = None