Browse Source

Merge: [nnUNet/TF2] Update AMP API

Krzysztof Kudrynski 3 năm trước cách đây
mục cha
commit
9152d5fddb

+ 1 - 1
TensorFlow2/Segmentation/nnUNet/runtime/utils.py

@@ -58,7 +58,7 @@ def set_tf_flags(args):
     tf.config.threading.set_inter_op_parallelism_threads(max(2, (multiprocessing.cpu_count() // hvd.size()) - 2))
     tf.config.threading.set_inter_op_parallelism_threads(max(2, (multiprocessing.cpu_count() // hvd.size()) - 2))
 
 
     if args.amp:
     if args.amp:
-        tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
+        tf.keras.mixed_precision.set_global_policy("mixed_float16")
 
 
 
 
 def is_main_process():
 def is_main_process():