calculate_metrics.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. r"""
  16. Using `calculate_metrics.py` script, you can obtain model accuracy/error metrics using defined `MetricsCalculator` class.
  17. Data provided to `MetricsCalculator` are obtained from npz dump files
  18. stored in directory pointed by `--dump-dir` argument.
  19. Above files are prepared by `run_inference_on_fw.py` and `run_inference_on_triton.py` scripts.
  20. Output data is stored in csv file pointed by `--csv` argument.
  21. Example call:
  22. ```shell script
  23. python ./triton/calculate_metrics.py \
  24. --dump-dir /results/dump_triton \
  25. --csv /results/accuracy_results.csv \
  26. --metrics metrics.py \
  27. --metric-class-param1 value
  28. ```
  29. """
  30. import argparse
  31. import csv
  32. import logging
  33. import string
  34. from pathlib import Path
  35. import numpy as np
  36. # method from PEP-366 to support relative import in executed modules
  37. if __package__ is None:
  38. __package__ = Path(__file__).parent.name
  39. from .deployment_toolkit.args import ArgParserGenerator
  40. from .deployment_toolkit.core import BaseMetricsCalculator, load_from_file
  41. from .deployment_toolkit.dump import pad_except_batch_axis
  42. LOGGER = logging.getLogger("calculate_metrics")
  43. TOTAL_COLUMN_NAME = "_total_"
  44. def get_data(dump_dir, prefix):
  45. """Loads and concatenates dump files for given prefix (ex. inputs, outputs, labels, ids)"""
  46. dump_dir = Path(dump_dir)
  47. npz_files = sorted(dump_dir.glob(f"{prefix}*.npz"))
  48. data = None
  49. if npz_files:
  50. # assume that all npz files with given prefix contain same set of names
  51. names = list(np.load(npz_files[0].as_posix()).keys())
  52. # calculate target shape
  53. target_shape = {
  54. name: tuple(np.max([np.load(npz_file.as_posix())[name].shape for npz_file in npz_files], axis=0))
  55. for name in names
  56. }
  57. # pad and concatenate data
  58. data = {
  59. name: np.concatenate(
  60. [pad_except_batch_axis(np.load(npz_file.as_posix())[name], target_shape[name]) for npz_file in npz_files]
  61. )
  62. for name in names
  63. }
  64. return data
  65. def main():
  66. logging.basicConfig(level=logging.INFO)
  67. parser = argparse.ArgumentParser(description="Run models with given dataloader", allow_abbrev=False)
  68. parser.add_argument("--metrics", help=f"Path to python module containing metrics calculator", required=True)
  69. parser.add_argument("--csv", help="Path to csv file", required=True)
  70. parser.add_argument("--dump-dir", help="Path to directory with dumped outputs (and labels)", required=True)
  71. args, *_ = parser.parse_known_args()
  72. MetricsCalculator = load_from_file(args.metrics, "metrics", "MetricsCalculator")
  73. ArgParserGenerator(MetricsCalculator).update_argparser(parser)
  74. args = parser.parse_args()
  75. LOGGER.info(f"args:")
  76. for key, value in vars(args).items():
  77. LOGGER.info(f" {key} = {value}")
  78. MetricsCalculator = load_from_file(args.metrics, "metrics", "MetricsCalculator")
  79. metrics_calculator: BaseMetricsCalculator = ArgParserGenerator(MetricsCalculator).from_args(args)
  80. ids = get_data(args.dump_dir, "ids")["ids"]
  81. x = get_data(args.dump_dir, "inputs")
  82. y_true = get_data(args.dump_dir, "labels")
  83. y_pred = get_data(args.dump_dir, "outputs")
  84. common_keys = list({k for k in (y_true or [])} & {k for k in (y_pred or [])})
  85. for key in common_keys:
  86. if y_true[key].shape != y_pred[key].shape:
  87. LOGGER.warning(
  88. f"Model predictions and labels shall have equal shapes. "
  89. f"y_pred[{key}].shape={y_pred[key].shape} != "
  90. f"y_true[{key}].shape={y_true[key].shape}"
  91. )
  92. metrics = metrics_calculator.calc(ids=ids, x=x, y_pred=y_pred, y_real=y_true)
  93. metrics = {TOTAL_COLUMN_NAME: len(ids), **metrics}
  94. metric_names_with_space = [name for name in metrics if any([c in string.whitespace for c in name])]
  95. if metric_names_with_space:
  96. raise ValueError(f"Metric names shall have no spaces; Incorrect names: {', '.join(metric_names_with_space)}")
  97. csv_path = Path(args.csv)
  98. csv_path.parent.mkdir(parents=True, exist_ok=True)
  99. with csv_path.open("w") as csv_file:
  100. writer = csv.DictWriter(csv_file, fieldnames=list(metrics.keys()))
  101. writer.writeheader()
  102. writer.writerow(metrics)
  103. if __name__ == "__main__":
  104. main()