|
|
@@ -301,12 +301,12 @@ class LAMBOptimizer(tf.compat.v1.train.Optimizer):
|
|
|
self.beta_2 = beta_2
|
|
|
self.epsilon = epsilon
|
|
|
self.exclude_from_weight_decay = exclude_from_weight_decay
|
|
|
- self.steps = 0
|
|
|
|
|
|
- def apply_gradients(self, grads_and_vars, global_step=None, name=None,
|
|
|
+ def apply_gradients(self, grads_and_vars, global_step, name=None,
|
|
|
manual_fp16=False):
|
|
|
"""See base class."""
|
|
|
assignments = []
|
|
|
+ steps = tf.cast(global_step, tf.float32)
|
|
|
for (grad, param) in grads_and_vars:
|
|
|
if grad is None or param is None:
|
|
|
continue
|
|
|
@@ -343,9 +343,8 @@ class LAMBOptimizer(tf.compat.v1.train.Optimizer):
|
|
|
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
|
|
|
tf.square(grad)))
|
|
|
|
|
|
- self.steps += 1
|
|
|
- beta1_correction = (1 - self.beta_1 ** self.steps)
|
|
|
- beta2_correction = (1 - self.beta_2 ** self.steps)
|
|
|
+ beta1_correction = (1 - self.beta_1 ** steps)
|
|
|
+ beta2_correction = (1 - self.beta_2 ** steps)
|
|
|
|
|
|
next_m_unbiased = next_m / beta1_correction
|
|
|
next_v_unbiased = next_v / beta2_correction
|