| 1234567891011121314151617181920212223242526272829303132333435363738 |
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the BSD 3-Clause License (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # https://opensource.org/licenses/BSD-3-Clause
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- import torch.nn as nn
- class LabelSmoothing(nn.Module):
- """
- NLL loss with label smoothing.
- """
- def __init__(self, smoothing=0.0):
- """
- Constructor for the LabelSmoothing module.
- :param smoothing: label smoothing factor
- """
- super(LabelSmoothing, self).__init__()
- self.confidence = 1.0 - smoothing
- self.smoothing = smoothing
- def forward(self, x, target):
- logprobs = torch.nn.functional.log_softmax(x, dim=-1)
- nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
- nll_loss = nll_loss.squeeze(1)
- smooth_loss = -logprobs.mean(dim=-1)
- loss = self.confidence * nll_loss + self.smoothing * smooth_loss
- return loss.mean()
|