metrics.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import timer
  2. from collections import defaultdict
  3. class Metrics(defaultdict):
  4. # TODO Where to measure - gpu:0 or all gpus?
  5. def __init__(self, tb_keys=[], benchmark_epochs=10):
  6. super().__init__(float)
  7. # dll_tb_keys=['loss_gen', 'loss_discrim', 'loss_mel', 'took']:
  8. self.tb_keys = tb_keys #_ = {'dll': dll_keys, 'tb': tb_keys, 'dll+tb': dll_tb_keys}
  9. self.iter_start_time = None
  10. self.iter_metrics = defaultdict(float)
  11. self.epoch_start_time = None
  12. self.epoch_metrics = defaultdict(float)
  13. self.benchmark_epochs = benchmark_epochs
  14. def start_epoch(self, epoch, start_timer=True):
  15. self.epoch = epoch
  16. if start_timer:
  17. self.epoch_start_time = time.time()
  18. def start_iter(self, iter, start_timer=True):
  19. self.iter = iter
  20. self.accum_steps = 0
  21. self.step_metrics.clear()
  22. if start_timer:
  23. self.iter_start_time = time.time()
  24. def update_iter(self, ...):
  25. # do stuff
  26. pass
  27. def accumulate(self, scope='step'):
  28. tgt = {'step': self.step_metrics, 'epoch': self.epoch_metrics}[scope]
  29. for k, v in self.items():
  30. tgt[k] += v
  31. self.clear()
  32. def update_iter(self, metrics={}, stop_timer=True):
  33. is not self.started_iter:
  34. return
  35. self.accumulate(metrics)
  36. self.accumulate(self.iter_metrics, scope='epoch')
  37. if stop_timer:
  38. self.iter_metrics['took'] = time.time() - self.iter_time_start
  39. def update_epoch(self, stop_timer=True):
  40. # tb_total_steps=None,
  41. # subset='train_avg',
  42. # data=OrderedDict([
  43. # ('loss', epoch_loss[-1]),
  44. # ('mel_loss', epoch_mel_loss[-1]),
  45. # ('frames/s', epoch_num_frames[-1] / epoch_time[-1]),
  46. # ('took', epoch_time[-1])]),
  47. # )
  48. if stop_timer:
  49. self.['epoch_time'] = time.time() - self.epoch_time_start
  50. if steps % args.stdout_interval == 0:
  51. # with torch.no_grad():
  52. # mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
  53. took = time.time() - self.start_b
  54. self.sws['train'].add_scalar("gen_loss_total", loss_gen_all.item(), steps)
  55. self.sws['train'].add_scalar("mel_spec_error", mel_error.item(), steps)
  56. for key, val in meta.items():
  57. sw_name = 'train'
  58. for name_ in keys_mpd + keys_msd:
  59. if name_ in key:
  60. sw_name = 'train_' + name_
  61. key = key.replace('loss_', 'loss/')
  62. key = re.sub('mpd\d+', 'mpd-msd', key)
  63. key = re.sub('msd\d+', 'mpd-msd', key)
  64. self.sws[sw_name].add_scalar(key, val / h.batch_size, steps)
  65. def iter_metrics(self, target='dll+tb'):
  66. return {self.iter_metrics[k] for k in self.keys_[target]}
  67. def foo
  68. Steps : 40, Gen Loss Total : 57.993, Mel-Spec. Error : 47.374, s/b : 1.013
  69. logger.log((epoch, epoch_iter, num_iters),
  70. tb_total_steps=total_iter,
  71. subset='train',
  72. data=OrderedDict([
  73. ('loss', iter_loss),
  74. ('mel_loss', iter_mel_loss),
  75. ('frames/s', iter_num_frames / iter_time),
  76. ('took', iter_time),
  77. ('lrate', optimizer.param_groups[0]['lr'])]),
  78. )
  79. class Meter:
  80. def __init__(self, sink_type, scope, downstream=None, end_points=None, verbosity=dllogger.Verbosity.DEFAULT):
  81. self.verbosity = verbosity
  82. self.sink_type = sink_type
  83. self.scope = scope
  84. self.downstream = downstream
  85. self.end_points = end_points or []
  86. def start(self):
  87. ds = None if self.downstream is None else self.downstream.sink
  88. end_pt_fn = lambda x: list(map(lambda f: f(x), self.end_points)) # call all endpoint functions
  89. self.sink = self.sink_type(end_pt_fn, ds)
  90. def end(self):
  91. self.sink.close()
  92. def send(self, data):
  93. self.sink.send(data)
  94. def meters(self):
  95. if self.downstream is not None:
  96. downstream_meters = self.downstream.meters()
  97. else:
  98. downstream_meters = []
  99. return [self] + downstream_meters
  100. def add_end_point(self, new_endpoint):
  101. self.end_points.append(new_endpoint)
  102. def __or__(self, other):
  103. """for easy chaining of meters"""
  104. if self.downstream is None:
  105. self.downstream = other
  106. else:
  107. self.downstream | other
  108. return self