Explorar o código

[EfficientNet/TF2] remove tf async level flag

Mingyuan Ma %!s(int64=2) %!d(string=hai) anos
pai
achega
1e103522fe
Modificáronse 1 ficheiros con 3 adicións e 1 borrados
  1. 3 1
      TensorFlow2/Classification/ConvNets/utils/setup.py

+ 3 - 1
TensorFlow2/Classification/ConvNets/utils/setup.py

@@ -44,8 +44,10 @@ def set_flags(params):
         # we set tf_xla_async_io_level=0 for 2 reasons: 1) It turns out that XLA doesn't like 
         # hvd.allreduce ops used in the custom train_step. Because of this issue, training never started. 
         # 2) XLA doesn't like the tf.cond used in conditional mixing (model module).
-        os.environ['TF_XLA_FLAGS'] = TF_XLA_FLAGS + " --tf_xla_auto_jit=1 --tf_xla_async_io_level=0"            
 
+        # remove async flag since it's obsolete
+        #os.environ['TF_XLA_FLAGS'] = TF_XLA_FLAGS + " --tf_xla_auto_jit=1 --tf_xla_async_io_level=0"
+        os.environ['TF_XLA_FLAGS'] = TF_XLA_FLAGS + " --tf_xla_auto_jit=1"
         os.environ['TF_EXTRA_PTXAS_OPTIONS'] = "-sw200428197=true"
         tf.keras.backend.clear_session()
         tf.config.optimizer.set_jit(True)