|
|
@@ -19,35 +19,42 @@
|
|
|
#include <stdexcept>
|
|
|
|
|
|
using namespace pybind11::literals;
|
|
|
+namespace py = pybind11;
|
|
|
+
|
|
|
+py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
|
|
|
+ PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
|
|
|
+ if (!handle) {
|
|
|
+ throw py::error_already_set();
|
|
|
+ }
|
|
|
+ return py::str(handle);
|
|
|
+}
|
|
|
|
|
|
-std::pair<std::vector<std::string>, std::vector<std::string>> getLineText(
|
|
|
+std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
|
|
|
fasttext::FastText& m,
|
|
|
- const std::string text) {
|
|
|
+ const std::string text,
|
|
|
+ const char* onUnicodeError) {
|
|
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
|
|
std::stringstream ioss(text);
|
|
|
std::string token;
|
|
|
- std::vector<std::string> words;
|
|
|
- std::vector<std::string> labels;
|
|
|
+ std::vector<py::str> words;
|
|
|
+ std::vector<py::str> labels;
|
|
|
while (d->readWord(ioss, token)) {
|
|
|
uint32_t h = d->hash(token);
|
|
|
int32_t wid = d->getId(token, h);
|
|
|
fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
|
|
|
|
|
|
if (type == fasttext::entry_type::word) {
|
|
|
- words.push_back(token);
|
|
|
+ words.push_back(castToPythonString(token, onUnicodeError));
|
|
|
// Labels must not be OOV!
|
|
|
} else if (type == fasttext::entry_type::label && wid >= 0) {
|
|
|
- labels.push_back(token);
|
|
|
+ labels.push_back(castToPythonString(token, onUnicodeError));
|
|
|
}
|
|
|
if (token == fasttext::Dictionary::EOS)
|
|
|
break;
|
|
|
}
|
|
|
- return std::pair<std::vector<std::string>, std::vector<std::string>>(
|
|
|
- words, labels);
|
|
|
+ return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
|
|
|
}
|
|
|
|
|
|
-namespace py = pybind11;
|
|
|
-
|
|
|
PYBIND11_MODULE(fasttext_pybind, m) {
|
|
|
py::class_<fasttext::Args>(m, "args")
|
|
|
.def(py::init<>())
|
|
|
@@ -184,48 +191,48 @@ PYBIND11_MODULE(fasttext_pybind, m) {
|
|
|
.def("getLine", &getLineText)
|
|
|
.def(
|
|
|
"multilineGetLine",
|
|
|
- [](fasttext::FastText& m, const std::vector<std::string> lines) {
|
|
|
+ [](fasttext::FastText& m,
|
|
|
+ const std::vector<std::string> lines,
|
|
|
+ const char* onUnicodeError) {
|
|
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
|
|
- std::vector<std::vector<std::string>> all_words;
|
|
|
- std::vector<std::vector<std::string>> all_labels;
|
|
|
- std::vector<std::string> words;
|
|
|
- std::vector<std::string> labels;
|
|
|
- std::string token;
|
|
|
+ std::vector<std::vector<py::str>> all_words;
|
|
|
+ std::vector<std::vector<py::str>> all_labels;
|
|
|
for (const auto& text : lines) {
|
|
|
- auto pair = getLineText(m, text);
|
|
|
+ auto pair = getLineText(m, text, onUnicodeError);
|
|
|
all_words.push_back(pair.first);
|
|
|
all_labels.push_back(pair.second);
|
|
|
}
|
|
|
return std::pair<
|
|
|
- std::vector<std::vector<std::string>>,
|
|
|
- std::vector<std::vector<std::string>>>(all_words, all_labels);
|
|
|
+ std::vector<std::vector<py::str>>,
|
|
|
+ std::vector<std::vector<py::str>>>(all_words, all_labels);
|
|
|
})
|
|
|
.def(
|
|
|
"getVocab",
|
|
|
- [](fasttext::FastText& m) {
|
|
|
- std::vector<std::string> vocab_list;
|
|
|
+ [](fasttext::FastText& m, const char* onUnicodeError) {
|
|
|
+ py::str s;
|
|
|
+ std::vector<py::str> vocab_list;
|
|
|
std::vector<int64_t> vocab_freq;
|
|
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
|
|
vocab_freq = d->getCounts(fasttext::entry_type::word);
|
|
|
- vocab_list.clear();
|
|
|
for (int32_t i = 0; i < vocab_freq.size(); i++) {
|
|
|
- vocab_list.push_back(d->getWord(i));
|
|
|
+ vocab_list.push_back(
|
|
|
+ castToPythonString(d->getWord(i), onUnicodeError));
|
|
|
}
|
|
|
- return std::pair<std::vector<std::string>, std::vector<int64_t>>(
|
|
|
+ return std::pair<std::vector<py::str>, std::vector<int64_t>>(
|
|
|
vocab_list, vocab_freq);
|
|
|
})
|
|
|
.def(
|
|
|
"getLabels",
|
|
|
- [](fasttext::FastText& m) {
|
|
|
- std::vector<std::string> labels_list;
|
|
|
+ [](fasttext::FastText& m, const char* onUnicodeError) {
|
|
|
+ std::vector<py::str> labels_list;
|
|
|
std::vector<int64_t> labels_freq;
|
|
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
|
|
labels_freq = d->getCounts(fasttext::entry_type::label);
|
|
|
- labels_list.clear();
|
|
|
for (int32_t i = 0; i < labels_freq.size(); i++) {
|
|
|
- labels_list.push_back(d->getLabel(i));
|
|
|
+ labels_list.push_back(
|
|
|
+ castToPythonString(d->getLabel(i), onUnicodeError));
|
|
|
}
|
|
|
- return std::pair<std::vector<std::string>, std::vector<int64_t>>(
|
|
|
+ return std::pair<std::vector<py::str>, std::vector<int64_t>>(
|
|
|
labels_list, labels_freq);
|
|
|
})
|
|
|
.def(
|
|
|
@@ -261,12 +268,22 @@ PYBIND11_MODULE(fasttext_pybind, m) {
|
|
|
[](fasttext::FastText& m,
|
|
|
const std::string text,
|
|
|
int32_t k,
|
|
|
- fasttext::real threshold) {
|
|
|
+ fasttext::real threshold,
|
|
|
+ const char* onUnicodeError) {
|
|
|
std::stringstream ioss(text);
|
|
|
std::vector<std::pair<fasttext::real, std::string>> predictions;
|
|
|
m.predictLine(ioss, predictions, k, threshold);
|
|
|
|
|
|
- return predictions;
|
|
|
+ std::vector<std::pair<fasttext::real, py::str>>
|
|
|
+ transformedPredictions;
|
|
|
+
|
|
|
+ for (const auto& prediction : predictions) {
|
|
|
+ transformedPredictions.push_back(std::make_pair(
|
|
|
+ prediction.first,
|
|
|
+ castToPythonString(prediction.second, onUnicodeError)));
|
|
|
+ }
|
|
|
+
|
|
|
+ return transformedPredictions;
|
|
|
})
|
|
|
.def(
|
|
|
"multilinePredict",
|
|
|
@@ -275,15 +292,23 @@ PYBIND11_MODULE(fasttext_pybind, m) {
|
|
|
[](fasttext::FastText& m,
|
|
|
const std::vector<std::string>& lines,
|
|
|
int32_t k,
|
|
|
- fasttext::real threshold) {
|
|
|
- std::vector<std::vector<std::pair<fasttext::real, std::string>>>
|
|
|
+ fasttext::real threshold,
|
|
|
+ const char* onUnicodeError) {
|
|
|
+ std::vector<std::vector<std::pair<fasttext::real, py::str>>>
|
|
|
allPredictions;
|
|
|
std::vector<std::pair<fasttext::real, std::string>> predictions;
|
|
|
|
|
|
for (const std::string& text : lines) {
|
|
|
std::stringstream ioss(text);
|
|
|
m.predictLine(ioss, predictions, k, threshold);
|
|
|
- allPredictions.push_back(predictions);
|
|
|
+ std::vector<std::pair<fasttext::real, py::str>>
|
|
|
+ transformedPredictions;
|
|
|
+ for (const auto& prediction : predictions) {
|
|
|
+ transformedPredictions.push_back(std::make_pair(
|
|
|
+ prediction.first,
|
|
|
+ castToPythonString(prediction.second, onUnicodeError)));
|
|
|
+ }
|
|
|
+ allPredictions.push_back(transformedPredictions);
|
|
|
}
|
|
|
return allPredictions;
|
|
|
})
|
|
|
@@ -332,13 +357,22 @@ PYBIND11_MODULE(fasttext_pybind, m) {
|
|
|
const std::string word) { m.getWordVector(vec, word); })
|
|
|
.def(
|
|
|
"getSubwords",
|
|
|
- [](fasttext::FastText& m, const std::string word) {
|
|
|
+ [](fasttext::FastText& m,
|
|
|
+ const std::string word,
|
|
|
+ const char* onUnicodeError) {
|
|
|
std::vector<std::string> subwords;
|
|
|
std::vector<int32_t> ngrams;
|
|
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
|
|
d->getSubwords(word, ngrams, subwords);
|
|
|
- return std::pair<std::vector<std::string>, std::vector<int32_t>>(
|
|
|
- subwords, ngrams);
|
|
|
+ std::vector<py::str> transformedSubwords;
|
|
|
+
|
|
|
+ for (const auto& subword : subwords) {
|
|
|
+ transformedSubwords.push_back(
|
|
|
+ castToPythonString(subword, onUnicodeError));
|
|
|
+ }
|
|
|
+
|
|
|
+ return std::pair<std::vector<py::str>, std::vector<int32_t>>(
|
|
|
+ transformedSubwords, ngrams);
|
|
|
})
|
|
|
.def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
|
|
|
}
|