log_helper.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import atexit
  2. import os
  3. import numpy as np
  4. from tensorboardX import SummaryWriter
  5. import dllogger as DLLogger
  6. from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
  7. def stdout_step_format(step):
  8. if isinstance(step, str):
  9. return step
  10. fields = []
  11. if len(step) > 0:
  12. fields.append("epoch {:>4}".format(step[0]))
  13. if len(step) > 1:
  14. fields.append("iter {:>3}".format(step[1]))
  15. if len(step) > 2:
  16. fields[-1] += "/{}".format(step[2])
  17. return " | ".join(fields)
  18. def stdout_metric_format(metric, metadata, value):
  19. name = metadata["name"] if "name" in metadata.keys() else metric + " : "
  20. unit = metadata["unit"] if "unit" in metadata.keys() else None
  21. format = "{" + metadata["format"] + "}" if "format" in metadata.keys() else "{}"
  22. fields = [name, format.format(value) if value is not None else value, unit]
  23. fields = filter(lambda f: f is not None, fields)
  24. return "| " + " ".join(fields)
  25. def init_dllogger(log_fpath=None, dummy=False):
  26. if dummy:
  27. DLLogger.init(backends=[])
  28. return
  29. DLLogger.init(backends=[
  30. JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
  31. StdOutBackend(Verbosity.VERBOSE, step_format=stdout_step_format,
  32. metric_format=stdout_metric_format)
  33. ]
  34. )
  35. DLLogger.metadata("train_loss", {"name": "loss", "format": ":>5.2f"})
  36. DLLogger.metadata("train_mel_loss", {"name": "mel loss", "format": ":>5.2f"})
  37. DLLogger.metadata("avg_train_loss", {"name": "avg train loss", "format": ":>5.2f"})
  38. DLLogger.metadata("avg_train_mel_loss", {"name": "avg train mel loss", "format": ":>5.2f"})
  39. DLLogger.metadata("val_loss", {"name": " avg val loss", "format": ":>5.2f"})
  40. DLLogger.metadata("val_mel_loss", {"name": " avg val mel loss", "format": ":>5.2f"})
  41. DLLogger.metadata(
  42. "val_ema_loss",
  43. {"name": " EMA val loss", "format": ":>5.2f"})
  44. DLLogger.metadata(
  45. "val_ema_mel_loss",
  46. {"name": " EMA val mel loss", "format": ":>5.2f"})
  47. DLLogger.metadata(
  48. "train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
  49. DLLogger.metadata(
  50. "avg_train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
  51. DLLogger.metadata(
  52. "val_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
  53. DLLogger.metadata(
  54. "val_ema_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
  55. DLLogger.metadata(
  56. "took", {"name": "took", "unit": "s", "format": ":>3.2f"})
  57. DLLogger.metadata("lrate_change", {"name": "lrate"})
  58. class TBLogger(object):
  59. """
  60. xyz_dummies: stretch the screen with empty plots so the legend would
  61. always fit for other plots
  62. """
  63. def __init__(self, local_rank, log_dir, name, interval=1, dummies=False):
  64. self.enabled = (local_rank == 0)
  65. self.interval = interval
  66. self.cache = {}
  67. if local_rank == 0:
  68. self.summary_writer = SummaryWriter(
  69. log_dir=os.path.join(log_dir, name),
  70. flush_secs=120, max_queue=200)
  71. atexit.register(self.summary_writer.close)
  72. if dummies:
  73. for key in ('aaa', 'zzz'):
  74. self.summary_writer.add_scalar(key, 0.0, 1)
  75. def log_value(self, step, key, val, stat='mean'):
  76. if self.enabled:
  77. if key not in self.cache:
  78. self.cache[key] = []
  79. self.cache[key].append(val)
  80. if len(self.cache[key]) == self.interval:
  81. agg_val = getattr(np, stat)(self.cache[key])
  82. self.summary_writer.add_scalar(key, agg_val, step)
  83. del self.cache[key]
  84. def log_meta(self, step, meta):
  85. for k, v in meta.items():
  86. self.log_value(step, k, v.item())
  87. def log_grads(self, step, model):
  88. if self.enabled:
  89. norms = [p.grad.norm().item() for p in model.parameters()
  90. if p.grad is not None]
  91. for stat in ('max', 'min', 'mean'):
  92. self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
  93. stat=stat)