|
|
@@ -44,8 +44,10 @@ def set_flags(params):
|
|
|
# we set tf_xla_async_io_level=0 for 2 reasons: 1) It turns out that XLA doesn't like
|
|
|
# hvd.allreduce ops used in the custom train_step. Because of this issue, training never started.
|
|
|
# 2) XLA doesn't like the tf.cond used in conditional mixing (model module).
|
|
|
- os.environ['TF_XLA_FLAGS'] = TF_XLA_FLAGS + " --tf_xla_auto_jit=1 --tf_xla_async_io_level=0"
|
|
|
|
|
|
+ # remove async flag since it's obsolete
|
|
|
+ #os.environ['TF_XLA_FLAGS'] = TF_XLA_FLAGS + " --tf_xla_auto_jit=1 --tf_xla_async_io_level=0"
|
|
|
+ os.environ['TF_XLA_FLAGS'] = TF_XLA_FLAGS + " --tf_xla_auto_jit=1"
|
|
|
os.environ['TF_EXTRA_PTXAS_OPTIONS'] = "-sw200428197=true"
|
|
|
tf.keras.backend.clear_session()
|
|
|
tf.config.optimizer.set_jit(True)
|