smoothing.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the BSD 3-Clause License (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # https://opensource.org/licenses/BSD-3-Clause
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. class LabelSmoothing(nn.Module):
  17. """
  18. NLL loss with label smoothing.
  19. """
  20. def __init__(self, smoothing=0.0):
  21. """
  22. Constructor for the LabelSmoothing module.
  23. :param smoothing: label smoothing factor
  24. """
  25. super(LabelSmoothing, self).__init__()
  26. self.confidence = 1.0 - smoothing
  27. self.smoothing = smoothing
  28. def forward(self, x, target):
  29. logprobs = torch.nn.functional.log_softmax(x, dim=-1)
  30. nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
  31. nll_loss = nll_loss.squeeze(1)
  32. smooth_loss = -logprobs.mean(dim=-1)
  33. loss = self.confidence * nll_loss + self.smoothing * smooth_loss
  34. return loss.mean()