Parcourir la source

Use threshold for prediction instead of top-k

Summary: [fasttext] Use threshold for prediction instead of top-k

Reviewed By: EdouardGrave

Differential Revision: D5644298

fbshipit-source-id: cc5ddaf9202d2bf0f9ed7a63ebdd3cc5a397965d
Changhan Wang il y a 8 ans
Parent
commit
d72255386b
7 fichiers modifiés avec 86 ajouts et 46 suppressions
  1. 8 4
      python/fastText/FastText.py
  2. 8 4
      python/fastText/pybind/fasttext_pybind.cc
  3. 20 9
      src/fasttext.cc
  4. 4 3
      src/fasttext.h
  5. 21 11
      src/main.cc
  6. 21 11
      src/model.cc
  7. 4 4
      src/model.h

+ 8 - 4
python/fastText/FastText.py

@@ -97,13 +97,17 @@ class _FastText():
         self.f.getInputVector(b, ind)
         return np.array(b)
 
-    def predict(self, text, k=1):
+    def predict(self, text, k=1, threshold=0.0):
         """
         Given a string, get a list of labels and a list of
         corresponding probabilities. k controls the number
         of returned labels. A choice of 5, will return the 5
         most probable labels. By default this returns only
-        the most likely label and probability.
+        the most likely label and probability. threshold filters
+        the returned labels by a threshold on probability. A
+        choice of 0.5 will return labels with at least 0.5
+        probability. k and threshold will be applied together to
+        determine the returned labels.
 
         This function assumes to be given
         a single line of text. We split words on whitespace (space,
@@ -126,11 +130,11 @@ class _FastText():
 
         if type(text) == list:
             text = [check(entry) for entry in text]
-            all_probs, all_labels = self.f.multilinePredict(text, k)
+            all_probs, all_labels = self.f.multilinePredict(text, k, threshold)
             return all_labels, np.array(all_probs, copy=False)
         else:
             text = check(text)
-            pairs = self.f.predict(text, k)
+            pairs = self.f.predict(text, k, threshold)
             probs, labels = zip(*pairs)
             return labels, np.array(probs, copy=False)
 

+ 8 - 4
python/fastText/pybind/fasttext_pybind.cc

@@ -253,10 +253,13 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           "predict",
           // NOTE: text needs to end in a newline
           // to exactly mimic the behavior of the cli
-          [](fasttext::FastText& m, const std::string& text, int32_t k) {
+          [](fasttext::FastText& m,
+             const std::string text,
+             int32_t k,
+             fasttext::real threshold) {
             std::vector<std::pair<fasttext::real, std::string>> predictions;
             std::stringstream ioss(text);
-            m.predict(ioss, k, predictions);
+            m.predict(ioss, k, predictions, threshold);
             for (auto& pair : predictions) {
               pair.first = std::exp(pair.first);
             }
@@ -268,7 +271,8 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           // to exactly mimic the behavior of the cli
           [](fasttext::FastText& m,
              const std::vector<std::string>& lines,
-             int32_t k) {
+             int32_t k,
+             fasttext::real threshold) {
             std::pair<
                 std::vector<std::vector<fasttext::real>>,
                 std::vector<std::vector<std::string>>>
@@ -277,7 +281,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
             for (const std::string& text : lines) {
               std::stringstream ioss(text);
               predictions.clear();
-              m.predict(ioss, k, predictions);
+              m.predict(ioss, k, predictions, threshold);
               all_predictions.first.push_back(std::vector<fasttext::real>());
               all_predictions.second.push_back(std::vector<std::string>());
               for (auto& pair : predictions) {

+ 20 - 9
src/fasttext.cc

@@ -366,8 +366,9 @@ void FastText::skipgram(Model& model, real lr,
 
 std::tuple<int64_t, double, double> FastText::test(
     std::istream& in,
-    int32_t k) {
-  int32_t nexamples = 0, nlabels = 0;
+    int32_t k,
+    real threshold) {
+  int32_t nexamples = 0, nlabels = 0, npredictions = 0;
   double precision = 0.0;
   std::vector<int32_t> line, labels;
 
@@ -375,7 +376,7 @@ std::tuple<int64_t, double, double> FastText::test(
     dict_->getLine(in, line, labels);
     if (labels.size() > 0 && line.size() > 0) {
       std::vector<std::pair<real, int32_t>> modelPredictions;
-      model_->predict(line, k, modelPredictions);
+      model_->predict(line, k, threshold, modelPredictions);
       for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
         if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) {
           precision += 1.0;
@@ -383,14 +384,19 @@ std::tuple<int64_t, double, double> FastText::test(
       }
       nexamples++;
       nlabels += labels.size();
+      npredictions += modelPredictions.size();
     }
   }
   return std::tuple<int64_t, double, double>(
-      nexamples, precision / (k * nexamples), precision / nlabels);
+      nexamples, precision / npredictions, precision / nlabels);
 }
 
-void FastText::predict(std::istream& in, int32_t k,
-                       std::vector<std::pair<real,std::string>>& predictions) const {
+void FastText::predict(
+  std::istream& in,
+  int32_t k,
+  std::vector<std::pair<real,std::string>>& predictions,
+  real threshold
+) const {
   std::vector<int32_t> words, labels;
   predictions.clear();
   dict_->getLine(in, words, labels);
@@ -399,17 +405,22 @@ void FastText::predict(std::istream& in, int32_t k,
   Vector hidden(args_->dim);
   Vector output(dict_->nlabels());
   std::vector<std::pair<real,int32_t>> modelPredictions;
-  model_->predict(words, k, modelPredictions, hidden, output);
+  model_->predict(words, k, threshold, modelPredictions, hidden, output);
   for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
     predictions.push_back(std::make_pair(it->first, dict_->getLabel(it->second)));
   }
 }
 
-void FastText::predict(std::istream& in, int32_t k, bool print_prob) {
+void FastText::predict(
+  std::istream& in,
+  int32_t k,
+  bool print_prob,
+  real threshold
+) {
   std::vector<std::pair<real,std::string>> predictions;
   while (in.peek() != EOF) {
     predictions.clear();
-    predict(in, k, predictions);
+    predict(in, k, predictions, threshold);
     if (predictions.empty()) {
       std::cout << std::endl;
       continue;

+ 4 - 3
src/fasttext.h

@@ -93,12 +93,13 @@ class FastText {
   std::vector<int32_t> selectEmbeddings(int32_t) const;
   void getSentenceVector(std::istream&, Vector&);
   void quantize(const Args);
-  std::tuple<int64_t, double, double> test(std::istream&, int32_t);
-  void predict(std::istream&, int32_t, bool);
+  std::tuple<int64_t, double, double> test(std::istream&, int32_t, real = 0.0);
+  void predict(std::istream&, int32_t, bool, real = 0.0);
   void predict(
       std::istream&,
       int32_t,
-      std::vector<std::pair<real, std::string>>&) const;
+      std::vector<std::pair<real, std::string>>&,
+      real = 0.0) const;
   void ngramVectors(std::string);
   void precomputeWordVectors(Matrix&);
   void findNN(

+ 21 - 11
src/main.cc

@@ -43,19 +43,21 @@ void printQuantizeUsage() {
 
 void printTestUsage() {
   std::cerr
-    << "usage: fasttext test <model> <test-data> [<k>]\n\n"
+    << "usage: fasttext test <model> <test-data> [<k>] [<th>]\n\n"
     << "  <model>      model filename\n"
     << "  <test-data>  test data filename (if -, read from stdin)\n"
     << "  <k>          (optional; 1 by default) predict top k labels\n"
+    << "  <th>         (optional; 0.0 by default) probability threshold\n"
     << std::endl;
 }
 
 void printPredictUsage() {
   std::cerr
-    << "usage: fasttext predict[-prob] <model> <test-data> [<k>]\n\n"
+    << "usage: fasttext predict[-prob] <model> <test-data> [<k>] [<th>]\n\n"
     << "  <model>      model filename\n"
     << "  <test-data>  test data filename (if -, read from stdin)\n"
     << "  <k>          (optional; 1 by default) predict top k labels\n"
+    << "  <th>         (optional; 0.0 by default) probability threshold\n"
     << std::endl;
 }
 
@@ -122,13 +124,17 @@ void printDumpUsage() {
 }
 
 void test(const std::vector<std::string>& args) {
-  if (args.size() < 4 || args.size() > 5) {
+  if (args.size() < 4 || args.size() > 6) {
     printTestUsage();
     exit(EXIT_FAILURE);
   }
   int32_t k = 1;
-  if (args.size() >= 5) {
+  real threshold = 0.0;
+  if (args.size() > 4) {
     k = std::stoi(args[4]);
+    if (args.size() == 6) {
+      threshold = std::stof(args[5]);
+    }
   }
 
   FastText fasttext;
@@ -137,14 +143,14 @@ void test(const std::vector<std::string>& args) {
   std::tuple<int64_t, double, double> result;
   std::string infile = args[3];
   if (infile == "-") {
-    result = fasttext.test(std::cin, k);
+    result = fasttext.test(std::cin, k, threshold);
   } else {
     std::ifstream ifs(infile);
     if (!ifs.is_open()) {
       std::cerr << "Test file cannot be opened!" << std::endl;
       exit(EXIT_FAILURE);
     }
-    result = fasttext.test(ifs, k);
+    result = fasttext.test(ifs, k, threshold);
     ifs.close();
   }
   std::cout << "N" << "\t" << std::get<0>(result) << std::endl;
@@ -155,13 +161,17 @@ void test(const std::vector<std::string>& args) {
 }
 
 void predict(const std::vector<std::string>& args) {
-  if (args.size() < 4 || args.size() > 5) {
+  if (args.size() < 4 || args.size() > 6) {
     printPredictUsage();
     exit(EXIT_FAILURE);
   }
   int32_t k = 1;
-  if (args.size() >= 5) {
+  real threshold = 0.0;
+  if (args.size() > 4) {
     k = std::stoi(args[4]);
+    if (args.size() == 6) {
+      threshold = std::stof(args[5]);
+    }
   }
 
   bool print_prob = args[1] == "predict-prob";
@@ -170,14 +180,14 @@ void predict(const std::vector<std::string>& args) {
 
   std::string infile(args[3]);
   if (infile == "-") {
-    fasttext.predict(std::cin, k, print_prob);
+    fasttext.predict(std::cin, k, print_prob, threshold);
   } else {
     std::ifstream ifs(infile);
     if (!ifs.is_open()) {
       std::cerr << "Input file cannot be opened!" << std::endl;
       exit(EXIT_FAILURE);
     }
-    fasttext.predict(ifs, k, print_prob);
+    fasttext.predict(ifs, k, print_prob, threshold);
     ifs.close();
   }
 
@@ -346,7 +356,7 @@ int main(int argc, char** argv) {
     nn(args);
   } else if (command == "analogies") {
     analogies(args);
-  } else if (command == "predict" || command == "predict-prob" ) {
+  } else if (command == "predict" || command == "predict-prob") {
     predict(args);
   } else if (command == "dump") {
     dump(args);

+ 21 - 11
src/model.cc

@@ -142,7 +142,7 @@ bool Model::comparePairs(const std::pair<real, int32_t> &l,
   return l.first > r.first;
 }
 
-void Model::predict(const std::vector<int32_t>& input, int32_t k,
+void Model::predict(const std::vector<int32_t>& input, int32_t k, real threshold,
                     std::vector<std::pair<real, int32_t>>& heap,
                     Vector& hidden, Vector& output) const {
   if (k <= 0) {
@@ -154,22 +154,31 @@ void Model::predict(const std::vector<int32_t>& input, int32_t k,
   heap.reserve(k + 1);
   computeHidden(input, hidden);
   if (args_->loss == loss_name::hs) {
-    dfs(k, 2 * osz_ - 2, 0.0, heap, hidden);
+    dfs(k, threshold, 2 * osz_ - 2, 0.0, heap, hidden);
   } else {
-    findKBest(k, heap, hidden, output);
+    findKBest(k, threshold, heap, hidden, output);
   }
   std::sort_heap(heap.begin(), heap.end(), comparePairs);
 }
 
-void Model::predict(const std::vector<int32_t>& input, int32_t k,
-                    std::vector<std::pair<real, int32_t>>& heap) {
-  predict(input, k, heap, hidden_, output_);
+void Model::predict(
+  const std::vector<int32_t>& input,
+  int32_t k,
+  real threshold,
+  std::vector<std::pair<real, int32_t>>& heap
+) {
+  predict(input, k, threshold, heap, hidden_, output_);
 }
 
-void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
-                      Vector& hidden, Vector& output) const {
+void Model::findKBest(
+  int32_t k,
+  real threshold,
+  std::vector<std::pair<real, int32_t>>& heap,
+  Vector& hidden, Vector& output
+) const {
   computeOutputSoftmax(hidden, output);
   for (int32_t i = 0; i < osz_; i++) {
+    if (output[i] < threshold) continue;
     if (heap.size() == k && std_log(output[i]) < heap.front().first) {
       continue;
     }
@@ -182,9 +191,10 @@ void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
   }
 }
 
-void Model::dfs(int32_t k, int32_t node, real score,
+void Model::dfs(int32_t k, real threshold, int32_t node, real score,
                 std::vector<std::pair<real, int32_t>>& heap,
                 Vector& hidden) const {
+  if (score < std_log(threshold)) return;
   if (heap.size() == k && score < heap.front().first) {
     return;
   }
@@ -207,8 +217,8 @@ void Model::dfs(int32_t k, int32_t node, real score,
   }
   f = 1. / (1 + std::exp(-f));
 
-  dfs(k, tree[node].left, score + std_log(1.0 - f), heap, hidden);
-  dfs(k, tree[node].right, score + std_log(f), heap, hidden);
+  dfs(k, threshold, tree[node].left, score + std_log(1.0 - f), heap, hidden);
+  dfs(k, threshold, tree[node].right, score + std_log(f), heap, hidden);
 }
 
 void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {

+ 4 - 4
src/model.h

@@ -72,15 +72,15 @@ class Model {
     real hierarchicalSoftmax(int32_t, real);
     real softmax(int32_t, real);
 
-    void predict(const std::vector<int32_t>&, int32_t,
+    void predict(const std::vector<int32_t>&, int32_t, real,
                  std::vector<std::pair<real, int32_t>>&,
                  Vector&, Vector&) const;
-    void predict(const std::vector<int32_t>&, int32_t,
+    void predict(const std::vector<int32_t>&, int32_t, real,
                  std::vector<std::pair<real, int32_t>>&);
-    void dfs(int32_t, int32_t, real,
+    void dfs(int32_t, real, int32_t, real,
              std::vector<std::pair<real, int32_t>>&,
              Vector&) const;
-    void findKBest(int32_t, std::vector<std::pair<real, int32_t>>&,
+    void findKBest(int32_t, real, std::vector<std::pair<real, int32_t>>&,
                    Vector&, Vector&) const;
     void update(const std::vector<int32_t>&, int32_t, real);
     void computeHidden(const std::vector<int32_t>&, Vector&) const;