Prechádzať zdrojové kódy

added threshold arg to test function (#868)

Summary:
Need threshold arg in test function, in python binding
Pull Request resolved: https://github.com/facebookresearch/fastText/pull/868

Reviewed By: EdouardGrave

Differential Revision: D17602650

Pulled By: Celebio

fbshipit-source-id: 4ace2401aabbe7fdcdb4c6856098418642c46420
Raghav Jajodia 6 rokov pred
rodič
commit
022c1a7737

+ 2 - 2
python/fasttext_module/fasttext/FastText.py

@@ -242,9 +242,9 @@ class _FastText(object):
         """Save the model to the given path"""
         self.f.saveModel(path)
 
-    def test(self, path, k=1):
+    def test(self, path, k=1, threshold=0.0):
         """Evaluate supervised model using file given by path"""
-        return self.f.test(path, k)
+        return self.f.test(path, k, threshold)
 
     def test_label(self, path, k=1, threshold=0.0):
         """

+ 5 - 2
python/fasttext_module/fasttext/pybind/fasttext_pybind.cc

@@ -217,13 +217,16 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
       .def(
           "test",
-          [](fasttext::FastText& m, const std::string filename, int32_t k) {
+          [](fasttext::FastText& m, 
+            const std::string filename, 
+            int32_t k,
+            fasttext::real threshold) {
             std::ifstream ifs(filename);
             if (!ifs.is_open()) {
               throw std::invalid_argument("Test file cannot be opened!");
             }
             fasttext::Meter meter;
-            m.test(ifs, k, 0.0, meter);
+            m.test(ifs, k, threshold, meter);
             ifs.close();
             return std::tuple<int64_t, double, double>(
                 meter.nexamples(), meter.precision(), meter.recall());