瀏覽代碼

Merge: [SIM/TF2] Fix concat bug from TensorFlow 2.11

Krzysztof Kudrynski 3 年之前
父節點
當前提交
dff49355d6
共有 1 個文件被更改,包括 4 次插入1 次删除
  1. 4 1
      TensorFlow2/Recommendation/SIM/main.py

+ 4 - 1
TensorFlow2/Recommendation/SIM/main.py

@@ -253,7 +253,10 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
             local = tf.constant(local)
 
         # concat all local variables into a single tensor
-        local = tf.concat(local, 0)
+        if local is local_total_losses:
+            local = tf.stack(local, 0)
+        else:
+            local = tf.concat(local, 0)
 
         # for single element lists, tf.concat will produce shape=() instead of shape=(1,).
         # reshape it for hvd.allgather to work