Pārlūkot izejas kodu

[SIM/TF2] Fix concat bug from TensorFlow 2.11

Jakub Tomsia 3 gadi atpakaļ
vecāks
revīzija
17c268ff33
1 mainītis faili ar 4 papildinājumiem un 1 dzēšanām
  1. 4 1
      TensorFlow2/Recommendation/SIM/main.py

+ 4 - 1
TensorFlow2/Recommendation/SIM/main.py

@@ -253,7 +253,10 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
             local = tf.constant(local)
 
         # concat all local variables into a single tensor
-        local = tf.concat(local, 0)
+        if local is local_total_losses:
+            local = tf.stack(local, 0)
+        else:
+            local = tf.concat(local, 0)
 
         # for single element lists, tf.concat will produce shape=() instead of shape=(1,).
         # reshape it for hvd.allgather to work