sampler.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TypeVar, List
  15. import torch
  16. import numpy as np
  17. from torch.utils.data import (RandomSampler, Sampler,
  18. DistributedSampler as TorchDistributedSampler)
  19. from common.fairseq.data import data_utils
  20. T = TypeVar('T')
  21. class DistributedSampler(Sampler):
  22. def __init__(self, dataset, batch_size, world_size, rank):
  23. """
  24. Constructor for the DistributedSampler.
  25. :param dataset: dataset
  26. :param batch_size: local batch size
  27. :param world_size: number of distributed workers
  28. :param rank: rank of the current process
  29. """
  30. self.dataset = dataset
  31. self.world_size = world_size
  32. self.rank = rank
  33. self.epoch = 0
  34. self.batch_size = batch_size
  35. self.global_batch_size = batch_size * world_size
  36. self.data_len = len(self.dataset)
  37. self.num_samples = self.data_len // self.global_batch_size \
  38. * self.global_batch_size
  39. def distribute_batches(self, indices):
  40. """
  41. Assigns batches to workers.
  42. Consecutive ranks are getting consecutive batches.
  43. :param indices: torch.tensor with batch indices
  44. """
  45. assert len(indices) == self.num_samples
  46. indices = indices.view(-1, self.batch_size)
  47. indices = indices[self.rank::self.world_size].contiguous()
  48. indices = indices.view(-1)
  49. indices = indices.tolist()
  50. assert len(indices) == self.num_samples // self.world_size
  51. return indices
  52. def reshuffle_batches(self, indices, rng):
  53. """
  54. Permutes global batches
  55. :param indices: torch.tensor with batch indices
  56. :param rng: instance of torch.Generator
  57. """
  58. indices = indices.view(-1, self.global_batch_size)
  59. num_batches = indices.shape[0]
  60. order = torch.randperm(num_batches, generator=rng)
  61. indices = indices[order, :]
  62. indices = indices.view(-1)
  63. return indices
  64. def __iter__(self):
  65. g = torch.Generator()
  66. g.manual_seed(self.epoch)
  67. # generate permutation
  68. indices = torch.randperm(self.data_len, generator=g)
  69. # make indices evenly divisible by (batch_size * world_size)
  70. indices = indices[:self.num_samples]
  71. # assign batches to workers
  72. indices = self.distribute_batches(indices)
  73. return iter(indices)
  74. def set_epoch(self, epoch):
  75. """
  76. Sets current epoch index.
  77. Epoch index is used to seed RNG in __iter__() function.
  78. :param epoch: index of current epoch
  79. """
  80. self.epoch = epoch
  81. def __len__(self):
  82. return self.num_samples // self.world_size
  83. class BucketingSampler(DistributedSampler):
  84. def __init__(self, dataset, batch_size, num_buckets, world_size, rank):
  85. """
  86. Bucketing sampler with approx. equally-sized buckets.
  87. :param dataset: dataset
  88. :param batch_size: local batch size
  89. :param seeds: list of seeds, one seed for each training epoch
  90. :param num_buckets: number of buckets
  91. :param world_size: number of distributed workers
  92. :param rank: rank of the current process
  93. """
  94. super().__init__(dataset, batch_size, world_size, rank)
  95. self.num_buckets = num_buckets
  96. len_ids = np.argsort([sample['duration']
  97. for sample in dataset.samples])
  98. self.buckets = [torch.from_numpy(t)
  99. for t in np.array_split(len_ids, num_buckets)]
  100. def __iter__(self):
  101. g = torch.Generator()
  102. g.manual_seed(self.epoch)
  103. global_bsz = self.global_batch_size
  104. indices = []
  105. for bid in range(self.num_buckets):
  106. # random shuffle within current bucket
  107. perm = torch.randperm(len(self.buckets[bid]), generator=g)
  108. bucket_indices = self.buckets[bid][perm]
  109. # add samples from current bucket to indices for current epoch
  110. indices.append(bucket_indices)
  111. indices = torch.cat(indices)
  112. # make indices evenly divisible by global batch size
  113. length = len(indices) // global_bsz * global_bsz
  114. indices = indices[:length]
  115. assert len(indices) % self.global_batch_size == 0
  116. # perform global reshuffle of all global batches
  117. indices = self.reshuffle_batches(indices, g)
  118. # distribute batches to individual workers
  119. indices = self.distribute_batches(indices)
  120. return iter(indices)
  121. class DistributedIndicesSampler(TorchDistributedSampler):
  122. """ DistributedSampler operating on indices.
  123. Differences wrt. DistributedSampler:
  124. 1) use Numpy RNG instead of PyTorch RNG
  125. 2) treat `self.dataset` as indices - DistributedSampler assumes indices
  126. are determined with `range(len(self.dataset))`
  127. 3) if `drop_last` is False, pad indices with `fillvalue`
  128. or don't pad at all if `fillvalue` is None (useful for validation)
  129. """
  130. def __init__(self, *args, fillvalue=None, **kwargs):
  131. super().__init__(*args, **kwargs)
  132. self.fillvalue = fillvalue
  133. if not self.drop_last and self.fillvalue is None:
  134. self.total_size = len(self.dataset)
  135. # possibly different num_samples for each device,
  136. # this will work with DDP only for validation
  137. self.num_samples = len(range(self.rank, self.total_size,
  138. self.num_replicas))
  139. def __iter__(self):
  140. indices = list(self.dataset)
  141. if self.shuffle:
  142. # deterministically shuffle based on epoch and seed
  143. with data_utils.numpy_seed(self.seed + self.epoch):
  144. np.random.shuffle(indices)
  145. if not self.drop_last:
  146. if self.fillvalue is not None:
  147. # add extra samples to make it evenly divisible
  148. padding_size = self.total_size - len(indices)
  149. indices += [self.fillvalue] * padding_size
  150. else:
  151. # remove tail of data to make it evenly divisible.
  152. indices = indices[:self.total_size]
  153. assert len(indices) == self.total_size
  154. # subsample
  155. indices = indices[self.rank:self.total_size:self.num_replicas]
  156. assert len(indices) == self.num_samples
  157. return iter(indices)
  158. class RandomSeedableSampler(RandomSampler):
  159. def __init__(self, *args, generator=None, seed=0, **kwargs):
  160. if generator is None:
  161. generator = torch.Generator()
  162. if seed is not None:
  163. generator.manual_seed(seed)
  164. super().__init__(*args, generator=generator, **kwargs)
  165. self.epoch = 0
  166. self.seed = seed
  167. def __iter__(self):
  168. self.generator.manual_seed(self.seed + self.epoch)
  169. return super().__iter__()
  170. def set_epoch(self, epoch: int) -> None:
  171. """ Allows reproducibility after resuming training. """
  172. self.epoch = epoch
  173. class IndexMappingSampler(Sampler[T]):
  174. """ Transforms index-based sampler to arbitrary one, e.g. batch-based. """
  175. def __init__(self, indices_map: List[T], base_sampler: Sampler[int]):
  176. super().__init__(indices_map)
  177. self.base_sampler = base_sampler
  178. self.indices_map = indices_map
  179. assert len(self.base_sampler) <= len(indices_map)
  180. def __iter__(self):
  181. return map(lambda ind: self.indices_map[ind], iter(self.base_sampler))
  182. def __len__(self):
  183. return len(self.base_sampler)
  184. def set_epoch(self, epoch: int) -> None:
  185. """ Allows reproducibility after resuming training. """
  186. self.base_sampler.set_epoch(epoch)