optimization.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch optimization for BERT model."""
  17. import math
  18. import torch
  19. from torch.optim import Optimizer
  20. from torch.optim.optimizer import required
  21. from torch.nn.utils import clip_grad_norm_
  22. #from fused_adam_local import FusedAdam
  23. from apex.optimizers import FusedAdam
  24. from apex.multi_tensor_apply import multi_tensor_applier
  25. import amp_C
  26. from utils import is_main_process
  27. multi_tensor_l2norm = amp_C.multi_tensor_l2norm
  28. lamb_compute_update = amp_C.multi_tensor_lamb_stage1_cuda
  29. lamb_apply_update = amp_C.multi_tensor_lamb_stage2_cuda
  30. scale = amp_C.multi_tensor_scale
  31. def warmup_cosine(x, warmup=0.002):
  32. if x < warmup:
  33. return x/warmup
  34. return 0.5 * (1.0 + torch.cos(math.pi * x))
  35. def warmup_constant(x, warmup=0.002):
  36. if x < warmup:
  37. return x/warmup
  38. return 1.0
  39. def warmup_linear(x, warmup=0.002):
  40. if x < warmup:
  41. return x/warmup
  42. return max((x - 1. )/ (warmup - 1.), 0.)
  43. def warmup_poly(x, warmup=0.002, degree=0.5):
  44. if x < warmup:
  45. return x/warmup
  46. return (1.0 - x)**degree
  47. SCHEDULES = {
  48. 'warmup_cosine':warmup_cosine,
  49. 'warmup_constant':warmup_constant,
  50. 'warmup_linear':warmup_linear,
  51. 'warmup_poly':warmup_poly,
  52. }
  53. class BertAdam(Optimizer):
  54. """Implements BERT version of Adam algorithm with weight decay fix.
  55. Params:
  56. lr: learning rate
  57. warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
  58. t_total: total number of training steps for the learning
  59. rate schedule, -1 means constant learning rate. Default: -1
  60. schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
  61. b1: Adams b1. Default: 0.9
  62. b2: Adams b2. Default: 0.999
  63. e: Adams epsilon. Default: 1e-6
  64. weight_decay: Weight decay. Default: 0.01
  65. max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
  66. """
  67. def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
  68. b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
  69. max_grad_norm=1.0):
  70. if lr is not required and lr < 0.0:
  71. raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
  72. if schedule not in SCHEDULES:
  73. raise ValueError("Invalid schedule parameter: {}".format(schedule))
  74. if not 0.0 <= warmup < 1.0 and not warmup == -1:
  75. raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
  76. if not 0.0 <= b1 < 1.0:
  77. raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
  78. if not 0.0 <= b2 < 1.0:
  79. raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
  80. if not e >= 0.0:
  81. raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
  82. defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
  83. b1=b1, b2=b2, e=e, weight_decay=weight_decay,
  84. max_grad_norm=max_grad_norm)
  85. super(BertAdam, self).__init__(params, defaults)
  86. def get_lr(self):
  87. lr = []
  88. for group in self.param_groups:
  89. for p in group['params']:
  90. state = self.state[p]
  91. if len(state) == 0:
  92. return [0]
  93. if group['t_total'] != -1:
  94. schedule_fct = SCHEDULES[group['schedule']]
  95. lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
  96. else:
  97. lr_scheduled = group['lr']
  98. lr.append(lr_scheduled)
  99. return lr
  100. def step(self, closure=None):
  101. """Performs a single optimization step.
  102. Arguments:
  103. closure (callable, optional): A closure that reevaluates the model
  104. and returns the loss.
  105. """
  106. loss = None
  107. if closure is not None:
  108. loss = closure()
  109. for group in self.param_groups:
  110. for p in group['params']:
  111. if p.grad is None:
  112. continue
  113. grad = p.grad.data
  114. if grad.is_sparse:
  115. raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
  116. state = self.state[p]
  117. # State initialization
  118. if len(state) == 0:
  119. state['step'] = 0
  120. # Exponential moving average of gradient values
  121. state['next_m'] = torch.zeros_like(p.data)
  122. # Exponential moving average of squared gradient values
  123. state['next_v'] = torch.zeros_like(p.data)
  124. next_m, next_v = state['next_m'], state['next_v']
  125. beta1, beta2 = group['b1'], group['b2']
  126. # Add grad clipping
  127. if group['max_grad_norm'] > 0:
  128. clip_grad_norm_(p, group['max_grad_norm'], error_if_nonfinite=False)
  129. # Decay the first and second moment running average coefficient
  130. # In-place operations to update the averages at the same time
  131. next_m.mul_(beta1).add_(1 - beta1, grad)
  132. next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
  133. update = next_m / (next_v.sqrt() + group['e'])
  134. # Just adding the square of the weights to the loss function is *not*
  135. # the correct way of using L2 regularization/weight decay with Adam,
  136. # since that will interact with the m and v parameters in strange ways.
  137. #
  138. # Instead we want to decay the weights in a manner that doesn't interact
  139. # with the m/v parameters. This is equivalent to adding the square
  140. # of the weights to the loss with plain (non-momentum) SGD.
  141. if group['weight_decay'] > 0.0:
  142. update += group['weight_decay'] * p.data
  143. if group['t_total'] != -1:
  144. schedule_fct = SCHEDULES[group['schedule']]
  145. lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
  146. else:
  147. lr_scheduled = group['lr']
  148. update_with_lr = lr_scheduled * update
  149. p.data.add_(-update_with_lr)
  150. state['step'] += 1
  151. return loss