global_metrics.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import numpy as np
  2. class CompositeMeter:
  3. def __init__(self):
  4. self.register = {}
  5. def register_metric(self, name, metric):
  6. self.register[name] = metric
  7. def _validate(self, metric_name):
  8. if metric_name not in self.register:
  9. raise ValueError('{} is not registered metric'.format(metric_name))
  10. def update_metric(self, metric_name, value):
  11. self._validate(metric_name)
  12. self.register[metric_name].update(value)
  13. def update_dict(self, dict_metric):
  14. for name, val in dict_metric.items():
  15. if name in self.register.keys():
  16. self.update_metric(name, val)
  17. def get(self, metric_name=None):
  18. if metric_name is not None:
  19. self._validate(metric_name)
  20. return self.register[metric_name].get()
  21. res_dict = {name: metric.get() for name, metric in self.register.items()}
  22. return res_dict
  23. class MaxMeter:
  24. def __init__(self):
  25. self.max = None
  26. self.n = 0
  27. def reset(self):
  28. self.max = None
  29. self.n = 0
  30. def update(self, val):
  31. if self.max is None:
  32. self.max = val
  33. else:
  34. self.max = max(self.max, val)
  35. def get(self):
  36. return self.max
  37. class MinMeter:
  38. def __init__(self):
  39. self.min = None
  40. self.n = 0
  41. def reset(self):
  42. self.min = None
  43. self.n = 0
  44. def update(self, val):
  45. if self.min is None:
  46. self.min = val
  47. else:
  48. self.min = min(self.min, val)
  49. def get(self):
  50. return self.min
  51. class AvgMeter:
  52. def __init__(self):
  53. self.sum = 0
  54. self.n = 0
  55. def reset(self):
  56. self.sum = 0
  57. self.n = 0
  58. def update(self, val):
  59. self.sum += val
  60. self.n += 1
  61. def get(self):
  62. return self.sum / self.n
  63. class PercentileMeter:
  64. def __init__(self, q):
  65. self.data = []
  66. self.q = q
  67. def reset(self):
  68. self.data = []
  69. def update(self, data):
  70. self.data.extend(data)
  71. def get(self):
  72. return np.percentile(self.data, self.q)