|
|
@@ -17,7 +17,7 @@ Fused Buckle Embedding
|
|
|
"""
|
|
|
|
|
|
from absl import logging
|
|
|
-from apex import amp
|
|
|
+import torch
|
|
|
from torch.autograd import Function
|
|
|
|
|
|
from dlrm.cuda_ext import fused_embedding
|
|
|
@@ -26,12 +26,14 @@ from dlrm.cuda_ext import fused_embedding
|
|
|
class BuckleEmbeddingFusedGatherFunction(Function):
|
|
|
"""Customized embedding gather """
|
|
|
@staticmethod
|
|
|
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
|
def forward(ctx, embedding, indices, offsets, amp_train):
|
|
|
output = fused_embedding.gather_gpu_fused_fwd(embedding, indices, offsets, amp_train)
|
|
|
ctx.save_for_backward(embedding, indices, offsets)
|
|
|
return output
|
|
|
|
|
|
@staticmethod
|
|
|
+ @torch.cuda.amp.custom_bwd
|
|
|
def backward(ctx, grad_output):
|
|
|
embedding, indices, offsets = ctx.saved_tensors
|
|
|
|
|
|
@@ -40,4 +42,4 @@ class BuckleEmbeddingFusedGatherFunction(Function):
|
|
|
return grad_weights, None, None, None
|
|
|
|
|
|
|
|
|
-buckle_embedding_fused_gather = amp.float_function(BuckleEmbeddingFusedGatherFunction.apply)
|
|
|
+buckle_embedding_fused_gather = BuckleEmbeddingFusedGatherFunction.apply
|