meters.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import time
  8. import torch
  9. class AverageMeter(object):
  10. """Computes and stores the average and current value"""
  11. def __init__(self):
  12. self.reset()
  13. def reset(self):
  14. self.val = 0
  15. self.avg = 0
  16. self.sum = 0
  17. self.count = 0
  18. def update(self, val, n=1):
  19. self.val = val
  20. self.sum += val * n
  21. self.count += n
  22. self.avg = self.sum / self.count
  23. class TimeMeter(object):
  24. """Computes the average occurrence of some event per second"""
  25. def __init__(self, init=0):
  26. self.reset(init)
  27. def reset(self, init=0):
  28. self.init = init
  29. torch.cuda.synchronize()
  30. self.start = time.time()
  31. self.n = 0
  32. self.last_update = time.time()
  33. def update(self, val=1):
  34. self.n += val
  35. torch.cuda.synchronize()
  36. self.last_update = time.time()
  37. @property
  38. def avg(self):
  39. return self.n / self.elapsed_time
  40. @property
  41. def elapsed_time(self):
  42. torch.cuda.synchronize()
  43. return self.init + (time.time() - self.start)
  44. @property
  45. def u_avg(self):
  46. return self.n / (self.last_update - self.start)
  47. class StopwatchMeter(object):
  48. """Computes the sum/avg duration of some event in seconds"""
  49. def __init__(self):
  50. self.reset()
  51. self.intervals = []
  52. def start(self):
  53. torch.cuda.synchronize()
  54. self.start_time = time.time()
  55. def stop(self, n=1):
  56. torch.cuda.synchronize()
  57. if self.start_time is not None:
  58. delta = time.time() - self.start_time
  59. self.intervals.append(delta)
  60. self.sum += delta
  61. self.n += n
  62. self.start_time = None
  63. def reset(self):
  64. self.sum = 0
  65. self.n = 0
  66. self.start_time = None
  67. self.intervals = []
  68. @property
  69. def avg(self):
  70. return self.sum / self.n
  71. def p(self, i):
  72. assert i <= 100
  73. idx = int(len(self.intervals) * i / 100)
  74. return sorted(self.intervals)[idx]