optimizers.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from common.fairseq.optim.adam import FairseqAdam
  16. from common.fairseq.optim.fp16_optimizer import FP16Optimizer
  17. from common.fairseq.optim.fused_adam import get_fused_adam_class
  18. from common.utils import print_once
  19. def lr_poly_policy(step, optimizer, lr, initial_lr_scale=0.0,
  20. final_lr_scale=0.0, warmup_steps=1000, hold_steps=0,
  21. num_steps=None, power=1.0):
  22. """Polynomial decay LR policy with an optional hold period."""
  23. assert step >= 1
  24. assert num_steps is not None
  25. assert power is not None
  26. start_lr = initial_lr_scale * lr
  27. end_lr = final_lr_scale * lr
  28. if step <= warmup_steps:
  29. new_lr = start_lr + (step) / warmup_steps * (lr - start_lr)
  30. elif step <= warmup_steps + hold_steps:
  31. new_lr = lr
  32. elif warmup_steps + hold_steps < step <= num_steps:
  33. remain = 1 - (step - warmup_steps) / (num_steps - warmup_steps)
  34. new_lr = (lr - end_lr) * remain ** power + end_lr
  35. else:
  36. new_lr = end_lr
  37. for param_group in optimizer.param_groups:
  38. param_group['lr'] = new_lr
  39. def lr_exp_policy(step, optimizer, initial_lr_scale, lr, final_lr_scale=0.0,
  40. warmup_steps=1000, hold_steps=0, num_steps=float('inf'),
  41. decay=None):
  42. """Exponential LR policy with an optional hold period.
  43. If `decay` factor is not supplied, it is calculated to reach `end_lr`
  44. on `num_steps` steps.
  45. Args:
  46. num_steps (int): Limits the number of decay steps.
  47. end_lr (float): The lowest possible LR.
  48. decay (float or None): Decay factor; if None, the it will be derived
  49. from `num_steps` and `end_lr`.
  50. """
  51. assert step >= 1
  52. start_lr = initial_lr_scale * lr
  53. end_lr = final_lr_scale * lr
  54. if decay is None:
  55. assert not math.isinf(num_steps) and end_lr > 0.0
  56. decay_steps = num_steps - warmup_steps - hold_steps
  57. decay = math.log(end_lr / lr) / decay_steps
  58. else:
  59. decay = math.log(decay)
  60. if step <= warmup_steps:
  61. new_lr = start_lr + (step) / warmup_steps * (lr - start_lr)
  62. elif step <= warmup_steps + hold_steps:
  63. new_lr = lr
  64. else:
  65. a = math.exp(decay * (min(step, num_steps) - warmup_steps - hold_steps))
  66. new_lr = max(a * lr, end_lr)
  67. for param_group in optimizer.param_groups:
  68. param_group['lr'] = new_lr
  69. def get_optimizer(model, args):
  70. kw = {'lr': args.lr, 'weight_decay': args.weight_decay}
  71. if args.optimizer == 'adam' and (args.fp16 or args.bf16):
  72. print_once('WARNING: Using Fairseq FP16Optimizer')
  73. # based on fairseq.optim.FP16Optimizer.build_optimizer
  74. flatten = True # not args.fp16_no_flatten_grads
  75. args.betas = args.adam_betas
  76. args.eps = args.adam_eps
  77. params = list(filter(lambda p: p.requires_grad, model.parameters()))
  78. fp32_params = FP16Optimizer.build_fp32_params(args, params,
  79. flatten=flatten)
  80. # based on fairseq.optim.build_optimizer
  81. def build_optimizer(cfg, params, *extra_args, **extra_kwargs):
  82. if all(isinstance(p, dict) for p in params):
  83. params = [t for p in params for t in p.values()]
  84. params = list(filter(lambda p: p.requires_grad, params))
  85. return FairseqAdam(cfg, params, *extra_args, **extra_kwargs)
  86. if flatten:
  87. fp32_optimizer = build_optimizer(args, [fp32_params])
  88. else:
  89. fp32_optimizer = build_optimizer(args, fp32_params)
  90. if flatten and not fp32_optimizer.supports_flat_params:
  91. raise RuntimeError(
  92. f"chosen optimizer {fp32_optimizer.__class__.__name__} does "
  93. "not support flat params, please set --fp16-no-flatten-grads"
  94. )
  95. kwargs = {}
  96. optimizer = FP16Optimizer(args, params, fp32_optimizer, fp32_params,
  97. **kwargs)
  98. elif args.optimizer == 'adam' and not (args.fp16 or args.bf16):
  99. print_once('WARNING: Using FusedAdam instead of Adam')
  100. kw.update({'betas': args.adam_betas, 'eps': args.adam_eps})
  101. fused_adam_cls = get_fused_adam_class()
  102. optimizer = fused_adam_cls(model.parameters(), **kw)
  103. else:
  104. raise ValueError(f'Invalid optimizer "{args.optimizer}"')
  105. return optimizer