distributed.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # BSD 3-Clause License
  2. # Copyright (c) 2018-2020, NVIDIA Corporation
  3. # All rights reserved.
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright notice, this
  7. # list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright notice,
  9. # this list of conditions and the following disclaimer in the documentation
  10. # and/or other materials provided with the distribution.
  11. # * Neither the name of the copyright holder nor the names of its
  12. # contributors may be used to endorse or promote products derived from
  13. # this software without specific prior written permission.
  14. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  15. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  16. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  17. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  18. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  19. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  20. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  21. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  22. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  23. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  24. """https://github.com/NVIDIA/tacotron2"""
  25. import torch
  26. import torch.distributed as dist
  27. from torch.nn.modules import Module
  28. from torch.autograd import Variable
  29. def _flatten_dense_tensors(tensors):
  30. """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
  31. same dense type.
  32. Since inputs are dense, the resulting tensor will be a concatenated 1D
  33. buffer. Element-wise operation on this buffer will be equivalent to
  34. operating individually.
  35. Arguments:
  36. tensors (Iterable[Tensor]): dense tensors to flatten.
  37. Returns:
  38. A contiguous 1D buffer containing input tensors.
  39. """
  40. if len(tensors) == 1:
  41. return tensors[0].contiguous().view(-1)
  42. flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
  43. return flat
  44. def _unflatten_dense_tensors(flat, tensors):
  45. """View a flat buffer using the sizes of tensors. Assume that tensors are of
  46. same dense type, and that flat is given by _flatten_dense_tensors.
  47. Arguments:
  48. flat (Tensor): flattened dense tensors to unflatten.
  49. tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
  50. unflatten flat.
  51. Returns:
  52. Unflattened dense tensors with sizes same as tensors and values from
  53. flat.
  54. """
  55. outputs = []
  56. offset = 0
  57. for tensor in tensors:
  58. numel = tensor.numel()
  59. outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
  60. offset += numel
  61. return tuple(outputs)
  62. '''
  63. This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
  64. launcher included with this example. It assumes that your run is using multiprocess with 1
  65. GPU/process, that the model is on the correct device, and that torch.set_device has been
  66. used to set the device.
  67. Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
  68. and will be allreduced at the finish of the backward pass.
  69. '''
  70. class DistributedDataParallel(Module):
  71. def __init__(self, module):
  72. super(DistributedDataParallel, self).__init__()
  73. #fallback for PyTorch 0.3
  74. if not hasattr(dist, '_backend'):
  75. self.warn_on_half = True
  76. else:
  77. self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
  78. self.module = module
  79. for p in self.module.state_dict().values():
  80. if not torch.is_tensor(p):
  81. continue
  82. dist.broadcast(p, 0)
  83. def allreduce_params():
  84. if(self.needs_reduction):
  85. self.needs_reduction = False
  86. buckets = {}
  87. for param in self.module.parameters():
  88. if param.requires_grad and param.grad is not None:
  89. tp = type(param.data)
  90. if tp not in buckets:
  91. buckets[tp] = []
  92. buckets[tp].append(param)
  93. if self.warn_on_half:
  94. if torch.cuda.HalfTensor in buckets:
  95. print("WARNING: gloo dist backend for half parameters may be extremely slow." +
  96. " It is recommended to use the NCCL backend in this case. This currently requires" +
  97. "PyTorch built from top of tree master.")
  98. self.warn_on_half = False
  99. for tp in buckets:
  100. bucket = buckets[tp]
  101. grads = [param.grad.data for param in bucket]
  102. coalesced = _flatten_dense_tensors(grads)
  103. dist.all_reduce(coalesced)
  104. coalesced /= dist.get_world_size()
  105. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
  106. buf.copy_(synced)
  107. for param in list(self.module.parameters()):
  108. def allreduce_hook(*unused):
  109. param._execution_engine.queue_callback(allreduce_params)
  110. if param.requires_grad:
  111. param.register_hook(allreduce_hook)
  112. def forward(self, *inputs, **kwargs):
  113. self.needs_reduction = True
  114. return self.module(*inputs, **kwargs)
  115. '''
  116. def _sync_buffers(self):
  117. buffers = list(self.module._all_buffers())
  118. if len(buffers) > 0:
  119. # cross-node buffer sync
  120. flat_buffers = _flatten_dense_tensors(buffers)
  121. dist.broadcast(flat_buffers, 0)
  122. for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
  123. buf.copy_(synced)
  124. def train(self, mode=True):
  125. # Clear NCCL communicator and CUDA event cache of the default group ID,
  126. # These cache will be recreated at the later call. This is currently a
  127. # work-around for a potential NCCL deadlock.
  128. if dist._backend == dist.dist_backend.NCCL:
  129. dist._clear_group_cache()
  130. super(DistributedDataParallel, self).train(mode)
  131. self.module.train(mode)
  132. '''
  133. '''
  134. Modifies existing model to do gradient allreduce, but doesn't change class
  135. so you don't need "module"
  136. '''
  137. def apply_gradient_allreduce(module):
  138. if not hasattr(dist, '_backend'):
  139. module.warn_on_half = True
  140. else:
  141. module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
  142. for p in module.state_dict().values():
  143. if not torch.is_tensor(p):
  144. continue
  145. dist.broadcast(p, 0)
  146. def allreduce_params():
  147. if(module.needs_reduction):
  148. module.needs_reduction = False
  149. buckets = {}
  150. for param in module.parameters():
  151. if param.requires_grad and param.grad is not None:
  152. tp = param.data.dtype
  153. if tp not in buckets:
  154. buckets[tp] = []
  155. buckets[tp].append(param)
  156. if module.warn_on_half:
  157. if torch.cuda.HalfTensor in buckets:
  158. print("WARNING: gloo dist backend for half parameters may be extremely slow." +
  159. " It is recommended to use the NCCL backend in this case. This currently requires" +
  160. "PyTorch built from top of tree master.")
  161. module.warn_on_half = False
  162. for tp in buckets:
  163. bucket = buckets[tp]
  164. grads = [param.grad.data for param in bucket]
  165. coalesced = _flatten_dense_tensors(grads)
  166. dist.all_reduce(coalesced)
  167. coalesced /= dist.get_world_size()
  168. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
  169. buf.copy_(synced)
  170. for param in list(module.parameters()):
  171. def allreduce_hook(*unused):
  172. Variable._execution_engine.queue_callback(allreduce_params)
  173. if param.requires_grad:
  174. param.register_hook(allreduce_hook)
  175. def set_needs_reduction(self, input, output):
  176. self.needs_reduction = True
  177. module.register_forward_hook(set_needs_reduction)
  178. return module