Browse Source

Merge: [SIM/TF2] Fix concat bug from TensorFlow 2.11

Krzysztof Kudrynski 3 năm trước cách đây
mục cha
commit
dff49355d6
1 tập tin đã thay đổi với 4 bổ sung1 xóa
  1. 4 1
      TensorFlow2/Recommendation/SIM/main.py

+ 4 - 1
TensorFlow2/Recommendation/SIM/main.py

@@ -253,7 +253,10 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
             local = tf.constant(local)
 
         # concat all local variables into a single tensor
-        local = tf.concat(local, 0)
+        if local is local_total_losses:
+            local = tf.stack(local, 0)
+        else:
+            local = tf.concat(local, 0)
 
         # for single element lists, tf.concat will produce shape=() instead of shape=(1,).
         # reshape it for hvd.allgather to work