logger.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright (c) 2018-2019, NVIDIA CORPORATION
  2. # Copyright (c) 2017- Facebook, Inc
  3. #
  4. # All rights reserved.
  5. #
  6. # Redistribution and use in source and binary forms, with or without
  7. # modification, are permitted provided that the following conditions are met:
  8. #
  9. # * Redistributions of source code must retain the above copyright notice, this
  10. # list of conditions and the following disclaimer.
  11. #
  12. # * Redistributions in binary form must reproduce the above copyright notice,
  13. # this list of conditions and the following disclaimer in the documentation
  14. # and/or other materials provided with the distribution.
  15. #
  16. # * Neither the name of the copyright holder nor the names of its
  17. # contributors may be used to endorse or promote products derived from
  18. # this software without specific prior written permission.
  19. #
  20. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  23. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  24. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  25. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  26. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  27. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  28. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  29. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  30. from collections import OrderedDict
  31. import dllogger
  32. import numpy as np
  33. def format_step(step):
  34. if isinstance(step, str):
  35. return step
  36. s = ""
  37. if len(step) > 0:
  38. s += "Epoch: {} ".format(step[0])
  39. if len(step) > 1:
  40. s += "Iteration: {} ".format(step[1])
  41. if len(step) > 2:
  42. s += "Validation Iteration: {} ".format(step[2])
  43. if len(step) == 0:
  44. s = "Summary:"
  45. return s
  46. PERF_METER = lambda: Meter(AverageMeter(), AverageMeter(), AverageMeter())
  47. LOSS_METER = lambda: Meter(AverageMeter(), AverageMeter(), MinMeter())
  48. ACC_METER = lambda: Meter(AverageMeter(), AverageMeter(), MaxMeter())
  49. LR_METER = lambda: Meter(LastMeter(), LastMeter(), LastMeter())
  50. LAT_100 = lambda: Meter(QuantileMeter(1), QuantileMeter(1), QuantileMeter(1))
  51. LAT_99 = lambda: Meter(QuantileMeter(0.99), QuantileMeter(0.99), QuantileMeter(0.99))
  52. LAT_95 = lambda: Meter(QuantileMeter(0.95), QuantileMeter(0.95), QuantileMeter(0.95))
  53. class Meter(object):
  54. def __init__(self, iteration_aggregator, epoch_aggregator, run_aggregator):
  55. self.run_aggregator = run_aggregator
  56. self.epoch_aggregator = epoch_aggregator
  57. self.iteration_aggregator = iteration_aggregator
  58. def record(self, val, n=1):
  59. self.iteration_aggregator.record(val, n=n)
  60. def get_iteration(self):
  61. v, n = self.iteration_aggregator.get_val()
  62. return v
  63. def reset_iteration(self):
  64. v, n = self.iteration_aggregator.get_data()
  65. self.iteration_aggregator.reset()
  66. if v is not None:
  67. self.epoch_aggregator.record(v, n=n)
  68. def get_epoch(self):
  69. v, n = self.epoch_aggregator.get_val()
  70. return v
  71. def reset_epoch(self):
  72. v, n = self.epoch_aggregator.get_data()
  73. self.epoch_aggregator.reset()
  74. if v is not None:
  75. self.run_aggregator.record(v, n=n)
  76. def get_run(self):
  77. v, n = self.run_aggregator.get_val()
  78. return v
  79. def reset_run(self):
  80. self.run_aggregator.reset()
  81. class QuantileMeter(object):
  82. def __init__(self, q):
  83. self.q = q
  84. self.reset()
  85. def reset(self):
  86. self.vals = []
  87. self.n = 0
  88. def record(self, val, n=1):
  89. if isinstance(val, list):
  90. self.vals += val
  91. self.n += len(val)
  92. else:
  93. self.vals += [val] * n
  94. self.n += n
  95. def get_val(self):
  96. if not self.vals:
  97. return None, self.n
  98. return np.quantile(self.vals, self.q, interpolation="nearest"), self.n
  99. def get_data(self):
  100. return self.vals, self.n
  101. class MaxMeter(object):
  102. def __init__(self):
  103. self.reset()
  104. def reset(self):
  105. self.max = None
  106. self.n = 0
  107. def record(self, val, n=1):
  108. if self.max is None:
  109. self.max = val
  110. else:
  111. self.max = max(self.max, val)
  112. self.n = n
  113. def get_val(self):
  114. return self.max, self.n
  115. def get_data(self):
  116. return self.max, self.n
  117. class MinMeter(object):
  118. def __init__(self):
  119. self.reset()
  120. def reset(self):
  121. self.min = None
  122. self.n = 0
  123. def record(self, val, n=1):
  124. if self.min is None:
  125. self.min = val
  126. else:
  127. self.min = max(self.min, val)
  128. self.n = n
  129. def get_val(self):
  130. return self.min, self.n
  131. def get_data(self):
  132. return self.min, self.n
  133. class LastMeter(object):
  134. def __init__(self):
  135. self.reset()
  136. def reset(self):
  137. self.last = None
  138. self.n = 0
  139. def record(self, val, n=1):
  140. self.last = val
  141. self.n = n
  142. def get_val(self):
  143. return self.last, self.n
  144. def get_data(self):
  145. return self.last, self.n
  146. class AverageMeter(object):
  147. def __init__(self):
  148. self.reset()
  149. def reset(self):
  150. self.n = 0
  151. self.val = 0
  152. def record(self, val, n=1):
  153. self.n += n
  154. self.val += val * n
  155. def get_val(self):
  156. if self.n == 0:
  157. return None, 0
  158. return self.val / self.n, self.n
  159. def get_data(self):
  160. if self.n == 0:
  161. return None, 0
  162. return self.val / self.n, self.n
  163. class Logger(object):
  164. def __init__(self, print_interval, backends, start_epoch=-1, verbose=False):
  165. self.epoch = start_epoch
  166. self.iteration = -1
  167. self.val_iteration = -1
  168. self.metrics = OrderedDict()
  169. self.backends = backends
  170. self.print_interval = print_interval
  171. self.verbose = verbose
  172. dllogger.init(backends)
  173. def log_parameter(self, data, verbosity=0):
  174. dllogger.log(step="PARAMETER", data=data, verbosity=verbosity)
  175. def register_metric(self, metric_name, meter, verbosity=0, metadata={}):
  176. if self.verbose:
  177. print("Registering metric: {}".format(metric_name))
  178. self.metrics[metric_name] = {"meter": meter, "level": verbosity}
  179. dllogger.metadata(metric_name, metadata)
  180. def log_metric(self, metric_name, val, n=1):
  181. self.metrics[metric_name]["meter"].record(val, n=n)
  182. def start_iteration(self, val=False):
  183. if val:
  184. self.val_iteration += 1
  185. else:
  186. self.iteration += 1
  187. def end_iteration(self, val=False):
  188. it = self.val_iteration if val else self.iteration
  189. if it % self.print_interval == 0:
  190. metrics = {
  191. n: m for n, m in self.metrics.items() if n.startswith("val") == val
  192. }
  193. step = (
  194. (self.epoch, self.iteration)
  195. if not val
  196. else (self.epoch, self.iteration, self.val_iteration)
  197. )
  198. verbositys = {m["level"] for _, m in metrics.items()}
  199. for ll in verbositys:
  200. llm = {n: m for n, m in metrics.items() if m["level"] == ll}
  201. dllogger.log(
  202. step=step,
  203. data={n: m["meter"].get_iteration() for n, m in llm.items()},
  204. verbosity=ll,
  205. )
  206. for n, m in metrics.items():
  207. m["meter"].reset_iteration()
  208. dllogger.flush()
  209. def start_epoch(self):
  210. self.epoch += 1
  211. self.iteration = 0
  212. self.val_iteration = 0
  213. for n, m in self.metrics.items():
  214. m["meter"].reset_epoch()
  215. def end_epoch(self):
  216. for n, m in self.metrics.items():
  217. m["meter"].reset_iteration()
  218. verbositys = {m["level"] for _, m in self.metrics.items()}
  219. for ll in verbositys:
  220. llm = {n: m for n, m in self.metrics.items() if m["level"] == ll}
  221. dllogger.log(
  222. step=(self.epoch,),
  223. data={n: m["meter"].get_epoch() for n, m in llm.items()},
  224. )
  225. def end(self):
  226. for n, m in self.metrics.items():
  227. m["meter"].reset_epoch()
  228. verbositys = {m["level"] for _, m in self.metrics.items()}
  229. for ll in verbositys:
  230. llm = {n: m for n, m in self.metrics.items() if m["level"] == ll}
  231. dllogger.log(
  232. step=tuple(), data={n: m["meter"].get_run() for n, m in llm.items()}
  233. )
  234. for n, m in self.metrics.items():
  235. m["meter"].reset_epoch()
  236. dllogger.flush()
  237. def iteration_generator_wrapper(self, gen, val=False):
  238. for g in gen:
  239. self.start_iteration(val=val)
  240. yield g
  241. self.end_iteration(val=val)
  242. def epoch_generator_wrapper(self, gen):
  243. for g in gen:
  244. self.start_epoch()
  245. yield g
  246. self.end_epoch()