optimizer.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import logging
  16. from paddle import optimizer as optim
  17. class Momentum:
  18. """
  19. Simple Momentum optimizer with velocity state.
  20. Args:
  21. args(Namespace): Arguments obtained from ArgumentParser.
  22. learning_rate(float|LRScheduler): The learning rate used to update parameters.
  23. Can be a float value or a paddle.optimizer.lr.LRScheduler.
  24. """
  25. def __init__(self, args, learning_rate):
  26. super().__init__()
  27. self.learning_rate = learning_rate
  28. self.momentum = args.momentum
  29. self.weight_decay = args.weight_decay
  30. self.grad_clip = None
  31. self.multi_precision = args.amp
  32. def __call__(self):
  33. # model_list is None in static graph
  34. parameters = None
  35. opt = optim.Momentum(
  36. learning_rate=self.learning_rate,
  37. momentum=self.momentum,
  38. weight_decay=self.weight_decay,
  39. grad_clip=self.grad_clip,
  40. multi_precision=self.multi_precision,
  41. parameters=parameters)
  42. return opt
  43. def build_optimizer(args, lr):
  44. """
  45. Build a raw optimizer with learning rate scheduler.
  46. Args:
  47. args(Namespace): Arguments obtained from ArgumentParser.
  48. lr(paddle.optimizer.lr.LRScheduler): A LRScheduler used for training.
  49. return:
  50. optim(paddle.optimizer): A normal optmizer.
  51. """
  52. optimizer_mod = sys.modules[__name__]
  53. opt = getattr(optimizer_mod, args.optimizer)(args, learning_rate=lr)()
  54. logging.info("build optimizer %s success..", opt)
  55. return opt