Просмотр исходного кода

Merge: [DLRM/PyT] Stop using apex AMP and DDP

Krzysztof Kudrynski 2 лет назад
Родитель
Сommit
370a221cc9

+ 4 - 2
PyTorch/Recommendation/DLRM/dlrm/cuda_ext/fused_gather_embedding.py

@@ -17,7 +17,7 @@ Fused Buckle Embedding
 """
 
 from absl import logging
-from apex import amp
+import torch
 from torch.autograd import Function
 
 from dlrm.cuda_ext import fused_embedding
@@ -26,12 +26,14 @@ from dlrm.cuda_ext import fused_embedding
 class BuckleEmbeddingFusedGatherFunction(Function):
     """Customized embedding gather """
     @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
     def forward(ctx, embedding, indices, offsets, amp_train):
         output = fused_embedding.gather_gpu_fused_fwd(embedding, indices, offsets, amp_train)
         ctx.save_for_backward(embedding, indices, offsets)
         return output
 
     @staticmethod
+    @torch.cuda.amp.custom_bwd
     def backward(ctx, grad_output):
         embedding, indices, offsets = ctx.saved_tensors
 
@@ -40,4 +42,4 @@ class BuckleEmbeddingFusedGatherFunction(Function):
         return grad_weights, None, None, None
 
 
-buckle_embedding_fused_gather = amp.float_function(BuckleEmbeddingFusedGatherFunction.apply)
+buckle_embedding_fused_gather = BuckleEmbeddingFusedGatherFunction.apply

+ 4 - 4
PyTorch/Recommendation/DLRM/dlrm/cuda_ext/sparse_embedding.py

@@ -15,7 +15,7 @@
 import copy
 
 import torch
-from apex import amp
+from torch.cuda import amp
 from dlrm.cuda_ext import sparse_gather
 from torch import nn
 from torch.autograd import Function
@@ -24,6 +24,7 @@ from torch.autograd import Function
 class EmbeddingGatherFunction(Function):
     """Customized embedding gather with fused plain SGD"""
     @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
     def forward(ctx, embedding, indices):
         output = sparse_gather.gather_gpu_fwd(embedding, indices)
         ctx.save_for_backward(indices)
@@ -31,11 +32,10 @@ class EmbeddingGatherFunction(Function):
         return output
 
     @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
     def backward(ctx, grad_output):
         indices = ctx.saved_tensors[0]
-
         grad_embedding = sparse_gather.gather_gpu_bwd(grad_output, indices, ctx.num_features)
-
         return grad_embedding, None
 
 
@@ -66,4 +66,4 @@ class JointSparseEmbedding(nn.Module):
         return embedding_out
 
 
-embedding_gather = amp.float_function(EmbeddingGatherFunction.apply)
+embedding_gather = EmbeddingGatherFunction.apply

+ 2 - 5
PyTorch/Recommendation/DLRM/dlrm/scripts/main.py

@@ -17,7 +17,7 @@ import itertools
 import os
 import sys
 from absl import app, flags, logging
-from apex import amp, parallel, optimizers as apex_optim
+from apex import optimizers as apex_optim
 
 from dlrm.data.feature_spec import FeatureSpec
 from dlrm.model.distributed import DistributedDlrm
@@ -500,10 +500,7 @@ def main(argv):
         if world_size <= 1:
             return model
 
-        if use_gpu:
-            model.top_model = parallel.DistributedDataParallel(model.top_model)
-        else:  # Use other backend for CPU
-            model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
+        model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
         return model
 
     if FLAGS.mode == 'test':