Bladeren bron

adding unicode error handling on the python interface

Summary:
The issue was reported here : https://github.com/facebookresearch/fastText/issues/715
Now, we can replace the line :
```
words = f.get_words()
```
by
```
words = f.get_words(on_unicode_error='replace')
```
in bin_to_vec.py

The behaviour is similar to python's `decode` function : if there is an encoding issue, `strict`: it fails with an error, `replace`: replaces silently the malformed characters by the replacement character.

Reviewed By: EdouardGrave

Differential Revision: D14133996

fbshipit-source-id: 9c82fef69b6d5223e4e5d60516a53467d8786ffc
Onur Çelebi 7 jaren geleden
bovenliggende
commit
e13484bcb2
2 gewijzigde bestanden met toevoegingen van 83 en 49 verwijderingen
  1. 12 12
      python/fastText/FastText.py
  2. 71 37
      python/fastText/pybind/fasttext_pybind.cc

+ 12 - 12
python/fastText/FastText.py

@@ -81,11 +81,11 @@ class _FastText():
         """
         return self.f.getSubwordId(subword)
 
-    def get_subwords(self, word):
+    def get_subwords(self, word, on_unicode_error='strict'):
         """
         Given a word, get the subwords and their indicies.
         """
-        pair = self.f.getSubwords(word)
+        pair = self.f.getSubwords(word, on_unicode_error)
         return pair[0], np.array(pair[1])
 
     def get_input_vector(self, ind):
@@ -97,7 +97,7 @@ class _FastText():
         self.f.getInputVector(b, ind)
         return np.array(b)
 
-    def predict(self, text, k=1, threshold=0.0):
+    def predict(self, text, k=1, threshold=0.0, on_unicode_error='strict'):
         """
         Given a string, get a list of labels and a list of
         corresponding probabilities. k controls the number
@@ -130,14 +130,14 @@ class _FastText():
 
         if type(text) == list:
             text = [check(entry) for entry in text]
-            predictions = self.f.multilinePredict(text, k, threshold)
+            predictions = self.f.multilinePredict(text, k, threshold, on_unicode_error)
             dt = np.dtype([('probability', 'float64'), ('label', 'object')])
             result_as_pair = np.array(predictions, dtype=dt)
 
             return result_as_pair['label'].tolist(), result_as_pair['probability']
         else:
             text = check(text)
-            predictions = self.f.predict(text, k, threshold)
+            predictions = self.f.predict(text, k, threshold, on_unicode_error)
             probs, labels = zip(*predictions)
 
             return labels, np.array(probs, copy=False)
@@ -160,20 +160,20 @@ class _FastText():
             raise ValueError("Can't get quantized Matrix")
         return np.array(self.f.getOutputMatrix())
 
-    def get_words(self, include_freq=False):
+    def get_words(self, include_freq=False, on_unicode_error='strict'):
         """
         Get the entire list of words of the dictionary optionally
         including the frequency of the individual words. This
         does not include any subwords. For that please consult
         the function get_subwords.
         """
-        pair = self.f.getVocab()
+        pair = self.f.getVocab(on_unicode_error)
         if include_freq:
             return (pair[0], np.array(pair[1]))
         else:
             return pair[0]
 
-    def get_labels(self, include_freq=False):
+    def get_labels(self, include_freq=False, on_unicode_error='strict'):
         """
         Get the entire list of labels of the dictionary optionally
         including the frequency of the individual labels. Unsupervised
@@ -183,7 +183,7 @@ class _FastText():
         """
         a = self.f.getArgs()
         if a.model == model_name.supervised:
-            pair = self.f.getLabels()
+            pair = self.f.getLabels(on_unicode_error)
             if include_freq:
                 return (pair[0], np.array(pair[1]))
             else:
@@ -191,7 +191,7 @@ class _FastText():
         else:
             return self.get_words(include_freq)
 
