train_multi_gpu.sh 725 B

123456789101112131415161718192021222324252627
  1. #!/usr/bin/env bash
  2. # CLI args with defaults
  3. BATCH_SIZE=${1:-240}
  4. AMP=${2:-true}
  5. NUM_EPOCHS=${3:-130}
  6. LEARNING_RATE=${4:-0.01}
  7. WEIGHT_DECAY=${5:-0.1}
  8. # choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
  9. # 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
  10. TASK=homo
  11. python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
  12. se3_transformer.runtime.training \
  13. --amp "$AMP" \
  14. --batch_size "$BATCH_SIZE" \
  15. --epochs "$NUM_EPOCHS" \
  16. --lr "$LEARNING_RATE" \
  17. --min_lr 0.00001 \
  18. --weight_decay "$WEIGHT_DECAY" \
  19. --use_layer_norm \
  20. --norm \
  21. --save_ckpt_path model_qm9.pth \
  22. --precompute_bases \
  23. --seed 42 \
  24. --task "$TASK"