optimization.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # coding=utf-8
  2. # Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
  3. # Copyright 2018 The Google AI Language Team Authors.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Functions and classes related to optimization (weight updates)."""
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import re
  21. import tensorflow as tf
  22. from tensorflow.python.ops import array_ops
  23. from tensorflow.python.ops import linalg_ops
  24. from tensorflow.python.ops import math_ops
  25. from horovod.tensorflow.compression import Compression
  26. def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, hvd=None, manual_fp16=False, use_fp16=False, num_accumulation_steps=1,
  27. optimizer_type="adam", allreduce_post_accumulation=False, init_loss_scale=2**32):
  28. """Creates an optimizer training op."""
  29. global_step = tf.compat.v1.train.get_or_create_global_step()
  30. # avoid step change in learning rate at end of warmup phase
  31. if optimizer_type == "adam":
  32. power = 1.0
  33. decayed_learning_rate_at_crossover_point = init_lr * (
  34. (1.0 - float(num_warmup_steps) / float(num_train_steps)) ** power)
  35. else:
  36. power = 0.5
  37. decayed_learning_rate_at_crossover_point = init_lr
  38. adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point)
  39. print('decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e' % (decayed_learning_rate_at_crossover_point, adjusted_init_lr))
  40. learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32)
  41. # Implements linear decay of the learning rate.
  42. learning_rate = tf.compat.v1.train.polynomial_decay(
  43. learning_rate,
  44. global_step,
  45. num_train_steps,
  46. end_learning_rate=0.0,
  47. power=power,
  48. cycle=False)
  49. # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
  50. # learning rate will be `global_step/num_warmup_steps * init_lr`.
  51. if num_warmup_steps:
  52. global_steps_int = tf.cast(global_step, tf.int32)
  53. warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
  54. global_steps_float = tf.cast(global_steps_int, tf.float32)
  55. warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
  56. warmup_percent_done = global_steps_float / warmup_steps_float
  57. warmup_learning_rate = init_lr * warmup_percent_done
  58. is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
  59. learning_rate = (
  60. (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
  61. if optimizer_type == "lamb":
  62. print("Initializing LAMB Optimizer")
  63. optimizer = LAMBOptimizer(
  64. learning_rate=learning_rate,
  65. weight_decay_rate=0.01,
  66. beta_1=0.9,
  67. beta_2=0.999,
  68. epsilon=1e-6,
  69. exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
  70. else:
  71. print("Initializing ADAM Weight Decay Optimizer")
  72. # It is recommended that you use this optimizer for fine tuning, since this
  73. # is how the model was trained (note that the Adam m/v variables are NOT
  74. # loaded from init_checkpoint.)
  75. optimizer = AdamWeightDecayOptimizer(
  76. learning_rate=learning_rate,
  77. weight_decay_rate=0.01,
  78. beta_1=0.9,
  79. beta_2=0.999,
  80. epsilon=1e-6,
  81. exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
  82. if hvd is not None and (num_accumulation_steps == 1 or (not allreduce_post_accumulation)):
  83. optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=Compression.fp16 if use_fp16 or manual_fp16 else Compression.none)
  84. if use_fp16:
  85. loss_scaler = tf.train.experimental.DynamicLossScale(initial_loss_scale=init_loss_scale, increment_period=1000, multiplier=2.0)
  86. optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer, loss_scaler)
  87. loss_scale_value = tf.identity(loss_scaler(), name="loss_scale")
  88. if manual_fp16:
  89. loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=init_loss_scale,
  90. incr_every_n_steps=1000,
  91. decr_every_n_nan_or_inf=2,
  92. decr_ratio=0.5)
  93. optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager)
  94. tvars = tf.trainable_variables()
  95. grads_and_vars = optimizer.compute_gradients(loss * 1.0 / num_accumulation_steps, tvars)
  96. if num_accumulation_steps > 1:
  97. local_step = tf.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False,
  98. initializer=tf.zeros_initializer)
  99. batch_finite = tf.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False,
  100. initializer=tf.ones_initializer)
  101. accum_vars = [tf.get_variable(
  102. name=tvar.name.split(":")[0] + "/accum",
  103. shape=tvar.shape.as_list(),
  104. dtype=tf.float32,
  105. trainable=False,
  106. initializer=tf.zeros_initializer()) for tvar in tf.trainable_variables()]
  107. reset_step = tf.cast(tf.math.equal(local_step % num_accumulation_steps, 0), dtype=tf.bool)
  108. local_step = tf.cond(reset_step, lambda:local_step.assign(tf.ones_like(local_step)), lambda:local_step.assign_add(1))
  109. grads_and_vars_and_accums = [(gv[0],gv[1],accum_vars[i]) for i, gv in enumerate(grads_and_vars) if gv[0] is not None]
  110. grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums))
  111. all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if manual_fp16 or use_fp16 else tf.constant(True, dtype=tf.bool)
  112. batch_finite = tf.cond(reset_step,
  113. lambda: batch_finite.assign(tf.math.logical_and(tf.constant(True, dtype=tf.bool), all_are_finite)),
  114. lambda:batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite)))
  115. # This is how the model was pre-trained.
  116. # ensure global norm is a finite number
  117. # to prevent clip_by_global_norm from having a hizzy fit.
  118. (clipped_grads, _) = tf.clip_by_global_norm(
  119. grads, clip_norm=1.0,
  120. use_norm=tf.cond(
  121. all_are_finite,
  122. lambda: tf.global_norm(grads),
  123. lambda: tf.constant(1.0)))
  124. accum_vars = tf.cond(reset_step,
  125. lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(clipped_grads)],
  126. lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)])
  127. def update(accum_vars):
  128. if allreduce_post_accumulation and hvd is not None:
  129. accum_vars = [hvd.allreduce(tf.convert_to_tensor(accum_var), compression=Compression.fp16 if use_fp16 or manual_fp16 else Compression.none) if isinstance(accum_var, tf.IndexedSlices)
  130. else hvd.allreduce(accum_var, compression=Compression.fp16 if use_fp16 or manual_fp16 else Compression.none) for accum_var in accum_vars]
  131. return optimizer.apply_gradients(list(zip(accum_vars, tvars)), global_step=global_step)
  132. update_step = tf.identity(tf.cast(tf.math.equal(local_step % num_accumulation_steps, 0), dtype=tf.bool), name="update_step")
  133. update_op = tf.cond(update_step,
  134. lambda: update(accum_vars), lambda: tf.no_op())
  135. new_global_step = tf.cond(tf.math.logical_and(update_step,
  136. tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)), tf.bool) if hvd is not None else batch_finite),
  137. lambda: global_step+1,
  138. lambda: global_step)
  139. new_global_step = tf.identity(new_global_step, name='step_update')
  140. train_op = tf.group(update_op, [global_step.assign(new_global_step)])
  141. else:
  142. grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
  143. grads, tvars = list(zip(*grads_and_vars))
  144. all_are_finite = tf.reduce_all(
  145. [tf.reduce_all(tf.is_finite(g)) for g in grads]) if use_fp16 or manual_fp16 else tf.constant(True, dtype=tf.bool)
  146. # This is how the model was pre-trained.
  147. # ensure global norm is a finite number
  148. # to prevent clip_by_global_norm from having a hizzy fit.
  149. (clipped_grads, _) = tf.clip_by_global_norm(
  150. grads, clip_norm=1.0,
  151. use_norm=tf.cond(
  152. all_are_finite,
  153. lambda: tf.global_norm(grads),
  154. lambda: tf.constant(1.0)))
  155. train_op = optimizer.apply_gradients(
  156. list(zip(clipped_grads, tvars)), global_step=global_step)
  157. new_global_step = tf.cond(all_are_finite, lambda: global_step + 1, lambda: global_step)
  158. new_global_step = tf.identity(new_global_step, name='step_update')
  159. train_op = tf.group(train_op, [global_step.assign(new_global_step)])
  160. return train_op
  161. class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer):
  162. """A basic Adam optimizer that includes "correct" L2 weight decay."""
  163. def __init__(self,
  164. learning_rate,
  165. weight_decay_rate=0.0,
  166. beta_1=0.9,
  167. beta_2=0.999,
  168. epsilon=1e-6,
  169. exclude_from_weight_decay=None,
  170. name="AdamWeightDecayOptimizer"):
  171. """Constructs a AdamWeightDecayOptimizer."""
  172. super(AdamWeightDecayOptimizer, self).__init__(False, name)
  173. self.learning_rate = tf.identity(learning_rate, name='learning_rate')
  174. self.weight_decay_rate = weight_decay_rate
  175. self.beta_1 = beta_1
  176. self.beta_2 = beta_2
  177. self.epsilon = epsilon
  178. self.exclude_from_weight_decay = exclude_from_weight_decay
  179. def apply_gradients(self, grads_and_vars, global_step=None, name=None,
  180. manual_fp16=False):
  181. """See base class."""
  182. assignments = []
  183. for (grad, param) in grads_and_vars:
  184. if grad is None or param is None:
  185. continue
  186. param_name = self._get_variable_name(param.name)
  187. has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32
  188. if has_shadow:
  189. # create shadow fp32 weights for fp16 variable
  190. param_fp32 = tf.get_variable(
  191. name=param_name + "/shadow",
  192. dtype=tf.float32,
  193. trainable=False,
  194. initializer=tf.cast(param.initialized_value(),tf.float32))
  195. else:
  196. param_fp32 = param
  197. m = tf.get_variable(
  198. name=param_name + "/adam_m",
  199. shape=param.shape.as_list(),
  200. dtype=tf.float32,
  201. trainable=False,
  202. initializer=tf.zeros_initializer())
  203. v = tf.get_variable(
  204. name=param_name + "/adam_v",
  205. shape=param.shape.as_list(),
  206. dtype=tf.float32,
  207. trainable=False,
  208. initializer=tf.zeros_initializer())
  209. # Standard Adam update.
  210. next_m = (
  211. tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
  212. next_v = (
  213. tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
  214. tf.square(grad)))
  215. update = next_m / (tf.sqrt(next_v) + self.epsilon)
  216. # Just adding the square of the weights to the loss function is *not*
  217. # the correct way of using L2 regularization/weight decay with Adam,
  218. # since that will interact with the m and v parameters in strange ways.
  219. #
  220. # Instead we want to decay the weights in a manner that doesn't interact
  221. # with the m/v parameters. This is equivalent to adding the square
  222. # of the weights to the loss with plain (non-momentum) SGD.
  223. if self._do_use_weight_decay(param_name):
  224. update += self.weight_decay_rate * param_fp32
  225. update_with_lr = self.learning_rate * update
  226. next_param = param_fp32 - update_with_lr
  227. if has_shadow:
  228. # cast shadow fp32 weights to fp16 and assign to trainable variable
  229. param.assign(tf.cast(next_param, param.dtype.base_dtype))
  230. assignments.extend(
  231. [param_fp32.assign(next_param),
  232. m.assign(next_m),
  233. v.assign(next_v)])
  234. return tf.group(*assignments, name=name)
  235. def _do_use_weight_decay(self, param_name):
  236. """Whether to use L2 weight decay for `param_name`."""
  237. if not self.weight_decay_rate:
  238. return False
  239. if self.exclude_from_weight_decay:
  240. for r in self.exclude_from_weight_decay:
  241. if re.search(r, param_name) is not None:
  242. return False
  243. return True
  244. def _get_variable_name(self, param_name):
  245. """Get the variable name from the tensor name."""
  246. m = re.match("^(.*):\\d+$", param_name)
  247. if m is not None:
  248. param_name = m.group(1)
  249. return param_name
  250. class LAMBOptimizer(tf.compat.v1.train.Optimizer):
  251. """A LAMB optimizer that includes "correct" L2 weight decay."""
  252. def __init__(self,
  253. learning_rate,
  254. weight_decay_rate=0.0,
  255. beta_1=0.9,
  256. beta_2=0.999,
  257. epsilon=1e-6,
  258. exclude_from_weight_decay=None,
  259. name="LAMBOptimizer"):
  260. """Constructs a LAMBOptimizer."""
  261. super(LAMBOptimizer, self).__init__(False, name)
  262. self.learning_rate = tf.identity(learning_rate, name='learning_rate')
  263. self.weight_decay_rate = weight_decay_rate
  264. self.beta_1 = beta_1
  265. self.beta_2 = beta_2
  266. self.epsilon = epsilon
  267. self.exclude_from_weight_decay = exclude_from_weight_decay
  268. def apply_gradients(self, grads_and_vars, global_step, name=None,
  269. manual_fp16=False):
  270. """See base class."""
  271. assignments = []
  272. steps = tf.cast(global_step, tf.float32)
  273. for (grad, param) in grads_and_vars:
  274. if grad is None or param is None:
  275. continue
  276. param_name = self._get_variable_name(param.name)
  277. has_shadow = manual_fp16 and param.dtype.base_dtype != tf.float32
  278. if has_shadow:
  279. # create shadow fp32 weights for fp16 variable
  280. param_fp32 = tf.get_variable(
  281. name=param_name + "/shadow",
  282. dtype=tf.float32,
  283. trainable=False,
  284. initializer=tf.cast(param.initialized_value(),tf.float32))
  285. else:
  286. param_fp32 = param
  287. m = tf.get_variable(
  288. name=param_name + "/adam_m",
  289. shape=param.shape.as_list(),
  290. dtype=tf.float32,
  291. trainable=False,
  292. initializer=tf.zeros_initializer())
  293. v = tf.get_variable(
  294. name=param_name + "/adam_v",
  295. shape=param.shape.as_list(),
  296. dtype=tf.float32,
  297. trainable=False,
  298. initializer=tf.zeros_initializer())
  299. # LAMB update
  300. next_m = (
  301. tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
  302. next_v = (
  303. tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
  304. tf.square(grad)))
  305. beta1_correction = (1 - self.beta_1 ** steps)
  306. beta2_correction = (1 - self.beta_2 ** steps)
  307. next_m_unbiased = next_m / beta1_correction
  308. next_v_unbiased = next_v / beta2_correction
  309. update = next_m_unbiased / (tf.sqrt(next_v_unbiased) + self.epsilon)
  310. # Just adding the square of the weights to the loss function is *not*
  311. # the correct way of using L2 regularization/weight decay with Adam,
  312. # since that will interact with the m and v parameters in strange ways.
  313. #
  314. # Instead we want to decay the weights in a manner that doesn't interact
  315. # with the m/v parameters. This is equivalent to adding the square
  316. # of the weights to the loss with plain (non-momentum) SGD.
  317. if self._do_use_weight_decay(param_name):
  318. update += self.weight_decay_rate * param_fp32
  319. w_norm = linalg_ops.norm(param, ord=2)
  320. g_norm = linalg_ops.norm(update, ord=2)
  321. ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
  322. math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
  323. update_with_lr = ratio * self.learning_rate * update
  324. next_param = param_fp32 - update_with_lr
  325. if has_shadow:
  326. # cast shadow fp32 weights to fp16 and assign to trainable variable
  327. param.assign(tf.cast(next_param, param.dtype.base_dtype))
  328. assignments.extend(
  329. [param_fp32.assign(next_param),
  330. m.assign(next_m),
  331. v.assign(next_v)])
  332. return tf.group(*assignments, name=name)
  333. def _do_use_weight_decay(self, param_name):
  334. """Whether to use L2 weight decay for `param_name`."""
  335. if not self.weight_decay_rate:
  336. return False
  337. if self.exclude_from_weight_decay:
  338. for r in self.exclude_from_weight_decay:
  339. if re.search(r, param_name) is not None:
  340. return False
  341. return True
  342. def _get_variable_name(self, param_name):
  343. """Get the variable name from the tensor name."""
  344. m = re.match("^(.*):\\d+$", param_name)
  345. if m is not None:
  346. param_name = m.group(1)
  347. return param_name