distributed_train.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #!/usr/bin/env python3 -u
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the LICENSE file in
  6. # the root directory of this source tree. An additional grant of patent rights
  7. # can be found in the PATENTS file in the same directory.
  8. #
  9. #-------------------------------------------------------------------------
  10. #
  11. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. import os
  24. import socket
  25. import subprocess
  26. from train import main as single_process_main
  27. from fairseq import distributed_utils, options
  28. def main(args):
  29. if args.distributed_init_method is None and args.distributed_port > 0:
  30. # We can determine the init method automatically for Slurm.
  31. node_list = os.environ.get('SLURM_JOB_NODELIST')
  32. if node_list is not None:
  33. try:
  34. hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
  35. args.distributed_init_method = 'tcp://{host}:{port}'.format(
  36. host=hostnames.split()[0].decode('utf-8'),
  37. port=args.distributed_port)
  38. args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
  39. args.device_id = int(os.environ.get('SLURM_LOCALID'))
  40. except subprocess.CalledProcessError as e: # scontrol failed
  41. raise e
  42. except FileNotFoundError as e: # Slurm is not installed
  43. pass
  44. if args.distributed_init_method is None:
  45. raise ValueError('--distributed-init-method or --distributed-port '
  46. 'must be specified for distributed training')
  47. args.distributed_rank = distributed_utils.distributed_init(args)
  48. args.device_id = int(os.environ.get('LOCAL_RANK', args.local_rank))
  49. print('| initialized host {} as rank {} and device id {}'.format(socket.gethostname(), args.distributed_rank, args.device_id))
  50. single_process_main(args)
  51. if __name__ == '__main__':
  52. parser = options.get_training_parser()
  53. args = options.parse_args_and_arch(parser)
  54. main(args)