lamb.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # MIT License
  15. #
  16. # Copyright (c) 2019 cybertronai
  17. #
  18. # Permission is hereby granted, free of charge, to any person obtaining a copy
  19. # of this software and associated documentation files (the "Software"), to deal
  20. # in the Software without restriction, including without limitation the rights
  21. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  22. # copies of the Software, and to permit persons to whom the Software is
  23. # furnished to do so, subject to the following conditions:
  24. #
  25. # The above copyright notice and this permission notice shall be included in all
  26. # copies or substantial portions of the Software.
  27. #
  28. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  29. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  30. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  31. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  32. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  33. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  34. # SOFTWARE.
  35. """Lamb optimizer."""
  36. import collections
  37. import math
  38. import torch
  39. from torch.optim import Optimizer
  40. class Lamb(Optimizer):
  41. r"""Implements Lamb algorithm.
  42. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
  43. Arguments:
  44. params (iterable): iterable of parameters to optimize or dicts defining
  45. parameter groups
  46. lr (float, optional): learning rate (default: 1e-3)
  47. betas (Tuple[float, float], optional): coefficients used for computing
  48. running averages of gradient and its square (default: (0.9, 0.999))
  49. eps (float, optional): term added to the denominator to improve
  50. numerical stability (default: 1e-8)
  51. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  52. adam (bool, optional): always use trust ratio = 1, which turns this into
  53. Adam. Useful for comparison purposes.
  54. .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
  55. https://arxiv.org/abs/1904.00962
  56. """
  57. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
  58. weight_decay=0, adam=False):
  59. if not 0.0 <= lr:
  60. raise ValueError("Invalid learning rate: {}".format(lr))
  61. if not 0.0 <= eps:
  62. raise ValueError("Invalid epsilon value: {}".format(eps))
  63. if not 0.0 <= betas[0] < 1.0:
  64. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  65. if not 0.0 <= betas[1] < 1.0:
  66. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  67. defaults = dict(lr=lr, betas=betas, eps=eps,
  68. weight_decay=weight_decay)
  69. self.adam = adam
  70. super(Lamb, self).__init__(params, defaults)
  71. def step(self, closure=None):
  72. """Performs a single optimization step.
  73. Arguments:
  74. closure (callable, optional): A closure that reevaluates the model
  75. and returns the loss.
  76. """
  77. loss = None
  78. if closure is not None:
  79. loss = closure()
  80. for group in self.param_groups:
  81. for p in group['params']:
  82. if p.grad is None:
  83. continue
  84. grad = p.grad.data
  85. if grad.is_sparse:
  86. raise RuntimeError('Lamb does not support sparse gradients.')
  87. state = self.state[p]
  88. # State initialization
  89. if len(state) == 0:
  90. state['step'] = 0
  91. # Exponential moving average of gradient values
  92. state['exp_avg'] = torch.zeros_like(p.data)
  93. # Exponential moving average of squared gradient values
  94. state['exp_avg_sq'] = torch.zeros_like(p.data)
  95. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  96. beta1, beta2 = group['betas']
  97. state['step'] += 1
  98. # Decay the first and second moment running average coefficient
  99. # m_t
  100. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  101. # v_t
  102. exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
  103. # Paper v3 does not use debiasing.
  104. # bias_correction1 = 1 - beta1 ** state['step']
  105. # bias_correction2 = 1 - beta2 ** state['step']
  106. # Apply bias to lr to avoid broadcast.
  107. step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
  108. weight_norm = p.data.norm(p=2).clamp_(0, 10)
  109. adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
  110. if group['weight_decay'] != 0:
  111. adam_step.add_(group['weight_decay'], p.data)
  112. adam_norm = adam_step.norm(p=2)
  113. trust_ratio = weight_norm / (adam_norm + group['eps'])
  114. state['weight_norm'] = weight_norm
  115. state['adam_norm'] = adam_norm
  116. state['trust_ratio'] = trust_ratio
  117. if self.adam:
  118. trust_ratio = 1
  119. p.data.add_(-step_size * trust_ratio, adam_step)
  120. return loss
  121. @torch.jit.script
  122. def lamb_kernel(param, grad, exp_avg, exp_avg_sq, beta1: float,
  123. beta2: float, step_size: float, eps: float, weight_decay: float):
  124. exp_avg = exp_avg * beta1 + (1 - beta1) * grad
  125. exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad)
  126. adam_step = exp_avg / (exp_avg_sq.sqrt() + eps)
  127. adam_step = adam_step + weight_decay * param
  128. weight_norm = param.norm(p=2).clamp_(0, 10)
  129. adam_norm = adam_step.norm(p=2)
  130. trust_ratio = weight_norm / (adam_norm + eps)
  131. param = param - step_size * trust_ratio * adam_step
  132. return param, exp_avg, exp_avg_sq
  133. class JITLamb(Optimizer):
  134. r"""Implements Lamb algorithm.
  135. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
  136. Arguments:
  137. params (iterable): iterable of parameters to optimize or dicts defining
  138. parameter groups
  139. lr (float, optional): learning rate (default: 1e-3)
  140. betas (Tuple[float, float], optional): coefficients used for computing
  141. running averages of gradient and its square (default: (0.9, 0.999))
  142. eps (float, optional): term added to the denominator to improve
  143. numerical stability (default: 1e-8)
  144. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  145. adam (bool, optional): always use trust ratio = 1, which turns this into
  146. Adam. Useful for comparison purposes.
  147. .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
  148. https://arxiv.org/abs/1904.00962
  149. """
  150. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
  151. weight_decay=0, adam=False):
  152. if not 0.0 <= lr:
  153. raise ValueError("Invalid learning rate: {}".format(lr))
  154. if not 0.0 <= eps:
  155. raise ValueError("Invalid epsilon value: {}".format(eps))
  156. if not 0.0 <= betas[0] < 1.0:
  157. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  158. if not 0.0 <= betas[1] < 1.0:
  159. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  160. defaults = dict(lr=lr, betas=betas, eps=eps,
  161. weight_decay=weight_decay)
  162. self.adam = adam
  163. super().__init__(params, defaults)
  164. def step(self, closure=None):
  165. """Performs a single optimization step.
  166. Arguments:
  167. closure (callable, optional): A closure that reevaluates the model
  168. and returns the loss.
  169. """
  170. loss = None
  171. if closure is not None:
  172. loss = closure()
  173. for group in self.param_groups:
  174. for p in group['params']:
  175. if p.grad is None:
  176. continue
  177. grad = p.grad.data
  178. if grad.is_sparse:
  179. raise RuntimeError('Lamb does not support sparse gradients.')
  180. state = self.state[p]
  181. # State initialization
  182. if len(state) == 0:
  183. state['step'] = 0
  184. # Exponential moving average of gradient values
  185. state['exp_avg'] = torch.zeros_like(p.data)
  186. # Exponential moving average of squared gradient values
  187. state['exp_avg_sq'] = torch.zeros_like(p.data)
  188. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  189. beta1, beta2 = group['betas']
  190. state['step'] += 1
  191. step_size = group['lr']
  192. param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg,
  193. exp_avg_sq, beta1,
  194. beta2, step_size,
  195. group['eps'],
  196. group['weight_decay'],
  197. )
  198. state['exp_avg'] = exp_avg
  199. state['exp_avg_sq'] = exp_avg_sq
  200. p.data = param
  201. return loss