utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import errno
  15. import os
  16. import time
  17. from collections import defaultdict, deque
  18. import dllogger
  19. import torch
  20. import torch.distributed as dist
  21. from dlrm.utils.distributed import is_dist_avail_and_initialized
  22. class SmoothedValue(object):
  23. """Track a series of values and provide access to smoothed values over a
  24. window or the global series average.
  25. """
  26. def __init__(self, window_size=20, fmt=None):
  27. if fmt is None:
  28. fmt = "{median:.4f} ({global_avg:.4f})"
  29. self.deque = deque(maxlen=window_size)
  30. self.total = 0.0
  31. self.count = 0
  32. self.fmt = fmt
  33. def update(self, value, n=1):
  34. self.deque.append(value)
  35. self.count += n
  36. self.total += value * n
  37. def synchronize_between_processes(self):
  38. """
  39. Warning: does not synchronize the deque!
  40. """
  41. if not is_dist_avail_and_initialized():
  42. return
  43. t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
  44. dist.barrier()
  45. dist.all_reduce(t)
  46. t = t.tolist()
  47. self.count = int(t[0])
  48. self.total = t[1]
  49. @property
  50. def median(self):
  51. d = torch.tensor(list(self.deque))
  52. return d.median().item()
  53. @property
  54. def avg(self):
  55. d = torch.tensor(list(self.deque), dtype=torch.float32)
  56. return d.mean().item()
  57. @property
  58. def global_avg(self):
  59. return self.total / self.count
  60. @property
  61. def max(self):
  62. return max(self.deque)
  63. @property
  64. def value(self):
  65. return self.deque[-1]
  66. def __str__(self):
  67. return self.fmt.format(
  68. median=self.median,
  69. avg=self.avg,
  70. global_avg=self.global_avg,
  71. max=self.max,
  72. value=self.value)
  73. class MetricLogger(object):
  74. def __init__(self, delimiter="\t"):
  75. self.meters = defaultdict(SmoothedValue)
  76. self.delimiter = delimiter
  77. def update(self, **kwargs):
  78. for k, v in kwargs.items():
  79. if isinstance(v, torch.Tensor):
  80. v = v.item()
  81. assert isinstance(v, (float, int))
  82. self.meters[k].update(v)
  83. def __getattr__(self, attr):
  84. if attr in self.meters:
  85. return self.meters[attr]
  86. if attr in self.__dict__:
  87. return self.__dict__[attr]
  88. raise AttributeError("'{}' object has no attribute '{}'".format(
  89. type(self).__name__, attr))
  90. def __str__(self):
  91. loss_str = []
  92. for name, meter in self.meters.items():
  93. loss_str.append(
  94. "{}: {}".format(name, str(meter))
  95. )
  96. return self.delimiter.join(loss_str)
  97. def synchronize_between_processes(self):
  98. for meter in self.meters.values():
  99. meter.synchronize_between_processes()
  100. def add_meter(self, name, meter):
  101. self.meters[name] = meter
  102. def print(self, header=None):
  103. if not header:
  104. header = ''
  105. print_str = header
  106. for name, meter in self.meters.items():
  107. print_str += f" {name}: {meter}"
  108. print(print_str)
  109. def accuracy(output, target, topk=(1,)):
  110. """Computes the accuracy over the k top predictions for the specified values of k"""
  111. with torch.no_grad():
  112. maxk = max(topk)
  113. batch_size = target.size(0)
  114. _, pred = output.topk(maxk, 1, True, True)
  115. pred = pred.t()
  116. correct = pred.eq(target[None])
  117. res = []
  118. for k in topk:
  119. correct_k = correct[:k].flatten().sum(dtype=torch.float32)
  120. res.append(correct_k * (100.0 / batch_size))
  121. return res
  122. def lr_step(optim, num_warmup_iter, current_step, base_lr, warmup_factor, decay_steps=0, decay_start_step=None):
  123. if decay_start_step is None:
  124. decay_start_step = num_warmup_iter
  125. new_lr = base_lr
  126. if decay_start_step < num_warmup_iter:
  127. raise ValueError('Learning rate warmup must finish before decay starts')
  128. if current_step <= num_warmup_iter:
  129. warmup_step = base_lr / (num_warmup_iter * (2 ** warmup_factor))
  130. new_lr = base_lr - (num_warmup_iter - current_step) * warmup_step
  131. steps_since_decay_start = current_step - decay_start_step
  132. if decay_steps != 0 and steps_since_decay_start > 0:
  133. already_decayed_steps = min(steps_since_decay_start, decay_steps)
  134. new_lr = base_lr * ((decay_steps - already_decayed_steps) / decay_steps) ** 2
  135. min_lr = 0.0000001
  136. new_lr = max(min_lr, new_lr)
  137. for param_group in optim.param_groups:
  138. param_group['lr'] = new_lr
  139. def mkdir(path):
  140. try:
  141. os.makedirs(path)
  142. except OSError as e:
  143. if e.errno != errno.EEXIST:
  144. raise
  145. def init_logging(log_path):
  146. json_backend = dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  147. filename=log_path)
  148. stdout_backend = dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE)
  149. stdout_backend._metadata['best_auc'].update({'format': '0:.5f'})
  150. stdout_backend._metadata['best_epoch'].update({'format': '0:.2f'})
  151. stdout_backend._metadata['average_train_throughput'].update({'format': ':.2e'})
  152. stdout_backend._metadata['average_test_throughput'].update({'format': ':.2e'})
  153. dllogger.init(backends=[json_backend, stdout_backend])
  154. class StepTimer():
  155. def __init__(self):
  156. self._previous = None
  157. self._new = None
  158. self.measured = None
  159. def click(self):
  160. self._previous = self._new
  161. self._new = time.time()
  162. if self._previous is not None:
  163. self.measured = self._new - self._previous
  164. class LearningRateScheduler:
  165. """Polynomial learning rate decay for multiple optimizers and multiple param groups
  166. Args:
  167. optimizers (list): optimizers for which to apply the learning rate changes
  168. base_lrs (list): a nested list of base_lrs to use for each param_group of each optimizer
  169. warmup_steps (int): number of linear warmup steps to perform at the beginning of training
  170. warmup_factor (int)
  171. decay_steps (int): number of steps over which to apply poly LR decay from base_lr to 0
  172. decay_start_step (int): the optimization step at which to start decaying the learning rate
  173. if None will start the decay immediately after
  174. decay_power (float): polynomial learning rate decay power
  175. end_lr_factor (float): for each optimizer and param group:
  176. lr = max(current_lr_factor, end_lr_factor) * base_lr
  177. Example:
  178. lr_scheduler = LearningRateScheduler(optimizers=[optimizer], base_lrs=[[lr]],
  179. warmup_steps=100, warmup_factor=0,
  180. decay_start_step=1000, decay_steps=2000,
  181. decay_power=2, end_lr_factor=1e-6)
  182. for batch in data_loader:
  183. lr_scheduler.step()
  184. # foward, backward, weight update
  185. """
  186. def __init__(self, optimizers, base_lrs, warmup_steps, warmup_factor,
  187. decay_steps, decay_start_step, decay_power=2, end_lr_factor=0):
  188. self.current_step = 0
  189. self.optimizers = optimizers
  190. self.base_lrs = base_lrs
  191. self.warmup_steps = warmup_steps
  192. self.warmup_factor = warmup_factor
  193. self.decay_steps = decay_steps
  194. self.decay_start_step = decay_start_step
  195. self.decay_power = decay_power
  196. self.end_lr_factor = end_lr_factor
  197. self.decay_end_step = self.decay_start_step + self.decay_steps
  198. if self.decay_start_step < self.warmup_steps:
  199. raise ValueError('Learning rate warmup must finish before decay starts')
  200. def _compute_lr_factor(self):
  201. lr_factor = 1
  202. if self.current_step <= self.warmup_steps:
  203. warmup_step = 1 / (self.warmup_steps * (2 ** self.warmup_factor))
  204. lr_factor = 1 - (self.warmup_steps - self.current_step) * warmup_step
  205. elif self.decay_start_step < self.current_step <= self.decay_end_step:
  206. lr_factor = ((self.decay_end_step - self.current_step) / self.decay_steps) ** self.decay_power
  207. lr_factor = max(lr_factor, self.end_lr_factor)
  208. elif self.current_step > self.decay_end_step:
  209. lr_factor = self.end_lr_factor
  210. return lr_factor
  211. def step(self):
  212. self.current_step += 1
  213. lr_factor = self._compute_lr_factor()
  214. for optim, base_lrs in zip(self.optimizers, self.base_lrs):
  215. for group_id, base_lr in enumerate(base_lrs):
  216. optim.param_groups[group_id]['lr'] = base_lr * lr_factor
  217. def roc_auc_score(y_true, y_score):
  218. """ROC AUC score in PyTorch
  219. Args:
  220. y_true (Tensor):
  221. y_score (Tensor):
  222. """
  223. device = y_true.device
  224. y_true.squeeze_()
  225. y_score.squeeze_()
  226. if y_true.shape != y_score.shape:
  227. raise TypeError(f"Shape of y_true and y_score must match. Got {y_true.shape()} and {y_score.shape()}.")
  228. desc_score_indices = torch.argsort(y_score, descending=True)
  229. y_score = y_score[desc_score_indices]
  230. y_true = y_true[desc_score_indices]
  231. distinct_value_indices = torch.nonzero(y_score[1:] - y_score[:-1], as_tuple=False).squeeze()
  232. threshold_idxs = torch.cat([distinct_value_indices, torch.tensor([y_true.numel() - 1], device=device)])
  233. tps = torch.cumsum(y_true, dim=0)[threshold_idxs]
  234. fps = 1 + threshold_idxs - tps
  235. tps = torch.cat([torch.zeros(1, device=device), tps])
  236. fps = torch.cat([torch.zeros(1, device=device), fps])
  237. fpr = fps / fps[-1]
  238. tpr = tps / tps[-1]
  239. area = torch.trapz(tpr, fpr).item()
  240. return area