| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import atexit
- import os
- import numpy as np
- from tensorboardX import SummaryWriter
- import dllogger as DLLogger
- from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
- def stdout_step_format(step):
- if isinstance(step, str):
- return step
- fields = []
- if len(step) > 0:
- fields.append("epoch {:>4}".format(step[0]))
- if len(step) > 1:
- fields.append("iter {:>3}".format(step[1]))
- if len(step) > 2:
- fields[-1] += "/{}".format(step[2])
- return " | ".join(fields)
- def stdout_metric_format(metric, metadata, value):
- name = metadata["name"] if "name" in metadata.keys() else metric + " : "
- unit = metadata["unit"] if "unit" in metadata.keys() else None
- format = "{" + metadata["format"] + "}" if "format" in metadata.keys() else "{}"
- fields = [name, format.format(value) if value is not None else value, unit]
- fields = filter(lambda f: f is not None, fields)
- return "| " + " ".join(fields)
- def init_dllogger(log_fpath=None, dummy=False):
- if dummy:
- DLLogger.init(backends=[])
- return
- DLLogger.init(backends=[
- JSONStreamBackend(Verbosity.DEFAULT, log_fpath),
- StdOutBackend(Verbosity.VERBOSE, step_format=stdout_step_format,
- metric_format=stdout_metric_format)
- ]
- )
- DLLogger.metadata("train_loss", {"name": "loss", "format": ":>5.2f"})
- DLLogger.metadata("train_mel_loss", {"name": "mel loss", "format": ":>5.2f"})
- DLLogger.metadata("avg_train_loss", {"name": "avg train loss", "format": ":>5.2f"})
- DLLogger.metadata("avg_train_mel_loss", {"name": "avg train mel loss", "format": ":>5.2f"})
- DLLogger.metadata("val_loss", {"name": " avg val loss", "format": ":>5.2f"})
- DLLogger.metadata("val_mel_loss", {"name": " avg val mel loss", "format": ":>5.2f"})
- DLLogger.metadata(
- "val_ema_loss",
- {"name": " EMA val loss", "format": ":>5.2f"})
- DLLogger.metadata(
- "val_ema_mel_loss",
- {"name": " EMA val mel loss", "format": ":>5.2f"})
- DLLogger.metadata(
- "train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
- DLLogger.metadata(
- "avg_train_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
- DLLogger.metadata(
- "val_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
- DLLogger.metadata(
- "val_ema_frames/s", {"name": None, "unit": "frames/s", "format": ":>10.2f"})
- DLLogger.metadata(
- "took", {"name": "took", "unit": "s", "format": ":>3.2f"})
- DLLogger.metadata("lrate_change", {"name": "lrate"})
- class TBLogger(object):
- """
- xyz_dummies: stretch the screen with empty plots so the legend would
- always fit for other plots
- """
- def __init__(self, local_rank, log_dir, name, interval=1, dummies=False):
- self.enabled = (local_rank == 0)
- self.interval = interval
- self.cache = {}
- if local_rank == 0:
- self.summary_writer = SummaryWriter(
- log_dir=os.path.join(log_dir, name),
- flush_secs=120, max_queue=200)
- atexit.register(self.summary_writer.close)
- if dummies:
- for key in ('aaa', 'zzz'):
- self.summary_writer.add_scalar(key, 0.0, 1)
- def log_value(self, step, key, val, stat='mean'):
- if self.enabled:
- if key not in self.cache:
- self.cache[key] = []
- self.cache[key].append(val)
- if len(self.cache[key]) == self.interval:
- agg_val = getattr(np, stat)(self.cache[key])
- self.summary_writer.add_scalar(key, agg_val, step)
- del self.cache[key]
- def log_meta(self, step, meta):
- for k, v in meta.items():
- self.log_value(step, k, v.item())
- def log_grads(self, step, model):
- if self.enabled:
- norms = [p.grad.norm().item() for p in model.parameters()
- if p.grad is not None]
- for stat in ('max', 'min', 'mean'):
- self.log_value(step, f'grad_{stat}', getattr(np, stat)(norms),
- stat=stat)
|