1
0

fasttext_pybind.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. /**
  2. * Copyright (c) 2017-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <args.h>
  9. #include <fasttext.h>
  10. #include <matrix.h>
  11. #include <pybind11/pybind11.h>
  12. #include <pybind11/stl.h>
  13. #include <real.h>
  14. #include <vector.h>
  15. #include <cmath>
  16. #include <iterator>
  17. #include <sstream>
  18. #include <stdexcept>
  19. using namespace pybind11::literals;
  20. namespace py = pybind11;
  21. py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
  22. PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
  23. if (!handle) {
  24. throw py::error_already_set();
  25. }
  26. return py::str(handle);
  27. }
  28. std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
  29. fasttext::FastText& m,
  30. const std::string text,
  31. const char* onUnicodeError) {
  32. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  33. std::stringstream ioss(text);
  34. std::string token;
  35. std::vector<py::str> words;
  36. std::vector<py::str> labels;
  37. while (d->readWord(ioss, token)) {
  38. uint32_t h = d->hash(token);
  39. int32_t wid = d->getId(token, h);
  40. fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
  41. if (type == fasttext::entry_type::word) {
  42. words.push_back(castToPythonString(token, onUnicodeError));
  43. // Labels must not be OOV!
  44. } else if (type == fasttext::entry_type::label && wid >= 0) {
  45. labels.push_back(castToPythonString(token, onUnicodeError));
  46. }
  47. if (token == fasttext::Dictionary::EOS)
  48. break;
  49. }
  50. return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
  51. }
  52. PYBIND11_MODULE(fasttext_pybind, m) {
  53. py::class_<fasttext::Args>(m, "args")
  54. .def(py::init<>())
  55. .def_readwrite("input", &fasttext::Args::input)
  56. .def_readwrite("output", &fasttext::Args::output)
  57. .def_readwrite("lr", &fasttext::Args::lr)
  58. .def_readwrite("lrUpdateRate", &fasttext::Args::lrUpdateRate)
  59. .def_readwrite("dim", &fasttext::Args::dim)
  60. .def_readwrite("ws", &fasttext::Args::ws)
  61. .def_readwrite("epoch", &fasttext::Args::epoch)
  62. .def_readwrite("minCount", &fasttext::Args::minCount)
  63. .def_readwrite("minCountLabel", &fasttext::Args::minCountLabel)
  64. .def_readwrite("neg", &fasttext::Args::neg)
  65. .def_readwrite("wordNgrams", &fasttext::Args::wordNgrams)
  66. .def_readwrite("loss", &fasttext::Args::loss)
  67. .def_readwrite("model", &fasttext::Args::model)
  68. .def_readwrite("bucket", &fasttext::Args::bucket)
  69. .def_readwrite("minn", &fasttext::Args::minn)
  70. .def_readwrite("maxn", &fasttext::Args::maxn)
  71. .def_readwrite("thread", &fasttext::Args::thread)
  72. .def_readwrite("t", &fasttext::Args::t)
  73. .def_readwrite("label", &fasttext::Args::label)
  74. .def_readwrite("verbose", &fasttext::Args::verbose)
  75. .def_readwrite("pretrainedVectors", &fasttext::Args::pretrainedVectors)
  76. .def_readwrite("saveOutput", &fasttext::Args::saveOutput)
  77. .def_readwrite("qout", &fasttext::Args::qout)
  78. .def_readwrite("retrain", &fasttext::Args::retrain)
  79. .def_readwrite("qnorm", &fasttext::Args::qnorm)
  80. .def_readwrite("cutoff", &fasttext::Args::cutoff)
  81. .def_readwrite("dsub", &fasttext::Args::dsub);
  82. py::enum_<fasttext::model_name>(m, "model_name")
  83. .value("cbow", fasttext::model_name::cbow)
  84. .value("skipgram", fasttext::model_name::sg)
  85. .value("supervised", fasttext::model_name::sup)
  86. .export_values();
  87. py::enum_<fasttext::loss_name>(m, "loss_name")
  88. .value("hs", fasttext::loss_name::hs)
  89. .value("ns", fasttext::loss_name::ns)
  90. .value("softmax", fasttext::loss_name::softmax)
  91. .value("ova", fasttext::loss_name::ova)
  92. .export_values();
  93. m.def(
  94. "train",
  95. [](fasttext::FastText& ft, fasttext::Args& a) { ft.train(a); },
  96. py::call_guard<py::gil_scoped_release>());
  97. py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
  98. .def(py::init<ssize_t>())
  99. .def_buffer([](fasttext::Vector& m) -> py::buffer_info {
  100. return py::buffer_info(
  101. m.data(),
  102. sizeof(fasttext::real),
  103. py::format_descriptor<fasttext::real>::format(),
  104. 1,
  105. {m.size()},
  106. {sizeof(fasttext::real)});
  107. });
  108. py::class_<fasttext::Matrix>(
  109. m, "Matrix", py::buffer_protocol(), py::module_local())
  110. .def(py::init<>())
  111. .def(py::init<ssize_t, ssize_t>())
  112. .def_buffer([](fasttext::Matrix& m) -> py::buffer_info {
  113. return py::buffer_info(
  114. m.data(),
  115. sizeof(fasttext::real),
  116. py::format_descriptor<fasttext::real>::format(),
  117. 2,
  118. {m.size(0), m.size(1)},
  119. {sizeof(fasttext::real) * m.size(1),
  120. sizeof(fasttext::real) * (int64_t)1});
  121. });
  122. py::class_<fasttext::FastText>(m, "fasttext")
  123. .def(py::init<>())
  124. .def("getArgs", &fasttext::FastText::getArgs)
  125. .def(
  126. "getInputMatrix",
  127. [](fasttext::FastText& m) {
  128. std::shared_ptr<const fasttext::Matrix> mm = m.getInputMatrix();
  129. return *mm.get();
  130. })
  131. .def(
  132. "getOutputMatrix",
  133. [](fasttext::FastText& m) {
  134. std::shared_ptr<const fasttext::Matrix> mm = m.getOutputMatrix();
  135. return *mm.get();
  136. })
  137. .def(
  138. "loadModel",
  139. [](fasttext::FastText& m, std::string s) { m.loadModel(s); })
  140. .def(
  141. "saveModel",
  142. [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
  143. .def(
  144. "test",
  145. [](fasttext::FastText& m, const std::string filename, int32_t k) {
  146. std::ifstream ifs(filename);
  147. if (!ifs.is_open()) {
  148. throw std::invalid_argument("Test file cannot be opened!");
  149. }
  150. fasttext::Meter meter;
  151. m.test(ifs, k, 0.0, meter);
  152. ifs.close();
  153. return std::tuple<int64_t, double, double>(
  154. meter.nexamples(), meter.precision(), meter.recall());
  155. })
  156. .def(
  157. "getSentenceVector",
  158. [](fasttext::FastText& m,
  159. fasttext::Vector& v,
  160. const std::string text) {
  161. std::stringstream ioss(text);
  162. m.getSentenceVector(ioss, v);
  163. })
  164. .def(
  165. "tokenize",
  166. [](fasttext::FastText& m, const std::string text) {
  167. std::vector<std::string> text_split;
  168. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  169. std::stringstream ioss(text);
  170. std::string token;
  171. while (!ioss.eof()) {
  172. while (d->readWord(ioss, token)) {
  173. text_split.push_back(token);
  174. }
  175. }
  176. return text_split;
  177. })
  178. .def("getLine", &getLineText)
  179. .def(
  180. "multilineGetLine",
  181. [](fasttext::FastText& m,
  182. const std::vector<std::string> lines,
  183. const char* onUnicodeError) {
  184. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  185. std::vector<std::vector<py::str>> all_words;
  186. std::vector<std::vector<py::str>> all_labels;
  187. for (const auto& text : lines) {
  188. auto pair = getLineText(m, text, onUnicodeError);
  189. all_words.push_back(pair.first);
  190. all_labels.push_back(pair.second);
  191. }
  192. return std::pair<
  193. std::vector<std::vector<py::str>>,
  194. std::vector<std::vector<py::str>>>(all_words, all_labels);
  195. })
  196. .def(
  197. "getVocab",
  198. [](fasttext::FastText& m, const char* onUnicodeError) {
  199. py::str s;
  200. std::vector<py::str> vocab_list;
  201. std::vector<int64_t> vocab_freq;
  202. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  203. vocab_freq = d->getCounts(fasttext::entry_type::word);
  204. for (int32_t i = 0; i < vocab_freq.size(); i++) {
  205. vocab_list.push_back(
  206. castToPythonString(d->getWord(i), onUnicodeError));
  207. }
  208. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  209. vocab_list, vocab_freq);
  210. })
  211. .def(
  212. "getLabels",
  213. [](fasttext::FastText& m, const char* onUnicodeError) {
  214. std::vector<py::str> labels_list;
  215. std::vector<int64_t> labels_freq;
  216. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  217. labels_freq = d->getCounts(fasttext::entry_type::label);
  218. for (int32_t i = 0; i < labels_freq.size(); i++) {
  219. labels_list.push_back(
  220. castToPythonString(d->getLabel(i), onUnicodeError));
  221. }
  222. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  223. labels_list, labels_freq);
  224. })
  225. .def(
  226. "quantize",
  227. [](fasttext::FastText& m,
  228. const std::string input,
  229. bool qout,
  230. int32_t cutoff,
  231. bool retrain,
  232. int epoch,
  233. double lr,
  234. int thread,
  235. int verbose,
  236. int32_t dsub,
  237. bool qnorm) {
  238. fasttext::Args qa = fasttext::Args();
  239. qa.input = input;
  240. qa.qout = qout;
  241. qa.cutoff = cutoff;
  242. qa.retrain = retrain;
  243. qa.epoch = epoch;
  244. qa.lr = lr;
  245. qa.thread = thread;
  246. qa.verbose = verbose;
  247. qa.dsub = dsub;
  248. qa.qnorm = qnorm;
  249. m.quantize(qa);
  250. })
  251. .def(
  252. "predict",
  253. // NOTE: text needs to end in a newline
  254. // to exactly mimic the behavior of the cli
  255. [](fasttext::FastText& m,
  256. const std::string text,
  257. int32_t k,
  258. fasttext::real threshold,
  259. const char* onUnicodeError) {
  260. std::stringstream ioss(text);
  261. std::vector<std::pair<fasttext::real, std::string>> predictions;
  262. m.predictLine(ioss, predictions, k, threshold);
  263. std::vector<std::pair<fasttext::real, py::str>>
  264. transformedPredictions;
  265. for (const auto& prediction : predictions) {
  266. transformedPredictions.push_back(std::make_pair(
  267. prediction.first,
  268. castToPythonString(prediction.second, onUnicodeError)));
  269. }
  270. return transformedPredictions;
  271. })
  272. .def(
  273. "multilinePredict",
  274. // NOTE: text needs to end in a newline
  275. // to exactly mimic the behavior of the cli
  276. [](fasttext::FastText& m,
  277. const std::vector<std::string>& lines,
  278. int32_t k,
  279. fasttext::real threshold,
  280. const char* onUnicodeError) {
  281. std::vector<std::vector<std::pair<fasttext::real, py::str>>>
  282. allPredictions;
  283. std::vector<std::pair<fasttext::real, std::string>> predictions;
  284. for (const std::string& text : lines) {
  285. std::stringstream ioss(text);
  286. m.predictLine(ioss, predictions, k, threshold);
  287. std::vector<std::pair<fasttext::real, py::str>>
  288. transformedPredictions;
  289. for (const auto& prediction : predictions) {
  290. transformedPredictions.push_back(std::make_pair(
  291. prediction.first,
  292. castToPythonString(prediction.second, onUnicodeError)));
  293. }
  294. allPredictions.push_back(transformedPredictions);
  295. }
  296. return allPredictions;
  297. })
  298. .def(
  299. "testLabel",
  300. [](fasttext::FastText& m,
  301. const std::string filename,
  302. int32_t k,
  303. fasttext::real threshold) {
  304. std::ifstream ifs(filename);
  305. if (!ifs.is_open()) {
  306. throw std::invalid_argument("Test file cannot be opened!");
  307. }
  308. fasttext::Meter meter;
  309. m.test(ifs, k, threshold, meter);
  310. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  311. std::unordered_map<std::string, py::dict> returnedValue;
  312. for (int32_t i = 0; i < d->nlabels(); i++) {
  313. returnedValue[d->getLabel(i)] = py::dict(
  314. "precision"_a = meter.precision(i),
  315. "recall"_a = meter.recall(i),
  316. "f1score"_a = meter.f1Score(i));
  317. }
  318. return returnedValue;
  319. })
  320. .def(
  321. "getWordId",
  322. [](fasttext::FastText& m, const std::string word) {
  323. return m.getWordId(word);
  324. })
  325. .def(
  326. "getSubwordId",
  327. [](fasttext::FastText& m, const std::string word) {
  328. return m.getSubwordId(word);
  329. })
  330. .def(
  331. "getInputVector",
  332. [](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {
  333. m.getInputVector(vec, ind);
  334. })
  335. .def(
  336. "getWordVector",
  337. [](fasttext::FastText& m,
  338. fasttext::Vector& vec,
  339. const std::string word) { m.getWordVector(vec, word); })
  340. .def(
  341. "getSubwords",
  342. [](fasttext::FastText& m,
  343. const std::string word,
  344. const char* onUnicodeError) {
  345. std::vector<std::string> subwords;
  346. std::vector<int32_t> ngrams;
  347. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  348. d->getSubwords(word, ngrams, subwords);
  349. std::vector<py::str> transformedSubwords;
  350. for (const auto& subword : subwords) {
  351. transformedSubwords.push_back(
  352. castToPythonString(subword, onUnicodeError));
  353. }
  354. return std::pair<std::vector<py::str>, std::vector<int32_t>>(
  355. transformedSubwords, ngrams);
  356. })
  357. .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
  358. }