lamb.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # MIT License
  15. #
  16. # Copyright (c) 2019 cybertronai
  17. #
  18. # Permission is hereby granted, free of charge, to any person obtaining a copy
  19. # of this software and associated documentation files (the "Software"), to deal
  20. # in the Software without restriction, including without limitation the rights
  21. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  22. # copies of the Software, and to permit persons to whom the Software is
  23. # furnished to do so, subject to the following conditions:
  24. #
  25. # The above copyright notice and this permission notice shall be included in all
  26. # copies or substantial portions of the Software.
  27. #
  28. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  29. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  30. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  31. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  32. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  33. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  34. # SOFTWARE.
  35. """Lamb optimizer."""
  36. import torch
  37. from torch.optim import Optimizer
  38. class Lamb(Optimizer):
  39. r"""Implements Lamb algorithm.
  40. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
  41. Arguments:
  42. params (iterable): iterable of parameters to optimize or dicts defining
  43. parameter groups
  44. lr (float, optional): learning rate (default: 1e-3)
  45. betas (Tuple[float, float], optional): coefficients used for computing
  46. running averages of gradient and its square (default: (0.9, 0.999))
  47. eps (float, optional): term added to the denominator to improve
  48. numerical stability (default: 1e-8)
  49. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  50. adam (bool, optional): always use trust ratio = 1, which turns this into
  51. Adam. Useful for comparison purposes.
  52. .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
  53. https://arxiv.org/abs/1904.00962
  54. """
  55. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
  56. weight_decay=0, adam=False):
  57. if not 0.0 <= lr:
  58. raise ValueError("Invalid learning rate: {}".format(lr))
  59. if not 0.0 <= eps:
  60. raise ValueError("Invalid epsilon value: {}".format(eps))
  61. if not 0.0 <= betas[0] < 1.0:
  62. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  63. if not 0.0 <= betas[1] < 1.0:
  64. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  65. defaults = dict(lr=lr, betas=betas, eps=eps,
  66. weight_decay=weight_decay)
  67. self.adam = adam
  68. super(Lamb, self).__init__(params, defaults)
  69. def step(self, closure=None):
  70. """Performs a single optimization step.
  71. Arguments:
  72. closure (callable, optional): A closure that reevaluates the model
  73. and returns the loss.
  74. """
  75. loss = None
  76. if closure is not None:
  77. loss = closure()
  78. for group in self.param_groups:
  79. for p in group['params']:
  80. if p.grad is None:
  81. continue
  82. grad = p.grad.data
  83. if grad.is_sparse:
  84. raise RuntimeError('Lamb does not support sparse gradients.')
  85. state = self.state[p]
  86. # State initialization
  87. if len(state) == 0:
  88. state['step'] = 0
  89. # Exponential moving average of gradient values
  90. state['exp_avg'] = torch.zeros_like(p.data)
  91. # Exponential moving average of squared gradient values
  92. state['exp_avg_sq'] = torch.zeros_like(p.data)
  93. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  94. beta1, beta2 = group['betas']
  95. state['step'] += 1
  96. # Decay the first and second moment running average coefficient
  97. # m_t
  98. exp_avg.mul_(beta1).add_(1 - beta1, grad)
  99. # v_t
  100. exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
  101. # Paper v3 does not use debiasing.
  102. # bias_correction1 = 1 - beta1 ** state['step']
  103. # bias_correction2 = 1 - beta2 ** state['step']
  104. # Apply bias to lr to avoid broadcast.
  105. step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
  106. weight_norm = p.data.norm(p=2).clamp_(0, 10)
  107. adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
  108. if group['weight_decay'] != 0:
  109. adam_step.add_(group['weight_decay'], p.data)
  110. adam_norm = adam_step.norm(p=2)
  111. if weight_norm == 0.0 or adam_norm == 0.0:
  112. trust_ratio = 1
  113. else:
  114. trust_ratio = weight_norm / (adam_norm + group['eps'])
  115. state['weight_norm'] = weight_norm
  116. state['adam_norm'] = adam_norm
  117. state['trust_ratio'] = trust_ratio
  118. if self.adam:
  119. trust_ratio = 1
  120. p.data.add_(-step_size * trust_ratio, adam_step)
  121. return loss
  122. @torch.jit.script
  123. def lamb_kernel(param, grad, exp_avg, exp_avg_sq, beta1: float,
  124. beta2: float, step_size: float, eps: float, weight_decay: float):
  125. exp_avg = exp_avg * beta1 + (1 - beta1) * grad
  126. exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad)
  127. adam_step = exp_avg / (exp_avg_sq.sqrt() + eps)
  128. adam_step = adam_step + weight_decay * param
  129. weight_norm = param.norm(p=2).clamp(0, 10)
  130. adam_norm = adam_step.norm(p=2)
  131. trust_ratio = weight_norm / (adam_norm + eps)
  132. trust_ratio = (weight_norm == 0.0) * 1.0 + (weight_norm != 0.0) * trust_ratio
  133. trust_ratio = (adam_norm == 0.0) * 1.0 + (adam_norm != 0.0) * trust_ratio
  134. trust_ratio = trust_ratio.float()
  135. param = param - step_size * trust_ratio * adam_step
  136. return param, exp_avg, exp_avg_sq
  137. class JITLamb(Optimizer):
  138. r"""Implements Lamb algorithm.
  139. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
  140. Arguments:
  141. params (iterable): iterable of parameters to optimize or dicts defining
  142. parameter groups
  143. lr (float, optional): learning rate (default: 1e-3)
  144. betas (Tuple[float, float], optional): coefficients used for computing
  145. running averages of gradient and its square (default: (0.9, 0.999))
  146. eps (float, optional): term added to the denominator to improve
  147. numerical stability (default: 1e-8)
  148. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  149. adam (bool, optional): always use trust ratio = 1, which turns this into
  150. Adam. Useful for comparison purposes.
  151. .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
  152. https://arxiv.org/abs/1904.00962
  153. """
  154. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
  155. weight_decay=0, adam=False):
  156. if not 0.0 <= lr:
  157. raise ValueError("Invalid learning rate: {}".format(lr))
  158. if not 0.0 <= eps:
  159. raise ValueError("Invalid epsilon value: {}".format(eps))
  160. if not 0.0 <= betas[0] < 1.0:
  161. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  162. if not 0.0 <= betas[1] < 1.0:
  163. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  164. defaults = dict(lr=lr, betas=betas, eps=eps,
  165. weight_decay=weight_decay)
  166. self.adam = adam
  167. super().__init__(params, defaults)
  168. def step(self, closure=None):
  169. """Performs a single optimization step.
  170. Arguments:
  171. closure (callable, optional): A closure that reevaluates the model
  172. and returns the loss.
  173. """
  174. loss = None
  175. if closure is not None:
  176. loss = closure()
  177. for group in self.param_groups:
  178. for p in group['params']:
  179. if p.grad is None:
  180. continue
  181. grad = p.grad.data
  182. if grad.is_sparse:
  183. raise RuntimeError('Lamb does not support sparse gradients.')
  184. state = self.state[p]
  185. # State initialization
  186. if len(state) == 0:
  187. state['step'] = 0
  188. # Exponential moving average of gradient values
  189. state['exp_avg'] = torch.zeros_like(p.data)
  190. # Exponential moving average of squared gradient values
  191. state['exp_avg_sq'] = torch.zeros_like(p.data)
  192. exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
  193. beta1, beta2 = group['betas']
  194. state['step'] += 1
  195. step_size = group['lr']
  196. param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg,
  197. exp_avg_sq, beta1,
  198. beta2, step_size,
  199. group['eps'],
  200. group['weight_decay'],
  201. )
  202. state['exp_avg'] = exp_avg
  203. state['exp_avg_sq'] = exp_avg_sq
  204. p.data = param
  205. return loss