| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # MIT License
- #
- # Copyright (c) 2019 cybertronai
- #
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE.
- """Lamb optimizer."""
- import torch
- from torch.optim import Optimizer
- class Lamb(Optimizer):
- r"""Implements Lamb algorithm.
- It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- adam (bool, optional): always use trust ratio = 1, which turns this into
- Adam. Useful for comparison purposes.
- .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
- https://arxiv.org/abs/1904.00962
- """
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
- weight_decay=0, adam=False):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
- self.adam = adam
- super(Lamb, self).__init__(params, defaults)
- def step(self, closure=None):
- """Performs a single optimization step.
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- loss = closure()
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is None:
- continue
- grad = p.grad.data
- if grad.is_sparse:
- raise RuntimeError('Lamb does not support sparse gradients.')
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state['step'] = 0
- # Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p.data)
- # Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p.data)
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
- beta1, beta2 = group['betas']
- state['step'] += 1
- # Decay the first and second moment running average coefficient
- # m_t
- exp_avg.mul_(beta1).add_(1 - beta1, grad)
- # v_t
- exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
- # Paper v3 does not use debiasing.
- # bias_correction1 = 1 - beta1 ** state['step']
- # bias_correction2 = 1 - beta2 ** state['step']
- # Apply bias to lr to avoid broadcast.
- step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
- weight_norm = p.data.norm(p=2).clamp_(0, 10)
- adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
- if group['weight_decay'] != 0:
- adam_step.add_(group['weight_decay'], p.data)
- adam_norm = adam_step.norm(p=2)
- if weight_norm == 0.0 or adam_norm == 0.0:
- trust_ratio = 1
- else:
- trust_ratio = weight_norm / (adam_norm + group['eps'])
- state['weight_norm'] = weight_norm
- state['adam_norm'] = adam_norm
- state['trust_ratio'] = trust_ratio
- if self.adam:
- trust_ratio = 1
- p.data.add_(-step_size * trust_ratio, adam_step)
- return loss
- @torch.jit.script
- def lamb_kernel(param, grad, exp_avg, exp_avg_sq, beta1: float,
- beta2: float, step_size: float, eps: float, weight_decay: float):
- exp_avg = exp_avg * beta1 + (1 - beta1) * grad
- exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad)
- adam_step = exp_avg / (exp_avg_sq.sqrt() + eps)
- adam_step = adam_step + weight_decay * param
- weight_norm = param.norm(p=2).clamp(0, 10)
- adam_norm = adam_step.norm(p=2)
- trust_ratio = weight_norm / (adam_norm + eps)
- trust_ratio = (weight_norm == 0.0) * 1.0 + (weight_norm != 0.0) * trust_ratio
- trust_ratio = (adam_norm == 0.0) * 1.0 + (adam_norm != 0.0) * trust_ratio
- trust_ratio = trust_ratio.float()
- param = param - step_size * trust_ratio * adam_step
- return param, exp_avg, exp_avg_sq
- class JITLamb(Optimizer):
- r"""Implements Lamb algorithm.
- It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- adam (bool, optional): always use trust ratio = 1, which turns this into
- Adam. Useful for comparison purposes.
- .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
- https://arxiv.org/abs/1904.00962
- """
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
- weight_decay=0, adam=False):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
- self.adam = adam
- super().__init__(params, defaults)
- def step(self, closure=None):
- """Performs a single optimization step.
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- loss = closure()
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is None:
- continue
- grad = p.grad.data
- if grad.is_sparse:
- raise RuntimeError('Lamb does not support sparse gradients.')
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state['step'] = 0
- # Exponential moving average of gradient values
- state['exp_avg'] = torch.zeros_like(p.data)
- # Exponential moving average of squared gradient values
- state['exp_avg_sq'] = torch.zeros_like(p.data)
- exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
- beta1, beta2 = group['betas']
- state['step'] += 1
- step_size = group['lr']
- param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg,
- exp_avg_sq, beta1,
- beta2, step_size,
- group['eps'],
- group['weight_decay'],
- )
- state['exp_avg'] = exp_avg
- state['exp_avg_sq'] = exp_avg_sq
- p.data = param
- return loss
|