fasttext_pybind.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. /**
  2. * Copyright (c) 2017-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <args.h>
  9. #include <densematrix.h>
  10. #include <fasttext.h>
  11. #include <pybind11/pybind11.h>
  12. #include <pybind11/stl.h>
  13. #include <real.h>
  14. #include <vector.h>
  15. #include <cmath>
  16. #include <iterator>
  17. #include <sstream>
  18. #include <stdexcept>
  19. using namespace pybind11::literals;
  20. namespace py = pybind11;
  21. py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
  22. PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
  23. if (!handle) {
  24. throw py::error_already_set();
  25. }
  26. // py::str's constructor from a PyObject assumes the string has been encoded
  27. // for python 2 and not encoded for python 3 :
  28. // https://github.com/pybind/pybind11/blob/ccbe68b084806dece5863437a7dc93de20bd9b15/include/pybind11/pytypes.h#L930
  29. #if PY_MAJOR_VERSION < 3
  30. handle = PyUnicode_AsEncodedString(handle, "utf-8", onUnicodeError);
  31. #endif
  32. return py::str(handle);
  33. }
  34. std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
  35. fasttext::FastText& m,
  36. const std::string text,
  37. const char* onUnicodeError) {
  38. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  39. std::stringstream ioss(text);
  40. std::string token;
  41. std::vector<py::str> words;
  42. std::vector<py::str> labels;
  43. while (d->readWord(ioss, token)) {
  44. uint32_t h = d->hash(token);
  45. int32_t wid = d->getId(token, h);
  46. fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
  47. if (type == fasttext::entry_type::word) {
  48. words.push_back(castToPythonString(token, onUnicodeError));
  49. // Labels must not be OOV!
  50. } else if (type == fasttext::entry_type::label && wid >= 0) {
  51. labels.push_back(castToPythonString(token, onUnicodeError));
  52. }
  53. if (token == fasttext::Dictionary::EOS)
  54. break;
  55. }
  56. return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
  57. }
  58. PYBIND11_MODULE(fasttext_pybind, m) {
  59. py::class_<fasttext::Args>(m, "args")
  60. .def(py::init<>())
  61. .def_readwrite("input", &fasttext::Args::input)
  62. .def_readwrite("output", &fasttext::Args::output)
  63. .def_readwrite("lr", &fasttext::Args::lr)
  64. .def_readwrite("lrUpdateRate", &fasttext::Args::lrUpdateRate)
  65. .def_readwrite("dim", &fasttext::Args::dim)
  66. .def_readwrite("ws", &fasttext::Args::ws)
  67. .def_readwrite("epoch", &fasttext::Args::epoch)
  68. .def_readwrite("minCount", &fasttext::Args::minCount)
  69. .def_readwrite("minCountLabel", &fasttext::Args::minCountLabel)
  70. .def_readwrite("neg", &fasttext::Args::neg)
  71. .def_readwrite("wordNgrams", &fasttext::Args::wordNgrams)
  72. .def_readwrite("loss", &fasttext::Args::loss)
  73. .def_readwrite("model", &fasttext::Args::model)
  74. .def_readwrite("bucket", &fasttext::Args::bucket)
  75. .def_readwrite("minn", &fasttext::Args::minn)
  76. .def_readwrite("maxn", &fasttext::Args::maxn)
  77. .def_readwrite("thread", &fasttext::Args::thread)
  78. .def_readwrite("t", &fasttext::Args::t)
  79. .def_readwrite("label", &fasttext::Args::label)
  80. .def_readwrite("verbose", &fasttext::Args::verbose)
  81. .def_readwrite("pretrainedVectors", &fasttext::Args::pretrainedVectors)
  82. .def_readwrite("saveOutput", &fasttext::Args::saveOutput)
  83. .def_readwrite("qout", &fasttext::Args::qout)
  84. .def_readwrite("retrain", &fasttext::Args::retrain)
  85. .def_readwrite("qnorm", &fasttext::Args::qnorm)
  86. .def_readwrite("cutoff", &fasttext::Args::cutoff)
  87. .def_readwrite("dsub", &fasttext::Args::dsub);
  88. py::enum_<fasttext::model_name>(m, "model_name")
  89. .value("cbow", fasttext::model_name::cbow)
  90. .value("skipgram", fasttext::model_name::sg)
  91. .value("supervised", fasttext::model_name::sup)
  92. .export_values();
  93. py::enum_<fasttext::loss_name>(m, "loss_name")
  94. .value("hs", fasttext::loss_name::hs)
  95. .value("ns", fasttext::loss_name::ns)
  96. .value("softmax", fasttext::loss_name::softmax)
  97. .value("ova", fasttext::loss_name::ova)
  98. .export_values();
  99. m.def(
  100. "train",
  101. [](fasttext::FastText& ft, fasttext::Args& a) { ft.train(a); },
  102. py::call_guard<py::gil_scoped_release>());
  103. py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
  104. .def(py::init<ssize_t>())
  105. .def_buffer([](fasttext::Vector& m) -> py::buffer_info {
  106. return py::buffer_info(
  107. m.data(),
  108. sizeof(fasttext::real),
  109. py::format_descriptor<fasttext::real>::format(),
  110. 1,
  111. {m.size()},
  112. {sizeof(fasttext::real)});
  113. });
  114. py::class_<fasttext::DenseMatrix>(
  115. m, "DenseMatrix", py::buffer_protocol(), py::module_local())
  116. .def(py::init<>())
  117. .def(py::init<ssize_t, ssize_t>())
  118. .def_buffer([](fasttext::DenseMatrix& m) -> py::buffer_info {
  119. return py::buffer_info(
  120. m.data(),
  121. sizeof(fasttext::real),
  122. py::format_descriptor<fasttext::real>::format(),
  123. 2,
  124. {m.size(0), m.size(1)},
  125. {sizeof(fasttext::real) * m.size(1),
  126. sizeof(fasttext::real) * (int64_t)1});
  127. });
  128. py::class_<fasttext::FastText>(m, "fasttext")
  129. .def(py::init<>())
  130. .def("getArgs", &fasttext::FastText::getArgs)
  131. .def(
  132. "getInputMatrix",
  133. [](fasttext::FastText& m) {
  134. std::shared_ptr<const fasttext::DenseMatrix> mm =
  135. m.getInputMatrix();
  136. return *mm.get();
  137. })
  138. .def(
  139. "getOutputMatrix",
  140. [](fasttext::FastText& m) {
  141. std::shared_ptr<const fasttext::DenseMatrix> mm =
  142. m.getOutputMatrix();
  143. return *mm.get();
  144. })
  145. .def(
  146. "loadModel",
  147. [](fasttext::FastText& m, std::string s) { m.loadModel(s); })
  148. .def(
  149. "saveModel",
  150. [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
  151. .def(
  152. "test",
  153. [](fasttext::FastText& m, const std::string filename, int32_t k) {
  154. std::ifstream ifs(filename);
  155. if (!ifs.is_open()) {
  156. throw std::invalid_argument("Test file cannot be opened!");
  157. }
  158. fasttext::Meter meter;
  159. m.test(ifs, k, 0.0, meter);
  160. ifs.close();
  161. return std::tuple<int64_t, double, double>(
  162. meter.nexamples(), meter.precision(), meter.recall());
  163. })
  164. .def(
  165. "getSentenceVector",
  166. [](fasttext::FastText& m,
  167. fasttext::Vector& v,
  168. const std::string text) {
  169. std::stringstream ioss(text);
  170. m.getSentenceVector(ioss, v);
  171. })
  172. .def(
  173. "tokenize",
  174. [](fasttext::FastText& m, const std::string text) {
  175. std::vector<std::string> text_split;
  176. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  177. std::stringstream ioss(text);
  178. std::string token;
  179. while (!ioss.eof()) {
  180. while (d->readWord(ioss, token)) {
  181. text_split.push_back(token);
  182. }
  183. }
  184. return text_split;
  185. })
  186. .def("getLine", &getLineText)
  187. .def(
  188. "multilineGetLine",
  189. [](fasttext::FastText& m,
  190. const std::vector<std::string> lines,
  191. const char* onUnicodeError) {
  192. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  193. std::vector<std::vector<py::str>> all_words;
  194. std::vector<std::vector<py::str>> all_labels;
  195. for (const auto& text : lines) {
  196. auto pair = getLineText(m, text, onUnicodeError);
  197. all_words.push_back(pair.first);
  198. all_labels.push_back(pair.second);
  199. }
  200. return std::pair<
  201. std::vector<std::vector<py::str>>,
  202. std::vector<std::vector<py::str>>>(all_words, all_labels);
  203. })
  204. .def(
  205. "getVocab",
  206. [](fasttext::FastText& m, const char* onUnicodeError) {
  207. py::str s;
  208. std::vector<py::str> vocab_list;
  209. std::vector<int64_t> vocab_freq;
  210. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  211. vocab_freq = d->getCounts(fasttext::entry_type::word);
  212. for (int32_t i = 0; i < vocab_freq.size(); i++) {
  213. vocab_list.push_back(
  214. castToPythonString(d->getWord(i), onUnicodeError));
  215. }
  216. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  217. vocab_list, vocab_freq);
  218. })
  219. .def(
  220. "getLabels",
  221. [](fasttext::FastText& m, const char* onUnicodeError) {
  222. std::vector<py::str> labels_list;
  223. std::vector<int64_t> labels_freq;
  224. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  225. labels_freq = d->getCounts(fasttext::entry_type::label);
  226. for (int32_t i = 0; i < labels_freq.size(); i++) {
  227. labels_list.push_back(
  228. castToPythonString(d->getLabel(i), onUnicodeError));
  229. }
  230. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  231. labels_list, labels_freq);
  232. })
  233. .def(
  234. "quantize",
  235. [](fasttext::FastText& m,
  236. const std::string input,
  237. bool qout,
  238. int32_t cutoff,
  239. bool retrain,
  240. int epoch,
  241. double lr,
  242. int thread,
  243. int verbose,
  244. int32_t dsub,
  245. bool qnorm) {
  246. fasttext::Args qa = fasttext::Args();
  247. qa.input = input;
  248. qa.qout = qout;
  249. qa.cutoff = cutoff;
  250. qa.retrain = retrain;
  251. qa.epoch = epoch;
  252. qa.lr = lr;
  253. qa.thread = thread;
  254. qa.verbose = verbose;
  255. qa.dsub = dsub;
  256. qa.qnorm = qnorm;
  257. m.quantize(qa);
  258. })
  259. .def(
  260. "predict",
  261. // NOTE: text needs to end in a newline
  262. // to exactly mimic the behavior of the cli
  263. [](fasttext::FastText& m,
  264. const std::string text,
  265. int32_t k,
  266. fasttext::real threshold,
  267. const char* onUnicodeError) {
  268. std::stringstream ioss(text);
  269. std::vector<std::pair<fasttext::real, std::string>> predictions;
  270. m.predictLine(ioss, predictions, k, threshold);
  271. std::vector<std::pair<fasttext::real, py::str>>
  272. transformedPredictions;
  273. for (const auto& prediction : predictions) {
  274. transformedPredictions.push_back(std::make_pair(
  275. prediction.first,
  276. castToPythonString(prediction.second, onUnicodeError)));
  277. }
  278. return transformedPredictions;
  279. })
  280. .def(
  281. "multilinePredict",
  282. // NOTE: text needs to end in a newline
  283. // to exactly mimic the behavior of the cli
  284. [](fasttext::FastText& m,
  285. const std::vector<std::string>& lines,
  286. int32_t k,
  287. fasttext::real threshold,
  288. const char* onUnicodeError) {
  289. std::vector<std::vector<std::pair<fasttext::real, py::str>>>
  290. allPredictions;
  291. std::vector<std::pair<fasttext::real, std::string>> predictions;
  292. for (const std::string& text : lines) {
  293. std::stringstream ioss(text);
  294. m.predictLine(ioss, predictions, k, threshold);
  295. std::vector<std::pair<fasttext::real, py::str>>
  296. transformedPredictions;
  297. for (const auto& prediction : predictions) {
  298. transformedPredictions.push_back(std::make_pair(
  299. prediction.first,
  300. castToPythonString(prediction.second, onUnicodeError)));
  301. }
  302. allPredictions.push_back(transformedPredictions);
  303. }
  304. return allPredictions;
  305. })
  306. .def(
  307. "testLabel",
  308. [](fasttext::FastText& m,
  309. const std::string filename,
  310. int32_t k,
  311. fasttext::real threshold) {
  312. std::ifstream ifs(filename);
  313. if (!ifs.is_open()) {
  314. throw std::invalid_argument("Test file cannot be opened!");
  315. }
  316. fasttext::Meter meter;
  317. m.test(ifs, k, threshold, meter);
  318. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  319. std::unordered_map<std::string, py::dict> returnedValue;
  320. for (int32_t i = 0; i < d->nlabels(); i++) {
  321. returnedValue[d->getLabel(i)] = py::dict(
  322. "precision"_a = meter.precision(i),
  323. "recall"_a = meter.recall(i),
  324. "f1score"_a = meter.f1Score(i));
  325. }
  326. return returnedValue;
  327. })
  328. .def(
  329. "getWordId",
  330. [](fasttext::FastText& m, const std::string word) {
  331. return m.getWordId(word);
  332. })
  333. .def(
  334. "getSubwordId",
  335. [](fasttext::FastText& m, const std::string word) {
  336. return m.getSubwordId(word);
  337. })
  338. .def(
  339. "getInputVector",
  340. [](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {
  341. m.getInputVector(vec, ind);
  342. })
  343. .def(
  344. "getWordVector",
  345. [](fasttext::FastText& m,
  346. fasttext::Vector& vec,
  347. const std::string word) { m.getWordVector(vec, word); })
  348. .def(
  349. "getSubwords",
  350. [](fasttext::FastText& m,
  351. const std::string word,
  352. const char* onUnicodeError) {
  353. std::vector<std::string> subwords;
  354. std::vector<int32_t> ngrams;
  355. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  356. d->getSubwords(word, ngrams, subwords);
  357. std::vector<py::str> transformedSubwords;
  358. for (const auto& subword : subwords) {
  359. transformedSubwords.push_back(
  360. castToPythonString(subword, onUnicodeError));
  361. }
  362. return std::pair<std::vector<py::str>, std::vector<int32_t>>(
  363. transformedSubwords, ngrams);
  364. })
  365. .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
  366. }