Просмотр исходного кода

precision/recall metrics

Summary:
This commit adds precision/recall curve to the metrics api.
The `Meter` object is now exposed in python.

The precision/recall curve helps to decide the best threshold.
It can be retrieved from the model object as follows:

```python
ft = fasttext.load_model(model_file)
meter = ft.get_meter(test_file)

label = "__label__bakery"
y_scores, y_true = meter.score_vs_true(label)
precision, recall = meter.precision_recall_curve(label)
```

Reviewed By: EdouardGrave

Differential Revision: D19218524

fbshipit-source-id: 41a7c8e1aa991d076df04c5e497688daf0de4673
Onur Çelebi 5 лет назад
Родитель
Сommit
2cc7f54ac0

+ 1 - 1
Makefile

@@ -65,7 +65,7 @@ meter.o: src/meter.cc src/meter.h
 fasttext.o: src/fasttext.cc src/*.h
 	$(CXX) $(CXXFLAGS) -c src/fasttext.cc
 
-fasttext: $(OBJS) src/fasttext.cc
+fasttext: $(OBJS) src/fasttext.cc src/main.cc
 	$(CXX) $(CXXFLAGS) $(OBJS) src/main.cc -o fasttext
 
 clean:

+ 18 - 0
docs/autotune.md

@@ -136,3 +136,21 @@ This is equivalent to manually optimize the f1-score we get when we test with `m
 Sometimes, you may be interested in predicting more than one label. For example, if you were optimizing the hyperparameters manually to get the best score to predict two labels, you would test with `model.test("cooking.valid", k=2)`. You can also tell autotune to optimize the parameters by testing two labels with the `autotunePredictions` argument.
 <!--END_DOCUSAURUS_CODE_TABS-->
 
+You can also force autotune to optimize for the best precision for a given recall, or the best recall for a given precision, for all labels, or for a specific label:
+
+For example, in order to get the best precision at recall = `30%`:
+```sh
+>> ./fasttext supervised [...] -autotune-metric precisionAtRecall:30
+```
+And to get the best precision at recall = `30%` for the label `__label__baking`:
+```sh
+>> ./fasttext supervised [...] -autotune-metric precisionAtRecall:30:__label__baking
+```
+
+Similarly, you can use `recallAtPrecision`:
+```sh
+>> ./fasttext supervised [...] -autotune-metric recallAtPrecision:30
+>> ./fasttext supervised [...] -autotune-metric recallAtPrecision:30:__label__baking
+```
+
+

+ 67 - 0
python/fasttext_module/fasttext/FastText.py

@@ -21,11 +21,66 @@ EOS = "</s>"
 BOW = "<"
 EOW = ">"
 
+displayed_errors = {}
+
 
 def eprint(*args, **kwargs):
     print(*args, file=sys.stderr, **kwargs)
 
 
+class _Meter(object):
+    def __init__(self, fasttext_model, meter):
+        self.f = fasttext_model
+        self.m = meter
+
+    def score_vs_true(self, label):
+        """Return scores and the gold of each sample for a specific label"""
+        label_id = self.f.get_label_id(label)
+        pair_list = self.m.scoreVsTrue(label_id)
+
+        if pair_list:
+            y_scores, y_true = zip(*pair_list)
+        else:
+            y_scores, y_true = ([], ())
+
+        return np.array(y_scores, copy=False), np.array(y_true, copy=False)
+
+    def precision_recall_curve(self, label=None):
+        """Return precision/recall curve"""
+        if label:
+            label_id = self.f.get_label_id(label)
+            pair_list = self.m.precisionRecallCurveLabel(label_id)
+        else:
+            pair_list = self.m.precisionRecallCurve()
+
+        if pair_list:
+            precision, recall = zip(*pair_list)
+        else:
+            precision, recall = ([], ())
+
+        return np.array(precision, copy=False), np.array(recall, copy=False)
+
+    def precision_at_recall(self, recall, label=None):
+        """Return precision for a given recall"""
+        if label:
+            label_id = self.f.get_label_id(label)
+            precision = self.m.precisionAtRecallLabel(label_id, recall)
+        else:
+            precision = self.m.precisionAtRecall(recall)
+
+        return precision
+
+    def recall_at_precision(self, precision, label=None):
+        """Return recall for a given precision"""
+        if label:
+            label_id = self.f.get_label_id(label)
+            recall = self.m.recallAtPrecisionLabel(label_id, precision)
+        else:
+            recall = self.m.recallAtPrecision(precision)
+
+        return recall
+
+
 class _FastText(object):
     """
     This class defines the API to inspect models and should not be used to
@@ -100,6 +155,13 @@ class _FastText(object):
         """
         return self.f.getWordId(word)
 
+    def get_label_id(self, label):
+        """
+        Given a label, get the label id within the dictionary.
+        Returns -1 if label is not in the dictionary.
+        """
+        return self.f.getLabelId(label)
+
     def get_subword_id(self, subword):
         """
         Given a subword, return the index (within input matrix) it hashes to.
@@ -258,6 +320,11 @@ class _FastText(object):
         """
         return self.f.testLabel(path, k, threshold)
 
+    def get_meter(self, path, k=-1):
+        meter = _Meter(self, self.f.getMeter(path, k))
+
+        return meter
+
     def quantize(
         self,
         input=None,

+ 62 - 8
python/fasttext_module/fasttext/pybind/fasttext_pybind.cc

@@ -143,7 +143,15 @@ PYBIND11_MODULE(fasttext_pybind, m) {
 
   py::enum_<fasttext::metric_name>(m, "metric_name")
       .value("f1score", fasttext::metric_name::f1score)
-      .value("labelf1score", fasttext::metric_name::labelf1score)
+      .value("f1scoreLabel", fasttext::metric_name::f1scoreLabel)
+      .value("precisionAtRecall", fasttext::metric_name::precisionAtRecall)
+      .value(
+          "precisionAtRecallLabel",
+          fasttext::metric_name::precisionAtRecallLabel)
+      .value("recallAtPrecision", fasttext::metric_name::recallAtPrecision)
+      .value(
+          "recallAtPrecisionLabel",
+          fasttext::metric_name::recallAtPrecisionLabel)
       .export_values();
 
   m.def(
@@ -186,6 +194,34 @@ PYBIND11_MODULE(fasttext_pybind, m) {
              sizeof(fasttext::real) * (int64_t)1});
       });
 
+  py::class_<fasttext::Meter>(m, "Meter")
+      .def(py::init<bool>())
+      .def("scoreVsTrue", &fasttext::Meter::scoreVsTrue)
+      .def(
+          "precisionRecallCurveLabel",
+          py::overload_cast<int32_t>(
+              &fasttext::Meter::precisionRecallCurve, py::const_))
+      .def(
+          "precisionRecallCurve",
+          py::overload_cast<>(
+              &fasttext::Meter::precisionRecallCurve, py::const_))
+      .def(
+          "precisionAtRecallLabel",
+          py::overload_cast<int32_t, double>(
+              &fasttext::Meter::precisionAtRecall, py::const_))
+      .def(
+          "precisionAtRecall",
+          py::overload_cast<double>(
+              &fasttext::Meter::precisionAtRecall, py::const_))
+      .def(
+          "recallAtPrecisionLabel",
+          py::overload_cast<int32_t, double>(
+              &fasttext::Meter::recallAtPrecision, py::const_))
+      .def(
+          "recallAtPrecision",
+          py::overload_cast<double>(
+              &fasttext::Meter::recallAtPrecision, py::const_));
+
   py::class_<fasttext::FastText>(m, "fasttext")
       .def(py::init<>())
       .def("getArgs", &fasttext::FastText::getArgs)
@@ -231,20 +267,33 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
       .def(
           "test",
-          [](fasttext::FastText& m, 
-            const std::string filename, 
-            int32_t k,
-            fasttext::real threshold) {
+          [](fasttext::FastText& m,
+             const std::string& filename,
+             int32_t k,
+             fasttext::real threshold) {
             std::ifstream ifs(filename);
             if (!ifs.is_open()) {
               throw std::invalid_argument("Test file cannot be opened!");
             }
-            fasttext::Meter meter;
+            fasttext::Meter meter(false);
             m.test(ifs, k, threshold, meter);
             ifs.close();
             return std::tuple<int64_t, double, double>(
                 meter.nexamples(), meter.precision(), meter.recall());
           })
+      .def(
+          "getMeter",
+          [](fasttext::FastText& m, const std::string& filename, int32_t k) {
+            std::ifstream ifs(filename);
+            if (!ifs.is_open()) {
+              throw std::invalid_argument("Test file cannot be opened!");
+            }
+            fasttext::Meter meter(true);
+            m.test(ifs, k, 0.0, meter);
+            ifs.close();
+
+            return meter;
+          })
       .def(
           "getSentenceVector",
           [](fasttext::FastText& m,
@@ -397,7 +446,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
             if (!ifs.is_open()) {
               throw std::invalid_argument("Test file cannot be opened!");
             }
-            fasttext::Meter meter;
+            fasttext::Meter meter(false);
             m.test(ifs, k, threshold, meter);
             std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
             std::unordered_map<std::string, py::dict> returnedValue;
@@ -412,7 +461,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           })
       .def(
           "getWordId",
-          [](fasttext::FastText& m, const std::string word) {
+          [](fasttext::FastText& m, const std::string& word) {
             return m.getWordId(word);
           })
       .def(
@@ -420,6 +469,11 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           [](fasttext::FastText& m, const std::string word) {
             return m.getSubwordId(word);
           })
+      .def(
+          "getLabelId",
+          [](fasttext::FastText& m, const std::string& label) {
+            return m.getLabelId(label);
+          })
       .def(
           "getInputVector",
           [](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {

+ 55 - 9
src/args.cc

@@ -12,6 +12,7 @@
 
 #include <iostream>
 #include <stdexcept>
+#include <string>
 #include <unordered_map>
 
 namespace fasttext {
@@ -90,8 +91,16 @@ std::string Args::metricToString(metric_name mn) const {
   switch (mn) {
     case metric_name::f1score:
       return "f1score";
-    case metric_name::labelf1score:
-      return "labelf1score";
+    case metric_name::f1scoreLabel:
+      return "f1scoreLabel";
+    case metric_name::precisionAtRecall:
+      return "precisionAtRecall";
+    case metric_name::precisionAtRecallLabel:
+      return "precisionAtRecallLabel";
+    case metric_name::recallAtPrecision:
+      return "recallAtPrecision";
+    case metric_name::recallAtPrecisionLabel:
+      return "recallAtPrecisionLabel";
   }
   return "Unknown metric name!"; // should never happen
 }
@@ -388,22 +397,59 @@ void Args::setManual(const std::string& argName) {
 
 metric_name Args::getAutotuneMetric() const {
   if (autotuneMetric.substr(0, 3) == "f1:") {
-    return metric_name::labelf1score;
+    return metric_name::f1scoreLabel;
   } else if (autotuneMetric == "f1") {
     return metric_name::f1score;
+  } else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
+    size_t semicolon = autotuneMetric.find(":", 18);
+    if (semicolon != std::string::npos) {
+      return metric_name::precisionAtRecallLabel;
+    }
+    return metric_name::precisionAtRecall;
+  } else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
+    size_t semicolon = autotuneMetric.find(":", 18);
+    if (semicolon != std::string::npos) {
+      return metric_name::recallAtPrecisionLabel;
+    }
+    return metric_name::recallAtPrecision;
   }
   throw std::runtime_error("Unknown metric : " + autotuneMetric);
 }
 
 std::string Args::getAutotuneMetricLabel() const {
-  if (getAutotuneMetric() == metric_name::labelf1score) {
-    std::string label = autotuneMetric.substr(3);
-    if (label.empty()) {
-      throw std::runtime_error("Empty metric label : " + autotuneMetric);
-    }
+  metric_name metric = getAutotuneMetric();
+  std::string label;
+  if (metric == metric_name::f1scoreLabel) {
+    label = autotuneMetric.substr(3);
+  } else if (
+      metric == metric_name::precisionAtRecallLabel ||
+      metric == metric_name::recallAtPrecisionLabel) {
+    size_t semicolon = autotuneMetric.find(":", 18);
+    label = autotuneMetric.substr(semicolon + 1);
+  } else {
     return label;
   }
-  return std::string();
+
+  if (label.empty()) {
+    throw std::runtime_error("Empty metric label : " + autotuneMetric);
+  }
+  return label;
+}
+
+double Args::getAutotuneMetricValue() const {
+  metric_name metric = getAutotuneMetric();
+  double value = 0.0;
+  if (metric == metric_name::precisionAtRecallLabel ||
+      metric == metric_name::precisionAtRecall ||
+      metric == metric_name::recallAtPrecisionLabel ||
+      metric == metric_name::recallAtPrecision) {
+    size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
+    size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
+    const std::string valueStr =
+        autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
+    value = std::stof(valueStr) / 100.0;
+  }
+  return value;
 }
 
 int64_t Args::getAutotuneModelSize() const {

+ 9 - 1
src/args.h

@@ -18,7 +18,14 @@ namespace fasttext {
 
 enum class model_name : int { cbow = 1, sg, sup };
 enum class loss_name : int { hs = 1, ns, softmax, ova };
-enum class metric_name : int { f1score = 1, labelf1score };
+enum class metric_name : int {
+  f1score = 1,
+  f1scoreLabel,
+  precisionAtRecall,
+  precisionAtRecallLabel,
+  recallAtPrecision,
+  recallAtPrecisionLabel
+};
 
 class Args {
  protected:
@@ -81,6 +88,7 @@ class Args {
   std::string lossToString(loss_name) const;
   metric_name getAutotuneMetric() const;
   std::string getAutotuneMetricLabel() const;
+  double getAutotuneMetricValue() const;
   int64_t getAutotuneModelSize() const;
 
   static constexpr double kUnlimitedModelSize = -1.0;

+ 20 - 7
src/autotune.cc

@@ -277,17 +277,28 @@ void Autotune::startTimer(const Args& args) {
 double Autotune::getMetricScore(
     Meter& meter,
     const metric_name& metricName,
+    const double metricValue,
     const std::string& metricLabel) const {
   double score = 0.0;
-  if (metricName == metric_name::f1score) {
-    score = meter.f1Score();
-  } else if (metricName == metric_name::labelf1score) {
-    int32_t labelId = fastText_->getDictionary()->getId(metricLabel);
+  int32_t labelId = -1;
+  if (!metricLabel.empty()) {
+    labelId = fastText_->getLabelId(metricLabel);
     if (labelId == -1) {
       throw std::runtime_error("Unknown autotune metric label");
     }
-    labelId = labelId - fastText_->getDictionary()->nwords();
+  }
+  if (metricName == metric_name::f1score) {
+    score = meter.f1Score();
+  } else if (metricName == metric_name::f1scoreLabel) {
     score = meter.f1Score(labelId);
+  } else if (metricName == metric_name::precisionAtRecall) {
+    score = meter.precisionAtRecall(metricValue);
+  } else if (metricName == metric_name::precisionAtRecallLabel) {
+    score = meter.precisionAtRecall(labelId, metricValue);
+  } else if (metricName == metric_name::recallAtPrecision) {
+    score = meter.recallAtPrecision(metricValue);
+  } else if (metricName == metric_name::recallAtPrecisionLabel) {
+    score = meter.recallAtPrecision(labelId, metricValue);
   } else {
     throw std::runtime_error("Unknown metric");
   }
@@ -397,14 +408,16 @@ void Autotune::train(const Args& autotuneArgs) {
       fastText_->train(trainArgs);
       bool sizeConstraintOK = quantize(trainArgs, autotuneArgs);
       if (sizeConstraintOK) {
-        Meter meter;
+        const auto& metricLabel = autotuneArgs.getAutotuneMetricLabel();
+        Meter meter(!metricLabel.empty());
         fastText_->test(
             validationFileStream, autotuneArgs.autotunePredictions, 0.0, meter);
 
         currentScore = getMetricScore(
             meter,
             autotuneArgs.getAutotuneMetric(),
-            autotuneArgs.getAutotuneMetricLabel());
+            autotuneArgs.getAutotuneMetricValue(),
+            metricLabel);
 
         if (bestScore_ == kUnknownBestScore || (currentScore > bestScore_)) {
           bestTrainArgs = trainArgs;

+ 1 - 0
src/autotune.h

@@ -61,6 +61,7 @@ class Autotune {
   double getMetricScore(
       Meter& meter,
       const metric_name& metricName,
+      const double metricValue,
       const std::string& metricLabel) const;
   void printArgs(const Args& args, const Args& autotuneArgs);
   void printSkippedArgs(const Args& autotuneArgs);

+ 9 - 1
src/fasttext.cc

@@ -100,6 +100,14 @@ int32_t FastText::getSubwordId(const std::string& subword) const {
   return dict_->nwords() + h;
 }
 
+int32_t FastText::getLabelId(const std::string& label) const {
+  int32_t labelId = dict_->getId(label);
+  if (labelId != -1) {
+    labelId -= dict_->nwords();
+  }
+  return labelId;
+}
+
 void FastText::getWordVector(Vector& vec, const std::string& word) const {
   const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
   vec.zero();
@@ -413,7 +421,7 @@ void FastText::skipgram(
 
 std::tuple<int64_t, double, double>
 FastText::test(std::istream& in, int32_t k, real threshold) {
-  Meter meter;
+  Meter meter(false);
   test(in, k, threshold, meter);
 
   return std::tuple<int64_t, double, double>(

+ 2 - 0
src/fasttext.h

@@ -87,6 +87,8 @@ class FastText {
 
   int32_t getSubwordId(const std::string& subword) const;
 
+  int32_t getLabelId(const std::string& label) const;
+
   void getWordVector(Vector& vec, const std::string& word) const;
 
   void getSubwordVector(Vector& vec, const std::string& subword) const;

+ 1 - 1
src/main.cc

@@ -148,7 +148,7 @@ void test(const std::vector<std::string>& args) {
   FastText fasttext;
   fasttext.loadModel(model);
 
-  Meter meter;
+  Meter meter(false);
 
   if (input == "-") {
     fasttext.test(std::cin, k, threshold, meter);

+ 136 - 3
src/meter.cc

@@ -16,6 +16,9 @@
 
 namespace fasttext {
 
+constexpr int32_t kAllLabels = -1;
+constexpr real falseNegativeScore = -1.0;
+
 void Meter::log(
     const std::vector<int32_t>& labels,
     const Predictions& predictions) {
@@ -26,7 +29,7 @@ void Meter::log(
   for (const auto& prediction : predictions) {
     labelMetrics_[prediction.second].predicted++;
 
-    real score = std::exp(prediction.first);
+    real score = std::min(std::exp(prediction.first), 1.0f);
     real gold = 0.0;
     if (utils::contains(labels, prediction.second)) {
       labelMetrics_[prediction.second].predictedGold++;
@@ -36,8 +39,13 @@ void Meter::log(
     labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
   }
 
-  for (const auto& label : labels) {
-    labelMetrics_[label].gold++;
+  if (falseNegativeLabels_) {
+    for (const auto& label : labels) {
+      labelMetrics_[label].gold++;
+      if (!utils::containsSecond(predictions, label)) {
+        labelMetrics_[label].scoreVsTrue.emplace_back(falseNegativeScore, 1.0);
+      }
+    }
   }
 }
 
@@ -78,4 +86,129 @@ void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
   out << "R@" << k << "\t" << metrics_.recall() << std::endl;
 }
 
+std::vector<std::pair<uint64_t, uint64_t>> Meter::getPositiveCounts(
+    int32_t labelId) const {
+  std::vector<std::pair<uint64_t, uint64_t>> positiveCounts;
+
+  const auto& v = scoreVsTrue(labelId);
+  uint64_t truePositives = 0;
+  uint64_t falsePositives = 0;
+  double lastScore = falseNegativeScore - 1.0;
+
+  for (auto it = v.rbegin(); it != v.rend(); ++it) {
+    double score = it->first;
+    double gold = it->second;
+    if (score < 0) { // only reachable recall
+      break;
+    }
+    if (gold == 1.0) {
+      truePositives++;
+    } else {
+      falsePositives++;
+    }
+    if (score == lastScore && positiveCounts.size()) { // squeeze tied scores
+      positiveCounts.back() = {truePositives, falsePositives};
+    } else {
+      positiveCounts.emplace_back(truePositives, falsePositives);
+    }
+    lastScore = score;
+  }
+
+  return positiveCounts;
+}
+
+double Meter::precisionAtRecall(double recallQuery) const {
+  return precisionAtRecall(kAllLabels, recallQuery);
+}
+
+double Meter::precisionAtRecall(int32_t labelId, double recallQuery) const {
+  const auto& precisionRecall = precisionRecallCurve(labelId);
+  double bestPrecision = 0.0;
+  std::for_each(
+      precisionRecall.begin(),
+      precisionRecall.end(),
+      [&bestPrecision, recallQuery](const std::pair<double, double>& element) {
+        if (element.second >= recallQuery) {
+          bestPrecision = std::max(bestPrecision, element.first);
+        };
+      });
+  return bestPrecision;
+}
+
+double Meter::recallAtPrecision(double precisionQuery) const {
+  return recallAtPrecision(kAllLabels, precisionQuery);
+}
+
+double Meter::recallAtPrecision(int32_t labelId, double precisionQuery) const {
+  const auto& precisionRecall = precisionRecallCurve(labelId);
+  double bestRecall = 0.0;
+  std::for_each(
+      precisionRecall.begin(),
+      precisionRecall.end(),
+      [&bestRecall, precisionQuery](const std::pair<double, double>& element) {
+        if (element.first >= precisionQuery) {
+          bestRecall = std::max(bestRecall, element.second);
+        };
+      });
+  return bestRecall;
+}
+
+std::vector<std::pair<double, double>> Meter::precisionRecallCurve() const {
+  return precisionRecallCurve(kAllLabels);
+}
+
+std::vector<std::pair<double, double>> Meter::precisionRecallCurve(
+    int32_t labelId) const {
+  std::vector<std::pair<double, double>> precisionRecallCurve;
+  const auto& positiveCounts = getPositiveCounts(labelId);
+  if (positiveCounts.empty()) {
+    return precisionRecallCurve;
+  }
+
+  uint64_t golds =
+      (labelId == kAllLabels) ? metrics_.gold : labelMetrics_.at(labelId).gold;
+
+  auto fullRecall = std::lower_bound(
+      positiveCounts.begin(),
+      positiveCounts.end(),
+      golds,
+      utils::compareFirstLess);
+
+  if (fullRecall != positiveCounts.end()) {
+    fullRecall = std::next(fullRecall);
+  }
+
+  for (auto it = positiveCounts.begin(); it != fullRecall; it++) {
+    double precision = 0.0;
+    double truePositives = it->first;
+    double falsePositives = it->second;
+    if (truePositives + falsePositives != 0.0) {
+      precision = truePositives / (truePositives + falsePositives);
+    }
+    double recall = golds != 0 ? (truePositives / double(golds))
+                               : std::numeric_limits<double>::quiet_NaN();
+    precisionRecallCurve.emplace_back(precision, recall);
+  }
+  precisionRecallCurve.emplace_back(1.0, 0.0);
+
+  return precisionRecallCurve;
+}
+
+std::vector<std::pair<real, real>> Meter::scoreVsTrue(int32_t labelId) const {
+  std::vector<std::pair<real, real>> ret;
+  if (labelId == kAllLabels) {
+    for (const auto& k : labelMetrics_) {
+      auto& labelScoreVsTrue = labelMetrics_.at(k.first).scoreVsTrue;
+      ret.insert(ret.end(), labelScoreVsTrue.begin(), labelScoreVsTrue.end());
+    }
+  } else {
+    if (labelMetrics_.count(labelId)) {
+      ret = labelMetrics_.at(labelId).scoreVsTrue;
+    }
+  }
+  sort(ret.begin(), ret.end());
+
+  return ret;
+}
+
 } // namespace fasttext

+ 22 - 2
src/meter.h

@@ -24,7 +24,7 @@ class Meter {
     uint64_t predictedGold;
     mutable std::vector<std::pair<real, real>> scoreVsTrue;
 
-    Metrics() : gold(0), predicted(0), predictedGold(0) {}
+    Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
 
     double precision() const {
       if (predicted == 0) {
@@ -44,16 +44,35 @@ class Meter {
       }
       return 2 * predictedGold / double(predicted + gold);
     }
+
+    std::vector<std::pair<real, real>> getScoreVsTrue() {
+      return scoreVsTrue;
+    }
   };
+  std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
+      int32_t labelId) const;
 
  public:
-  Meter() : metrics_(), nexamples_(0), labelMetrics_() {}
+  Meter() = delete;
+  explicit Meter(bool falseNegativeLabels)
+      : metrics_(),
+        nexamples_(0),
+        labelMetrics_(),
+        falseNegativeLabels_(falseNegativeLabels) {}
 
   void log(const std::vector<int32_t>& labels, const Predictions& predictions);
 
   double precision(int32_t);
   double recall(int32_t);
   double f1Score(int32_t);
+  std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
+  double precisionAtRecall(int32_t labelId, double recall) const;
+  double precisionAtRecall(double recall) const;
+  double recallAtPrecision(int32_t labelId, double recall) const;
+  double recallAtPrecision(double recall) const;
+  std::vector<std::pair<double, double>> precisionRecallCurve(
+      int32_t labelId) const;
+  std::vector<std::pair<double, double>> precisionRecallCurve() const;
   double precision() const;
   double recall() const;
   double f1Score() const;
@@ -66,6 +85,7 @@ class Meter {
   Metrics metrics_{};
   uint64_t nexamples_;
   std::unordered_map<int32_t, Metrics> labelMetrics_;
+  bool falseNegativeLabels_;
 };
 
 } // namespace fasttext

+ 4 - 0
src/utils.cc

@@ -44,6 +44,10 @@ std::ostream& operator<<(std::ostream& out, const ClockPrint& me) {
   return out;
 }
 
+bool compareFirstLess(const std::pair<double, double>& l, const double& r) {
+  return l.first < r;
+}
+
 } // namespace utils
 
 } // namespace fasttext

+ 14 - 0
src/utils.h

@@ -40,6 +40,18 @@ bool contains(const std::vector<T>& container, const T& value) {
       container.end();
 }
 
+template <typename T1, typename T2>
+bool containsSecond(
+    const std::vector<std::pair<T1, T2>>& container,
+    const T2& value) {
+  return std::find_if(
+             container.begin(),
+             container.end(),
+             [&value](const std::pair<T1, T2>& item) {
+               return item.second == value;
+             }) != container.end();
+}
+
 double getDuration(
     const std::chrono::steady_clock::time_point& start,
     const std::chrono::steady_clock::time_point& end);
@@ -53,6 +65,8 @@ class ClockPrint {
   int32_t duration_;
 };
 
+bool compareFirstLess(const std::pair<double, double>& l, const double& r);
+
 } // namespace utils
 
 } // namespace fasttext