Преглед изворни кода

Merge: [SIM/TF2] Fix concat bug from TensorFlow 2.11

Krzysztof Kudrynski пре 3 година
родитељ
комит
dff49355d6
1 измењених фајлова са 4 додато и 1 уклоњено
  1. 4 1
      TensorFlow2/Recommendation/SIM/main.py

+ 4 - 1
TensorFlow2/Recommendation/SIM/main.py

@@ -253,7 +253,10 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
             local = tf.constant(local)
 
         # concat all local variables into a single tensor
-        local = tf.concat(local, 0)
+        if local is local_total_losses:
+            local = tf.stack(local, 0)
+        else:
+            local = tf.concat(local, 0)
 
         # for single element lists, tf.concat will produce shape=() instead of shape=(1,).
         # reshape it for hvd.allgather to work