Просмотр исходного кода

[BERT/TF] bug fix in beta bias correction terms (#395)

Swetha Mandava 6 лет назад
Родитель
Сommit
0845aaa901
1 измененных файлов с 4 добавлено и 5 удалено
  1. 4 5
      TensorFlow/LanguageModeling/BERT/optimization.py

+ 4 - 5
TensorFlow/LanguageModeling/BERT/optimization.py

@@ -301,12 +301,12 @@ class LAMBOptimizer(tf.compat.v1.train.Optimizer):
     self.beta_2 = beta_2
     self.epsilon = epsilon
     self.exclude_from_weight_decay = exclude_from_weight_decay
-    self.steps = 0
 
-  def apply_gradients(self, grads_and_vars, global_step=None, name=None,
+  def apply_gradients(self, grads_and_vars, global_step, name=None,
       manual_fp16=False):
     """See base class."""
     assignments = []
+    steps = tf.cast(global_step, tf.float32)
     for (grad, param) in grads_and_vars:
       if grad is None or param is None:
         continue
@@ -343,9 +343,8 @@ class LAMBOptimizer(tf.compat.v1.train.Optimizer):
           tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
                                                     tf.square(grad)))
 
-      self.steps += 1
-      beta1_correction = (1 - self.beta_1 ** self.steps)
-      beta2_correction = (1 - self.beta_2 ** self.steps)
+      beta1_correction = (1 - self.beta_1 ** steps)
+      beta2_correction = (1 - self.beta_2 ** steps)
 
       next_m_unbiased = next_m / beta1_correction
       next_v_unbiased = next_v / beta2_correction