-    def get_line(self, text):
+    def get_line(self, text, on_unicode_error='strict'):
         """
         Split a line of text into words and labels. Labels must start with
         the prefix used to create the model (__label__ by default).
@@ -207,10 +207,10 @@ class _FastText():
 
         if type(text) == list:
             text = [check(entry) for entry in text]
-            return self.f.multilineGetLine(text)
+            return self.f.multilineGetLine(text, on_unicode_error)
         else:
             text = check(text)
-            return self.f.getLine(text)
+            return self.f.getLine(text, on_unicode_error)
 
     def save_model(self, path):
         """Save the model to the given path"""

+ 71 - 37
python/fastText/pybind/fasttext_pybind.cc

@@ -19,35 +19,42 @@
 #include <stdexcept>
 
 using namespace pybind11::literals;
+namespace py = pybind11;
+
+py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
+  PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
+  if (!handle) {
+    throw py::error_already_set();
+  }
+  return py::str(handle);
+}
 
-std::pair<std::vector<std::string>, std::vector<std::string>> getLineText(
+std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
     fasttext::FastText& m,
-    const std::string text) {
+    const std::string text,
+    const char* onUnicodeError) {
   std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
   std::stringstream ioss(text);
   std::string token;
-  std::vector<std::string> words;
-  std::vector<std::string> labels;
+  std::vector<py::str> words;
+  std::vector<py::str> labels;
   while (d->readWord(ioss, token)) {
     uint32_t h = d->hash(token);
     int32_t wid = d->getId(token, h);
     fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
 
     if (type == fasttext::entry_type::word) {
-      words.push_back(token);
+      words.push_back(castToPythonString(token, onUnicodeError));
       // Labels must not be OOV!
     } else if (type == fasttext::entry_type::label && wid >= 0) {
-      labels.push_back(token);
+      labels.push_back(castToPythonString(token, onUnicodeError));
     }
     if (token == fasttext::Dictionary::EOS)
       break;
   }
-  return std::pair<std::vector<std::string>, std::vector<std::string>>(
-      words, labels);
+  return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
 }
 
-namespace py = pybind11;
-
 PYBIND11_MODULE(fasttext_pybind, m) {
   py::class_<fasttext::Args>(m, "args")
       .def(py::init<>())
@@ -184,48 +191,48 @@ PYBIND11_MODULE(fasttext_pybind, m) {
       .def("getLine", &getLineText)
       .def(
           "multilineGetLine",
-          [](fasttext::FastText& m, const std::vector<std::string> lines) {
+          [](fasttext::FastText& m,
+             const std::vector<std::string> lines,
+             const char* onUnicodeError) {
             std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
-            std::vector<std::vector<std::string>> all_words;
-            std::vector<std::vector<std::string>> all_labels;
-            std::vector<std::string> words;
-            std::vector<std::string> labels;
-            std::string token;
+            std::vector<std::vector<py::str>> all_words;
+            std::vector<std::vector<py::str>> all_labels;
             for (const auto& text : lines) {
-              auto pair = getLineText(m, text);
+              auto pair = getLineText(m, text, onUnicodeError);
               all_words.push_back(pair.first);
               all_labels.push_back(pair.second);
             }
             return std::pair<
-                std::vector<std::vector<std::string>>,
-                std::vector<std::vector<std::string>>>(all_words, all_labels);
+                std::vector<std::vector<py::str>>,
+                std::vector<std::vector<py::str>>>(all_words, all_labels);
           })
       .def(
           "getVocab",
-          [](fasttext::FastText& m) {
-            std::vector<std::string> vocab_list;
+          [](fasttext::FastText& m, const char* onUnicodeError) {
+            py::str s;
+            std::vector<py::str> vocab_list;
             std::vector<int64_t> vocab_freq;
             std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
             vocab_freq = d->getCounts(fasttext::entry_type::word);
-            vocab_list.clear();
             for (int32_t i = 0; i < vocab_freq.size(); i++) {
-              vocab_list.push_back(d->getWord(i));
+              vocab_list.push_back(
+                  castToPythonString(d->getWord(i), onUnicodeError));
             }
-            return std::pair<std::vector<std::string>, std::vector<int64_t>>(
+            return std::pair<std::vector<py::str>, std::vector<int64_t>>(
                 vocab_list, vocab_freq);
           })
       .def(
           "getLabels",
-          [](fasttext::FastText& m) {
-            std::vector<std::string> labels_list;
+          [](fasttext::FastText& m, const char* onUnicodeError) {
+            std::vector<py::str> labels_list;
             std::vector<int64_t> labels_freq;
             std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
             labels_freq = d->getCounts(fasttext::entry_type::label);
-            labels_list.clear();
             for (int32_t i = 0; i < labels_freq.size(); i++) {
-              labels_list.push_back(d->getLabel(i));
+              labels_list.push_back(
+                  castToPythonString(d->getLabel(i), onUnicodeError));
             }
-            return std::pair<std::vector<std::string>, std::vector<int64_t>>(
+            return std::pair<std::vector<py::str>, std::vector<int64_t>>(
                 labels_list, labels_freq);
           })
       .def(
@@ -261,12 +268,22 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           [](fasttext::FastText& m,
              const std::string text,
              int32_t k,
-             fasttext::real threshold) {
+             fasttext::real threshold,
+             const char* onUnicodeError) {
             std::stringstream ioss(text);
             std::vector<std::pair<fasttext::real, std::string>> predictions;
             m.predictLine(ioss, predictions, k, threshold);
 
-            return predictions;
+            std::vector<std::pair<fasttext::real, py::str>>
+                transformedPredictions;
+
+            for (const auto& prediction : predictions) {
+              transformedPredictions.push_back(std::make_pair(
+                  prediction.first,
+                  castToPythonString(prediction.second, onUnicodeError)));
+            }
+
+            return transformedPredictions;
           })
       .def(
           "multilinePredict",
@@ -275,15 +292,23 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           [](fasttext::FastText& m,
              const std::vector<std::string>& lines,
              int32_t k,
-             fasttext::real threshold) {
-            std::vector<std::vector<std::pair<fasttext::real, std::string>>>
+             fasttext::real threshold,
+             const char* onUnicodeError) {
+            std::vector<std::vector<std::pair<fasttext::real, py::str>>>
                 allPredictions;
             std::vector<std::pair<fasttext::real, std::string>> predictions;
 
             for (const std::string& text : lines) {
               std::stringstream ioss(text);
               m.predictLine(ioss, predictions, k, threshold);
-              allPredictions.push_back(predictions);
+              std::vector<std::pair<fasttext::real, py::str>>
+                  transformedPredictions;
+              for (const auto& prediction : predictions) {
+                transformedPredictions.push_back(std::make_pair(
+                    prediction.first,
+                    castToPythonString(prediction.second, onUnicodeError)));
+              }
+              allPredictions.push_back(transformedPredictions);
             }
             return allPredictions;
           })
@@ -332,13 +357,22 @@ PYBIND11_MODULE(fasttext_pybind, m) {
              const std::string word) { m.getWordVector(vec, word); })
       .def(
           "getSubwords",
-          [](fasttext::FastText& m, const std::string word) {
+          [](fasttext::FastText& m,
+             const std::string word,
+             const char* onUnicodeError) {
             std::vector<std::string> subwords;
             std::vector<int32_t> ngrams;
             std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
             d->getSubwords(word, ngrams, subwords);
-            return std::pair<std::vector<std::string>, std::vector<int32_t>>(
-                subwords, ngrams);
+            std::vector<py::str> transformedSubwords;
+
+            for (const auto& subword : subwords) {
+              transformedSubwords.push_back(
+                  castToPythonString(subword, onUnicodeError));
+            }
+
+            return std::pair<std::vector<py::str>, std::vector<int32_t>>(
+                transformedSubwords, ngrams);
           })
       .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
 }