UNet_AMP_4GPU.sh 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #!/usr/bin/env bash
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # This script launches UNet training in FP32-AMP on 4 GPUs using 16 batch size (4 per GPU)
  16. # Usage ./UNet_AMP_4GPU_XLA.sh <path to result repository> <path to dataset> <dagm classID (1-10)>
  17. BASEDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
  18. export TF_CPP_MIN_LOG_LEVEL=3
  19. mpirun \
  20. -np 4 \
  21. -H localhost:4 \
  22. -bind-to none \
  23. -map-by slot \
  24. -x NCCL_DEBUG=VERSION \
  25. -x LD_LIBRARY_PATH \
  26. -x PATH \
  27. -mca pml ob1 -mca btl ^openib \
  28. --allow-run-as-root \
  29. python "${BASEDIR}/../main.py" \
  30. --unet_variant='tinyUNet' \
  31. --activation_fn='relu' \
  32. --exec_mode='train_and_evaluate' \
  33. --iter_unit='batch' \
  34. --num_iter=2500 \
  35. --batch_size=4 \
  36. --warmup_step=10 \
  37. --results_dir="${1}" \
  38. --data_dir="${2}" \
  39. --dataset_name='DAGM2007' \
  40. --dataset_classID="${3}" \
  41. --data_format='NCHW' \
  42. --use_auto_loss_scaling \
  43. --amp \
  44. --xla \
  45. --learning_rate=1e-4 \
  46. --learning_rate_decay_factor=0.8 \
  47. --learning_rate_decay_steps=500 \
  48. --rmsprop_decay=0.9 \
  49. --rmsprop_momentum=0.8 \
  50. --loss_fn_name='adaptive_loss' \
  51. --weight_decay=1e-5 \
  52. --weight_init_method='he_uniform' \
  53. --augment_data \
  54. --display_every=250 \
  55. --debug_verbosity=0