|
@@ -253,7 +253,10 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
|
|
|
local = tf.constant(local)
|
|
local = tf.constant(local)
|
|
|
|
|
|
|
|
# concat all local variables into a single tensor
|
|
# concat all local variables into a single tensor
|
|
|
- local = tf.concat(local, 0)
|
|
|
|
|
|
|
+ if local is local_total_losses:
|
|
|
|
|
+ local = tf.stack(local, 0)
|
|
|
|
|
+ else:
|
|
|
|
|
+ local = tf.concat(local, 0)
|
|
|
|
|
|
|
|
# for single element lists, tf.concat will produce shape=() instead of shape=(1,).
|
|
# for single element lists, tf.concat will produce shape=() instead of shape=(1,).
|
|
|
# reshape it for hvd.allgather to work
|
|
# reshape it for hvd.allgather to work
|