meter.h 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #pragma once
  9. #include <unordered_map>
  10. #include <vector>
  11. #include "dictionary.h"
  12. #include "real.h"
  13. #include "utils.h"
  14. namespace fasttext {
  15. class Meter {
  16. struct Metrics {
  17. uint64_t gold;
  18. uint64_t predicted;
  19. uint64_t predictedGold;
  20. mutable std::vector<std::pair<real, real>> scoreVsTrue;
  21. Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
  22. double precision() const {
  23. if (predicted == 0) {
  24. return std::numeric_limits<double>::quiet_NaN();
  25. }
  26. return predictedGold / double(predicted);
  27. }
  28. double recall() const {
  29. if (gold == 0) {
  30. return std::numeric_limits<double>::quiet_NaN();
  31. }
  32. return predictedGold / double(gold);
  33. }
  34. double f1Score() const {
  35. if (predicted + gold == 0) {
  36. return std::numeric_limits<double>::quiet_NaN();
  37. }
  38. return 2 * predictedGold / double(predicted + gold);
  39. }
  40. std::vector<std::pair<real, real>> getScoreVsTrue() {
  41. return scoreVsTrue;
  42. }
  43. };
  44. std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
  45. int32_t labelId) const;
  46. public:
  47. Meter() = delete;
  48. explicit Meter(bool falseNegativeLabels)
  49. : metrics_(),
  50. nexamples_(0),
  51. labelMetrics_(),
  52. falseNegativeLabels_(falseNegativeLabels) {}
  53. void log(const std::vector<int32_t>& labels, const Predictions& predictions);
  54. double precision(int32_t);
  55. double recall(int32_t);
  56. double f1Score(int32_t);
  57. std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
  58. double precisionAtRecall(int32_t labelId, double recall) const;
  59. double precisionAtRecall(double recall) const;
  60. double recallAtPrecision(int32_t labelId, double recall) const;
  61. double recallAtPrecision(double recall) const;
  62. std::vector<std::pair<double, double>> precisionRecallCurve(
  63. int32_t labelId) const;
  64. std::vector<std::pair<double, double>> precisionRecallCurve() const;
  65. double precision() const;
  66. double recall() const;
  67. double f1Score() const;
  68. uint64_t nexamples() const {
  69. return nexamples_;
  70. }
  71. void writeGeneralMetrics(std::ostream& out, int32_t k) const;
  72. private:
  73. Metrics metrics_{};
  74. uint64_t nexamples_;
  75. std::unordered_map<int32_t, Metrics> labelMetrics_;
  76. bool falseNegativeLabels_;
  77. };
  78. } // namespace fasttext