Jelajahi Sumber

getNN and getAnalogies functions handle `onUnicodeError` argument

Summary: Fixing the previous pull-requests issues + refactoring

Reviewed By: EdouardGrave

Differential Revision: D20478559

fbshipit-source-id: bc92b40257a74ee548b087740bd81af3886ab1d6
Onur Çelebi 6 tahun lalu
induk
melakukan
5a5b1e6410

+ 5 - 4
python/fasttext_module/fasttext/FastText.py

@@ -86,11 +86,12 @@ class _FastText(object):
         self.f.getSentenceVector(b, text)
         return np.array(b)
 
-    def get_nearest_neighbors(self, word, k=10):
-        return self.f.getNN(word, k)
+    def get_nearest_neighbors(self, word, k=10, on_unicode_error='strict'):
+        return self.f.getNN(word, k, on_unicode_error)
 
-    def get_analogies(self, wordA, wordB, wordC, k=10):
-        return self.f.getAnalogies(wordA, wordB, wordC, k)
+    def get_analogies(self, wordA, wordB, wordC, k=10,
+                      on_unicode_error='strict'):
+        return self.f.getAnalogies(wordA, wordB, wordC, k, on_unicode_error)
 
     def get_word_id(self, word):
         """

+ 24 - 24
python/fasttext_module/fasttext/pybind/fasttext_pybind.cc

@@ -44,6 +44,20 @@ py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
   return handle_str;
 }
 
+std::vector<std::pair<fasttext::real, py::str>> castToPythonString(
+    const std::vector<std::pair<fasttext::real, std::string>>& predictions,
+    const char* onUnicodeError) {
+  std::vector<std::pair<fasttext::real, py::str>> transformedPredictions;
+
+  for (const auto& prediction : predictions) {
+    transformedPredictions.emplace_back(
+        prediction.first,
+        castToPythonString(prediction.second, onUnicodeError));
+  }
+
+  return transformedPredictions;
+}
+
 std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
     fasttext::FastText& m,
     const std::string text,
@@ -339,16 +353,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
             std::vector<std::pair<fasttext::real, std::string>> predictions;
             m.predictLine(ioss, predictions, k, threshold);
 
-            std::vector<std::pair<fasttext::real, py::str>>
-                transformedPredictions;
-
-            for (const auto& prediction : predictions) {
-              transformedPredictions.push_back(std::make_pair(
-                  prediction.first,
-                  castToPythonString(prediction.second, onUnicodeError)));
-            }
-
-            return transformedPredictions;
+            return castToPythonString(predictions, onUnicodeError);
           })
       .def(
           "multilinePredict",
@@ -427,20 +432,11 @@ PYBIND11_MODULE(fasttext_pybind, m) {
              const std::string word) { m.getWordVector(vec, word); })
       .def(
           "getNN",
-          [](fasttext::FastText& m, const std::string& word, int32_t k,
+          [](fasttext::FastText& m,
+             const std::string& word,
+             int32_t k,
              const char* onUnicodeError) {
-            std::vector<std::pair<float, std::string>> score_words = m.getNN(
-                word, k);
-            std::vector<std::pair<float, py::str>> output_list;
-            for (uint32_t i = 0; i < score_words.size(); i++) {
-               float score = score_words[i].first;
-               py::str word = castToPythonString(
-                   score_words[i].second, onUnicodeError);
-               std::pair<float, py::str> sw_pair = std::make_pair(score, word);
-               output_list.push_back(sw_pair);
-            }
-
-            return output_list;
+            return castToPythonString(m.getNN(word, k), onUnicodeError);
           })
       .def(
           "getAnalogies",
@@ -448,7 +444,11 @@ PYBIND11_MODULE(fasttext_pybind, m) {
              const std::string& wordA,
              const std::string& wordB,
              const std::string& wordC,
-             int32_t k) { return m.getAnalogies(k, wordA, wordB, wordC); })
+             int32_t k,
+             const char* onUnicodeError) {
+            return castToPythonString(
+                m.getAnalogies(k, wordA, wordB, wordC), onUnicodeError);
+          })
       .def(
           "getSubwords",
           [](fasttext::FastText& m,

+ 1 - 2
src/autotune.cc

@@ -406,8 +406,7 @@ void Autotune::train(const Args& autotuneArgs) {
             autotuneArgs.getAutotuneMetric(),
             autotuneArgs.getAutotuneMetricLabel());
 
-        if (bestScore_ == kUnknownBestScore ||
-            (currentScore > bestScore_)) {
+        if (bestScore_ == kUnknownBestScore || (currentScore > bestScore_)) {
           bestTrainArgs = trainArgs;
           bestScore_ = currentScore;
           strategy_->updateBest(bestTrainArgs);

+ 0 - 1
src/autotune.h

@@ -73,7 +73,6 @@ class Autotune {
     TimeoutError() : std::runtime_error("Autotune timed out.") {}
   };
 
-
  public:
   Autotune() = delete;
   explicit Autotune(const std::shared_ptr<FastText>& fastText);