optimization.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. # Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. """Functions and classes related to optimization (weight updates)."""
  17. import re
  18. import collections
  19. import tensorflow as tf
  20. import tensorflow_addons.optimizers as tfa_optimizers
  21. from tensorflow.python.ops import control_flow_ops
  22. from tensorflow.python.ops import math_ops
  23. from tensorflow.python.ops import state_ops
  24. from tensorflow.python.training import training_ops
  25. from utils import log
  26. class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
  27. """Applys a warmup schedule on a given learning rate decay schedule."""
  28. def __init__(self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None):
  29. super().__init__()
  30. self.initial_learning_rate = initial_learning_rate
  31. self.warmup_steps = warmup_steps
  32. self.power = power
  33. self.decay_schedule_fn = decay_schedule_fn
  34. self.name = name
  35. def __call__(self, step):
  36. with tf.name_scope(self.name or "WarmUp") as name:
  37. # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
  38. # learning rate will be `global_step/num_warmup_steps * init_lr`.
  39. global_step_float = tf.cast(step, tf.float32)
  40. warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
  41. warmup_percent_done = global_step_float / warmup_steps_float
  42. warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
  43. return tf.cond(
  44. global_step_float < warmup_steps_float,
  45. lambda: warmup_learning_rate,
  46. lambda: self.decay_schedule_fn(step - self.warmup_steps),
  47. name=name,
  48. )
  49. def get_config(self):
  50. return {
  51. "initial_learning_rate": self.initial_learning_rate,
  52. "decay_schedule_fn": self.decay_schedule_fn,
  53. "warmup_steps": self.warmup_steps,
  54. "power": self.power,
  55. "name": self.name,
  56. }
  57. def create_optimizer(init_lr, num_train_steps, num_warmup_steps, weight_decay_rate=0.01,
  58. layerwise_lr_decay=-1, n_transformer_layers=None, clip_norm=1.0,
  59. optimizer="adam", skip_adaptive=False, power=1.0, beta_1=0.9, beta_2=0.999, end_lr=0.0):
  60. """Creates an optimizer with learning rate schedule."""
  61. # Implements linear decay of the learning rate.
  62. learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
  63. initial_learning_rate=init_lr, decay_steps=num_train_steps - num_warmup_steps, end_learning_rate=end_lr, power=power
  64. )
  65. if num_warmup_steps:
  66. learning_rate_fn = WarmUp(
  67. initial_learning_rate=init_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=num_warmup_steps
  68. )
  69. layer_decay = None
  70. if layerwise_lr_decay > 0 and n_transformer_layers is not None:
  71. layer_decay = _get_layer_decay(layerwise_lr_decay, n_transformer_layers)
  72. if optimizer == "adam":
  73. optimizer = AdamWeightDecay(
  74. learning_rate=learning_rate_fn,
  75. weight_decay_rate=weight_decay_rate,
  76. layer_decay=layer_decay,
  77. beta_1=beta_1,
  78. beta_2=beta_2,
  79. epsilon=1e-6,
  80. exclude_from_weight_decay=["layer_norm", "bias", "LayerNorm"],
  81. clip_norm=clip_norm,
  82. )
  83. else:
  84. if skip_adaptive:
  85. skip_list = ["layer_norm", "bias", "LayerNorm"]
  86. else:
  87. skip_list = ["None"]
  88. log("Skip list for LAMB {}".format(skip_list))
  89. optimizer = tfa_optimizers.LAMB(
  90. learning_rate=learning_rate_fn,
  91. weight_decay_rate=weight_decay_rate,
  92. beta_1=beta_1,
  93. beta_2=beta_2,
  94. epsilon=1e-6,
  95. exclude_from_weight_decay=["layer_norm", "bias", "LayerNorm"],
  96. exclude_from_layer_adaptation=skip_list,
  97. )
  98. return optimizer
  99. class AdamWeightDecay(tf.keras.optimizers.Adam):
  100. """Adam enables L2 weight decay and clip_by_global_norm on gradients.
  101. Just adding the square of the weights to the loss function is *not* the
  102. correct way of using L2 regularization/weight decay with Adam, since that will
  103. interact with the m and v parameters in strange ways.
  104. Instead we want ot decay the weights in a manner that doesn't interact with
  105. the m/v parameters. This is equivalent to adding the square of the weights to
  106. the loss with plain (non-momentum) SGD.
  107. """
  108. def __init__(
  109. self,
  110. learning_rate=0.001,
  111. beta_1=0.9,
  112. beta_2=0.999,
  113. epsilon=1e-7,
  114. amsgrad=False,
  115. weight_decay_rate=0.0,
  116. include_in_weight_decay=None,
  117. exclude_from_weight_decay=None,
  118. layer_decay=None,
  119. clip_norm=1.0,
  120. name="AdamWeightDecay",
  121. **kwargs
  122. ):
  123. super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
  124. self.weight_decay_rate = weight_decay_rate
  125. self._include_in_weight_decay = include_in_weight_decay
  126. self._exclude_from_weight_decay = exclude_from_weight_decay
  127. self.layer_decay = layer_decay
  128. self.clip_norm = clip_norm
  129. @classmethod
  130. def from_config(cls, config):
  131. """Creates an optimizer from its config with WarmUp custom object."""
  132. custom_objects = {"WarmUp": WarmUp}
  133. return super().from_config(config, custom_objects=custom_objects)
  134. def _prepare_local(self, var_device, var_dtype, apply_state):
  135. super()._prepare_local(var_device, var_dtype, apply_state)
  136. apply_state["weight_decay_rate"] = tf.constant(self.weight_decay_rate, name="adam_weight_decay_rate")
  137. def _decay_weights_op(self, var, learning_rate, apply_state):
  138. do_decay = self._do_use_weight_decay(var.name)
  139. if do_decay:
  140. return var.assign_sub(
  141. learning_rate * var * apply_state["weight_decay_rate"], use_locking=self._use_locking
  142. )
  143. return tf.no_op()
  144. def apply_gradients(self, grads_and_vars, name=None, experimental_aggregate_gradients=True):
  145. grads, tvars = list(zip(*grads_and_vars))
  146. # Being done in train_step
  147. ##(grads, _) = tf.clip_by_global_norm(grads, clip_norm=self.clip_norm)
  148. return super().apply_gradients(zip(grads, tvars), name=name,
  149. experimental_aggregate_gradients=experimental_aggregate_gradients)
  150. def _get_lr(self, var, apply_state):
  151. """Retrieves the learning rate with the given state."""
  152. # if apply_state is None:
  153. # return self._decayed_lr_t[var_dtype], {}
  154. var_name, var_device, var_dtype = var.name, var.device, var.dtype.base_dtype
  155. apply_state = apply_state or {}
  156. coefficients = apply_state.get((var_device, var_dtype))
  157. if coefficients is None:
  158. coefficients = self._fallback_apply_state(var_device, var_dtype)
  159. apply_state[(var_device, var_dtype)] = coefficients
  160. lr_t = coefficients["lr_t"]
  161. lr = coefficients["lr"]
  162. if self.layer_decay is not None:
  163. update_for_var = False
  164. for key in self.layer_decay:
  165. if key in var_name:
  166. update_for_var = True
  167. lr_t *= self.layer_decay[key]
  168. lr *= self.layer_decay[key]
  169. break
  170. if not update_for_var:
  171. raise ValueError("No learning rate specified for variable", var)
  172. return lr_t, lr, coefficients, dict(apply_state=apply_state)
  173. def _resource_apply_dense(self, grad, var, apply_state=None):
  174. # print("Dense: {} {} {}".format(var.name, var.device, var.dtype.base_dtype))
  175. lr_t, _, coefficients, kwargs = self._get_lr(var, apply_state)
  176. decay = self._decay_weights_op(var, lr_t, apply_state)
  177. with tf.control_dependencies([decay]):
  178. m = self.get_slot(var, 'm')
  179. v = self.get_slot(var, 'v')
  180. if not self.amsgrad:
  181. return training_ops.resource_apply_adam(
  182. var.handle,
  183. m.handle,
  184. v.handle,
  185. coefficients['beta_1_power'],
  186. coefficients['beta_2_power'],
  187. lr_t,
  188. coefficients['beta_1_t'],
  189. coefficients['beta_2_t'],
  190. coefficients['epsilon'],
  191. grad,
  192. use_locking=self._use_locking)
  193. else:
  194. vhat = self.get_slot(var, 'vhat')
  195. return training_ops.resource_apply_adam_with_amsgrad(
  196. var.handle,
  197. m.handle,
  198. v.handle,
  199. vhat.handle,
  200. coefficients['beta_1_power'],
  201. coefficients['beta_2_power'],
  202. lr_t,
  203. coefficients['beta_1_t'],
  204. coefficients['beta_2_t'],
  205. coefficients['epsilon'],
  206. grad,
  207. use_locking=self._use_locking)
  208. def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
  209. # print("Sparse: {} {} {}".format(var.name, var.device, var.dtype.base_dtype))
  210. lr_t, lr, coefficients, kwargs = self._get_lr(var, apply_state)
  211. decay = self._decay_weights_op(var, lr_t, apply_state)
  212. with tf.control_dependencies([decay]):
  213. # m_t = beta1 * m + (1 - beta1) * g_t
  214. m = self.get_slot(var, 'm')
  215. m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
  216. m_t = state_ops.assign(m, m * coefficients['beta_1_t'],
  217. use_locking=self._use_locking)
  218. with tf.control_dependencies([m_t]):
  219. m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
  220. # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
  221. v = self.get_slot(var, 'v')
  222. v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
  223. v_t = state_ops.assign(v, v * coefficients['beta_2_t'],
  224. use_locking=self._use_locking)
  225. with tf.control_dependencies([v_t]):
  226. v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
  227. if not self.amsgrad:
  228. v_sqrt = math_ops.sqrt(v_t)
  229. var_update = state_ops.assign_sub(
  230. var, lr * m_t / (v_sqrt + coefficients['epsilon']),
  231. use_locking=self._use_locking)
  232. return control_flow_ops.group(*[var_update, m_t, v_t])
  233. else:
  234. v_hat = self.get_slot(var, 'vhat')
  235. v_hat_t = math_ops.maximum(v_hat, v_t)
  236. with tf.control_dependencies([v_hat_t]):
  237. v_hat_t = state_ops.assign(
  238. v_hat, v_hat_t, use_locking=self._use_locking)
  239. v_hat_sqrt = math_ops.sqrt(v_hat_t)
  240. var_update = state_ops.assign_sub(
  241. var,
  242. lr * m_t / (v_hat_sqrt + coefficients['epsilon']),
  243. use_locking=self._use_locking)
  244. return control_flow_ops.group(*[var_update, m_t, v_t, v_hat_t])
  245. def get_config(self):
  246. config = super().get_config()
  247. config.update({"weight_decay_rate": self.weight_decay_rate})
  248. return config
  249. def _do_use_weight_decay(self, param_name):
  250. """Whether to use L2 weight decay for `param_name`."""
  251. if self.weight_decay_rate == 0:
  252. return False
  253. if self._include_in_weight_decay:
  254. for r in self._include_in_weight_decay:
  255. if re.search(r, param_name) is not None:
  256. return True
  257. if self._exclude_from_weight_decay:
  258. for r in self._exclude_from_weight_decay:
  259. if re.search(r, param_name) is not None:
  260. return False
  261. return True
  262. # Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
  263. class GradientAccumulator(object):
  264. """Distribution strategies-aware gradient accumulation utility."""
  265. def __init__(self):
  266. """Initializes the accumulator."""
  267. self._gradients = []
  268. self._accum_steps = tf.Variable(
  269. initial_value=0, dtype=tf.int64, trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA
  270. )
  271. @property
  272. def step(self):
  273. """Number of accumulated steps."""
  274. return self._accum_steps.value()
  275. @property
  276. def gradients(self):
  277. """The accumulated gradients."""
  278. return list(
  279. gradient.value() if gradient is not None else gradient for gradient in self._get_replica_gradients()
  280. )
  281. def __call__(self, gradients):
  282. """Accumulates :obj:`gradients`."""
  283. if not self._gradients:
  284. self._gradients.extend(
  285. [
  286. tf.Variable(tf.zeros_like(gradient), trainable=False) if gradient is not None else gradient
  287. for gradient in gradients
  288. ]
  289. )
  290. if len(gradients) != len(self._gradients):
  291. raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
  292. for accum_gradient, gradient in zip(self._get_replica_gradients(), gradients):
  293. if accum_gradient is not None and gradient is not None:
  294. accum_gradient.assign_add(gradient)
  295. self._accum_steps.assign_add(1)
  296. def reset(self):
  297. """Resets the accumulated gradients."""
  298. if self._gradients:
  299. self._accum_steps.assign(0)
  300. for gradient in self._get_replica_gradients():
  301. if gradient is not None:
  302. gradient.assign(tf.zeros_like(gradient))
  303. def _get_replica_gradients(self):
  304. if tf.distribute.has_strategy():
  305. # In a replica context, we want to accumulate gradients on each replica
  306. # without synchronization, so we directly assign the value of the
  307. # current replica.
  308. replica_context = tf.distribute.get_replica_context()
  309. if replica_context is None or tf.distribute.get_strategy().num_replicas_in_sync == 1:
  310. return self._gradients
  311. return (
  312. gradient.device_map.select_for_current_replica(gradient.values, replica_context)
  313. for gradient in self._gradients
  314. if gradient is not None
  315. )
  316. else:
  317. return self._gradients
  318. def _get_layer_decay(layer_decay, n_layers):
  319. """Have lower learning rates for layers closer to the input."""
  320. key_to_depths = collections.OrderedDict({
  321. "/embeddings/": 0,
  322. "/embeddings_project/": 0,
  323. "/start_logits/": n_layers + 2,
  324. "/end_logits/": n_layers + 2,
  325. "/answer_class/": n_layers + 2,
  326. "/qa_outputs/": n_layers + 2,
  327. })
  328. for layer in range(n_layers):
  329. key_to_depths["encoder/layer_._" + str(layer) + "/"] = layer + 1
  330. return {
  331. key: layer_decay ** (n_layers + 2 - depth)
  332. for key, depth in key_to_depths.items()
  333. }