hydra_callbacks.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import mlflow
  16. import pandas as pd
  17. from omegaconf import OmegaConf
  18. from hydra.experimental.callback import Callback
  19. from loggers.log_helper import jsonlog_2_df
  20. from mlflow.entities import Metric, Param
  21. class MergeLogs(Callback):
  22. def on_multirun_end(self, config, **kwargs):
  23. OmegaConf.resolve(config)
  24. ALLOWED_KEYS=['timestamp', 'elapsed_time', 'step', 'loss', 'val_loss', 'MAE', 'MSE', 'RMSE', 'P50', 'P90', 'SMAPE', 'TDI']
  25. dfs = []
  26. for p, sub_dirs, files in os.walk(config.hydra.sweep.dir):
  27. if 'log.json' in files:
  28. path = os.path.join(p, 'log.json')
  29. df = jsonlog_2_df(path, ALLOWED_KEYS)
  30. dfs.append(df)
  31. # Transpose dataframes
  32. plots = {}
  33. for c in dfs[0].columns:
  34. joint_plots = pd.DataFrame({i : df[c] for i, df in enumerate(dfs)})
  35. metrics = {}
  36. metrics['mean'] = joint_plots.mean(axis=1)
  37. metrics['std'] = joint_plots.std(axis=1)
  38. metrics['mean_m_std'] = metrics['mean'] - metrics['std']
  39. metrics['mean_p_std'] = metrics['mean'] + metrics['std']
  40. metrics_df = pd.DataFrame(metrics)
  41. plots[c] = metrics_df[~metrics_df.isna().all(axis=1)] # Drop rows which contain only NaNs
  42. timestamps = plots.pop('timestamp')['mean']
  43. timestamps = (timestamps * 1000).astype(int)
  44. if not timestamps.is_monotonic:
  45. raise ValueError('Timestamps are not monotonic')
  46. metrics = [Metric('_'.join((k,name)), v, timestamp, step)
  47. for k, df in plots.items()
  48. for timestamp, (step, series) in zip(timestamps, df.iterrows())
  49. for name, v in series.items()
  50. ]
  51. client = mlflow.tracking.MlflowClient(tracking_uri=config.trainer.config.mlflow_store)
  52. exp = client.get_experiment_by_name(config.trainer.config.get('experiment_name', ''))
  53. run = client.create_run(exp.experiment_id if exp else '0')
  54. for i in range(0, len(metrics), 1000):
  55. client.log_batch(run.info.run_id, metrics=metrics[i:i+1000])
  56. client.set_terminated(run.info.run_id)