utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Copyright (c) 2017 Elad Hoffer
  2. # Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Permission is hereby granted, free of charge, to any person obtaining a copy
  5. # of this software and associated documentation files (the "Software"), to deal
  6. # in the Software without restriction, including without limitation the rights
  7. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  8. # copies of the Software, and to permit persons to whom the Software is
  9. # furnished to do so, subject to the following conditions:
  10. #
  11. # The above copyright notice and this permission notice shall be included in all
  12. # copies or substantial portions of the Software.
  13. #
  14. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  15. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  16. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  17. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  18. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  19. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  20. # SOFTWARE.
  21. import logging.config
  22. import os
  23. import random
  24. import sys
  25. import time
  26. from contextlib import contextmanager
  27. import dllogger
  28. import numpy as np
  29. import torch
  30. import torch.distributed as dist
  31. import torch.nn.init as init
  32. import torch.utils.collect_env
  33. def init_lstm_(lstm, init_weight=0.1):
  34. """
  35. Initializes weights of LSTM layer.
  36. Weights and biases are initialized with uniform(-init_weight, init_weight)
  37. distribution.
  38. :param lstm: instance of torch.nn.LSTM
  39. :param init_weight: range for the uniform initializer
  40. """
  41. # Initialize hidden-hidden weights
  42. init.uniform_(lstm.weight_hh_l0.data, -init_weight, init_weight)
  43. # Initialize input-hidden weights:
  44. init.uniform_(lstm.weight_ih_l0.data, -init_weight, init_weight)
  45. # Initialize bias. PyTorch LSTM has two biases, one for input-hidden GEMM
  46. # and the other for hidden-hidden GEMM. Here input-hidden bias is
  47. # initialized with uniform distribution and hidden-hidden bias is
  48. # initialized with zeros.
  49. init.uniform_(lstm.bias_ih_l0.data, -init_weight, init_weight)
  50. init.zeros_(lstm.bias_hh_l0.data)
  51. if lstm.bidirectional:
  52. init.uniform_(lstm.weight_hh_l0_reverse.data, -init_weight, init_weight)
  53. init.uniform_(lstm.weight_ih_l0_reverse.data, -init_weight, init_weight)
  54. init.uniform_(lstm.bias_ih_l0_reverse.data, -init_weight, init_weight)
  55. init.zeros_(lstm.bias_hh_l0_reverse.data)
  56. def generate_seeds(rng, size):
  57. """
  58. Generate list of random seeds
  59. :param rng: random number generator
  60. :param size: length of the returned list
  61. """
  62. seeds = [rng.randint(0, 2**32 - 1) for _ in range(size)]
  63. return seeds
  64. def broadcast_seeds(seeds, device):
  65. """
  66. Broadcasts random seeds to all distributed workers.
  67. Returns list of random seeds (broadcasted from workers with rank 0).
  68. :param seeds: list of seeds (integers)
  69. :param device: torch.device
  70. """
  71. if torch.distributed.is_available() and torch.distributed.is_initialized():
  72. seeds_tensor = torch.tensor(seeds, dtype=torch.int64, device=device)
  73. torch.distributed.broadcast(seeds_tensor, 0)
  74. seeds = seeds_tensor.tolist()
  75. return seeds
  76. def setup_seeds(master_seed, epochs, device):
  77. """
  78. Generates seeds from one master_seed.
  79. Function returns (worker_seeds, shuffling_seeds), worker_seeds are later
  80. used to initialize per-worker random number generators (mostly for
  81. dropouts), shuffling_seeds are for RNGs resposible for reshuffling the
  82. dataset before each epoch.
  83. Seeds are generated on worker with rank 0 and broadcasted to all other
  84. workers.
  85. :param master_seed: master RNG seed used to initialize other generators
  86. :param epochs: number of epochs
  87. :param device: torch.device (used for distributed.broadcast)
  88. """
  89. if master_seed is None:
  90. # random master seed, random.SystemRandom() uses /dev/urandom on Unix
  91. master_seed = random.SystemRandom().randint(0, 2**32 - 1)
  92. if get_rank() == 0:
  93. # master seed is reported only from rank=0 worker, it's to avoid
  94. # confusion, seeds from rank=0 are later broadcasted to other
  95. # workers
  96. logging.info(f'Using random master seed: {master_seed}')
  97. else:
  98. # master seed was specified from command line
  99. logging.info(f'Using master seed from command line: {master_seed}')
  100. # initialize seeding RNG
  101. seeding_rng = random.Random(master_seed)
  102. # generate worker seeds, one seed for every distributed worker
  103. worker_seeds = generate_seeds(seeding_rng, get_world_size())
  104. # generate seeds for data shuffling, one seed for every epoch
  105. shuffling_seeds = generate_seeds(seeding_rng, epochs)
  106. # broadcast seeds from rank=0 to other workers
  107. worker_seeds = broadcast_seeds(worker_seeds, device)
  108. shuffling_seeds = broadcast_seeds(shuffling_seeds, device)
  109. return worker_seeds, shuffling_seeds
  110. def barrier():
  111. """
  112. Call torch.distributed.barrier() if distritubed is in use
  113. """
  114. if torch.distributed.is_available() and torch.distributed.is_initialized():
  115. torch.distributed.barrier()
  116. def get_rank():
  117. """
  118. Gets distributed rank or returns zero if distributed is not initialized.
  119. """
  120. if torch.distributed.is_available() and torch.distributed.is_initialized():
  121. rank = torch.distributed.get_rank()
  122. else:
  123. rank = 0
  124. return rank
  125. def get_world_size():
  126. """
  127. Gets total number of distributed workers or returns one if distributed is
  128. not initialized.
  129. """
  130. if torch.distributed.is_available() and torch.distributed.is_initialized():
  131. world_size = torch.distributed.get_world_size()
  132. else:
  133. world_size = 1
  134. return world_size
  135. @contextmanager
  136. def sync_workers():
  137. """
  138. Yields distributed rank and synchronizes all workers on exit.
  139. """
  140. rank = get_rank()
  141. yield rank
  142. barrier()
  143. @contextmanager
  144. def timer(name, ndigits=2, sync_gpu=True):
  145. if sync_gpu:
  146. torch.cuda.synchronize()
  147. start = time.time()
  148. yield
  149. if sync_gpu:
  150. torch.cuda.synchronize()
  151. stop = time.time()
  152. elapsed = round(stop - start, ndigits)
  153. logging.info(f'TIMER {name} {elapsed}')
  154. def setup_logging(log_all_ranks=True, log_file=os.devnull):
  155. """
  156. Configures logging.
  157. By default logs from all workers are printed to the console, entries are
  158. prefixed with "N: " where N is the rank of the worker. Logs printed to the
  159. console don't include timestaps.
  160. Full logs with timestamps are saved to the log_file file.
  161. """
  162. class RankFilter(logging.Filter):
  163. def __init__(self, rank, log_all_ranks):
  164. self.rank = rank
  165. self.log_all_ranks = log_all_ranks
  166. def filter(self, record):
  167. record.rank = self.rank
  168. if self.log_all_ranks:
  169. return True
  170. else:
  171. return (self.rank == 0)
  172. rank = get_rank()
  173. rank_filter = RankFilter(rank, log_all_ranks)
  174. for handler in logging.root.handlers[:]:
  175. logging.root.removeHandler(handler)
  176. handler.close()
  177. logging_format = "%(asctime)s - %(levelname)s - %(rank)s - %(message)s"
  178. logging.basicConfig(level=logging.DEBUG,
  179. format=logging_format,
  180. datefmt="%Y-%m-%d %H:%M:%S",
  181. filename=log_file,
  182. filemode='w')
  183. console = logging.StreamHandler(sys.stdout)
  184. console.setLevel(logging.INFO)
  185. formatter = logging.Formatter('%(rank)s: %(message)s')
  186. console.setFormatter(formatter)
  187. logging.getLogger('').addHandler(console)
  188. logging.getLogger('').addFilter(rank_filter)
  189. def setup_dllogger(enabled=True, filename=os.devnull):
  190. rank = get_rank()
  191. if enabled and rank == 0:
  192. backends = [
  193. dllogger.JSONStreamBackend(
  194. dllogger.Verbosity.VERBOSE,
  195. filename,
  196. ),
  197. ]
  198. dllogger.init(backends)
  199. else:
  200. dllogger.init([])
  201. dllogger.metadata("test_bleu", {"unit": None})
  202. dllogger.metadata("eval_90%_latency", {"unit": "ms"})
  203. dllogger.metadata("eval_avg_latency", {"unit": "ms"})
  204. dllogger.metadata("train_elapsed", {"unit": "s"})
  205. dllogger.metadata("eval_throughput", {"unit": "tokens/s"})
  206. dllogger.metadata("train_throughput", {"unit": "tokens/s"})
  207. def set_device(cuda, local_rank):
  208. """
  209. Sets device based on local_rank and returns instance of torch.device.
  210. :param cuda: if True: use cuda
  211. :param local_rank: local rank of the worker
  212. """
  213. if cuda:
  214. torch.cuda.set_device(local_rank)
  215. device = torch.device('cuda')
  216. else:
  217. device = torch.device('cpu')
  218. return device
  219. def init_distributed(cuda):
  220. """
  221. Initializes distributed backend.
  222. :param cuda: (bool) if True initializes nccl backend, if False initializes
  223. gloo backend
  224. """
  225. world_size = int(os.environ.get('WORLD_SIZE', 1))
  226. distributed = (world_size > 1)
  227. if distributed:
  228. backend = 'nccl' if cuda else 'gloo'
  229. dist.init_process_group(backend=backend,
  230. init_method='env://')
  231. assert dist.is_initialized()
  232. return distributed
  233. def log_env_info():
  234. """
  235. Prints information about execution environment.
  236. """
  237. logging.info('Collecting environment information...')
  238. env_info = torch.utils.collect_env.get_pretty_env_info()
  239. logging.info(f'{env_info}')
  240. def pad_vocabulary(math):
  241. if math == 'tf32' or math == 'fp16' or math == 'manual_fp16':
  242. pad_vocab = 8
  243. elif math == 'fp32':
  244. pad_vocab = 1
  245. return pad_vocab
  246. def benchmark(test_acc, target_acc, test_perf, target_perf):
  247. def test(achieved, target, name):
  248. passed = True
  249. if target is not None and achieved is not None:
  250. logging.info(f'{name} achieved: {achieved:.2f} '
  251. f'target: {target:.2f}')
  252. if achieved >= target:
  253. logging.info(f'{name} test passed')
  254. else:
  255. logging.info(f'{name} test failed')
  256. passed = False
  257. return passed
  258. passed = True
  259. passed &= test(test_acc, target_acc, 'Accuracy')
  260. passed &= test(test_perf, target_perf, 'Performance')
  261. return passed
  262. def debug_tensor(tensor, name):
  263. """
  264. Simple utility which helps with debugging.
  265. Takes a tensor and outputs: min, max, avg, std, number of NaNs, number of
  266. INFs.
  267. :param tensor: torch tensor
  268. :param name: name of the tensor (only for logging)
  269. """
  270. logging.info(name)
  271. tensor = tensor.detach().float().cpu().numpy()
  272. logging.info(f'MIN: {tensor.min()} MAX: {tensor.max()} '
  273. f'AVG: {tensor.mean()} STD: {tensor.std()} '
  274. f'NAN: {np.isnan(tensor).sum()} INF: {np.isinf(tensor).sum()}')
  275. class AverageMeter:
  276. """
  277. Computes and stores the average and current value
  278. """
  279. def __init__(self, warmup=0, keep=False):
  280. self.reset()
  281. self.warmup = warmup
  282. self.keep = keep
  283. def reset(self):
  284. self.val = 0
  285. self.avg = 0
  286. self.sum = 0
  287. self.count = 0
  288. self.iters = 0
  289. self.vals = []
  290. def update(self, val, n=1):
  291. self.iters += 1
  292. self.val = val
  293. if self.iters > self.warmup:
  294. self.sum += val * n
  295. self.count += n
  296. self.avg = self.sum / self.count
  297. if self.keep:
  298. self.vals.append(val)
  299. def reduce(self, op):
  300. """
  301. Reduces average value over all workers.
  302. :param op: 'sum' or 'mean', reduction operator
  303. """
  304. if op not in ('sum', 'mean'):
  305. raise NotImplementedError
  306. distributed = (get_world_size() > 1)
  307. if distributed:
  308. backend = dist.get_backend()
  309. cuda = (backend == dist.Backend.NCCL)
  310. if cuda:
  311. avg = torch.cuda.FloatTensor([self.avg])
  312. _sum = torch.cuda.FloatTensor([self.sum])
  313. else:
  314. avg = torch.FloatTensor([self.avg])
  315. _sum = torch.FloatTensor([self.sum])
  316. dist.all_reduce(avg)
  317. dist.all_reduce(_sum)
  318. self.avg = avg.item()
  319. self.sum = _sum.item()
  320. if op == 'mean':
  321. self.avg /= get_world_size()
  322. self.sum /= get_world_size()