fasttext_pybind.cc 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. /**
  2. * Copyright (c) 2017-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <args.h>
  9. #include <autotune.h>
  10. #include <densematrix.h>
  11. #include <fasttext.h>
  12. #include <pybind11/numpy.h>
  13. #include <pybind11/pybind11.h>
  14. #include <pybind11/stl.h>
  15. #include <real.h>
  16. #include <vector.h>
  17. #include <cmath>
  18. #include <iterator>
  19. #include <sstream>
  20. #include <stdexcept>
  21. using namespace pybind11::literals;
  22. namespace py = pybind11;
  23. py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
  24. PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
  25. if (!handle) {
  26. throw py::error_already_set();
  27. }
  28. // py::str's constructor from a PyObject assumes the string has been encoded
  29. // for python 2 and not encoded for python 3 :
  30. // https://github.com/pybind/pybind11/blob/ccbe68b084806dece5863437a7dc93de20bd9b15/include/pybind11/pytypes.h#L930
  31. #if PY_MAJOR_VERSION < 3
  32. PyObject* handle_encoded =
  33. PyUnicode_AsEncodedString(handle, "utf-8", onUnicodeError);
  34. Py_DECREF(handle);
  35. handle = handle_encoded;
  36. #endif
  37. py::str handle_str = py::str(handle);
  38. Py_DECREF(handle);
  39. return handle_str;
  40. }
  41. std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
  42. fasttext::FastText& m,
  43. const std::string text,
  44. const char* onUnicodeError) {
  45. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  46. std::stringstream ioss(text);
  47. std::string token;
  48. std::vector<py::str> words;
  49. std::vector<py::str> labels;
  50. while (d->readWord(ioss, token)) {
  51. uint32_t h = d->hash(token);
  52. int32_t wid = d->getId(token, h);
  53. fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
  54. if (type == fasttext::entry_type::word) {
  55. words.push_back(castToPythonString(token, onUnicodeError));
  56. // Labels must not be OOV!
  57. } else if (type == fasttext::entry_type::label && wid >= 0) {
  58. labels.push_back(castToPythonString(token, onUnicodeError));
  59. }
  60. if (token == fasttext::Dictionary::EOS)
  61. break;
  62. }
  63. return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
  64. }
  65. PYBIND11_MODULE(fasttext_pybind, m) {
  66. py::class_<fasttext::Args>(m, "args")
  67. .def(py::init<>())
  68. .def_readwrite("input", &fasttext::Args::input)
  69. .def_readwrite("output", &fasttext::Args::output)
  70. .def_readwrite("lr", &fasttext::Args::lr)
  71. .def_readwrite("lrUpdateRate", &fasttext::Args::lrUpdateRate)
  72. .def_readwrite("dim", &fasttext::Args::dim)
  73. .def_readwrite("ws", &fasttext::Args::ws)
  74. .def_readwrite("epoch", &fasttext::Args::epoch)
  75. .def_readwrite("minCount", &fasttext::Args::minCount)
  76. .def_readwrite("minCountLabel", &fasttext::Args::minCountLabel)
  77. .def_readwrite("neg", &fasttext::Args::neg)
  78. .def_readwrite("wordNgrams", &fasttext::Args::wordNgrams)
  79. .def_readwrite("loss", &fasttext::Args::loss)
  80. .def_readwrite("model", &fasttext::Args::model)
  81. .def_readwrite("bucket", &fasttext::Args::bucket)
  82. .def_readwrite("minn", &fasttext::Args::minn)
  83. .def_readwrite("maxn", &fasttext::Args::maxn)
  84. .def_readwrite("thread", &fasttext::Args::thread)
  85. .def_readwrite("t", &fasttext::Args::t)
  86. .def_readwrite("label", &fasttext::Args::label)
  87. .def_readwrite("verbose", &fasttext::Args::verbose)
  88. .def_readwrite("pretrainedVectors", &fasttext::Args::pretrainedVectors)
  89. .def_readwrite("saveOutput", &fasttext::Args::saveOutput)
  90. .def_readwrite("seed", &fasttext::Args::seed)
  91. .def_readwrite("qout", &fasttext::Args::qout)
  92. .def_readwrite("retrain", &fasttext::Args::retrain)
  93. .def_readwrite("qnorm", &fasttext::Args::qnorm)
  94. .def_readwrite("cutoff", &fasttext::Args::cutoff)
  95. .def_readwrite("dsub", &fasttext::Args::dsub)
  96. .def_readwrite(
  97. "autotuneValidationFile", &fasttext::Args::autotuneValidationFile)
  98. .def_readwrite("autotuneMetric", &fasttext::Args::autotuneMetric)
  99. .def_readwrite(
  100. "autotunePredictions", &fasttext::Args::autotunePredictions)
  101. .def_readwrite("autotuneDuration", &fasttext::Args::autotuneDuration)
  102. .def_readwrite("autotuneModelSize", &fasttext::Args::autotuneModelSize)
  103. .def("setManual", [](fasttext::Args& m, const std::string& argName) {
  104. m.setManual(argName);
  105. });
  106. py::enum_<fasttext::model_name>(m, "model_name")
  107. .value("cbow", fasttext::model_name::cbow)
  108. .value("skipgram", fasttext::model_name::sg)
  109. .value("supervised", fasttext::model_name::sup)
  110. .export_values();
  111. py::enum_<fasttext::loss_name>(m, "loss_name")
  112. .value("hs", fasttext::loss_name::hs)
  113. .value("ns", fasttext::loss_name::ns)
  114. .value("softmax", fasttext::loss_name::softmax)
  115. .value("ova", fasttext::loss_name::ova)
  116. .export_values();
  117. py::enum_<fasttext::metric_name>(m, "metric_name")
  118. .value("f1score", fasttext::metric_name::f1score)
  119. .value("labelf1score", fasttext::metric_name::labelf1score)
  120. .export_values();
  121. m.def(
  122. "train",
  123. [](fasttext::FastText& ft, fasttext::Args& a) {
  124. if (a.hasAutotune()) {
  125. fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(
  126. &ft, [](fasttext::FastText*) {}));
  127. autotune.train(a);
  128. } else {
  129. ft.train(a);
  130. }
  131. },
  132. py::call_guard<py::gil_scoped_release>());
  133. py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
  134. .def(py::init<ssize_t>())
  135. .def_buffer([](fasttext::Vector& m) -> py::buffer_info {
  136. return py::buffer_info(
  137. m.data(),
  138. sizeof(fasttext::real),
  139. py::format_descriptor<fasttext::real>::format(),
  140. 1,
  141. {m.size()},
  142. {sizeof(fasttext::real)});
  143. });
  144. py::class_<fasttext::DenseMatrix>(
  145. m, "DenseMatrix", py::buffer_protocol(), py::module_local())
  146. .def(py::init<>())
  147. .def(py::init<ssize_t, ssize_t>())
  148. .def_buffer([](fasttext::DenseMatrix& m) -> py::buffer_info {
  149. return py::buffer_info(
  150. m.data(),
  151. sizeof(fasttext::real),
  152. py::format_descriptor<fasttext::real>::format(),
  153. 2,
  154. {m.size(0), m.size(1)},
  155. {sizeof(fasttext::real) * m.size(1),
  156. sizeof(fasttext::real) * (int64_t)1});
  157. });
  158. py::class_<fasttext::FastText>(m, "fasttext")
  159. .def(py::init<>())
  160. .def("getArgs", &fasttext::FastText::getArgs)
  161. .def(
  162. "getInputMatrix",
  163. [](fasttext::FastText& m) {
  164. std::shared_ptr<const fasttext::DenseMatrix> mm =
  165. m.getInputMatrix();
  166. return mm.get();
  167. },
  168. pybind11::return_value_policy::reference)
  169. .def(
  170. "getOutputMatrix",
  171. [](fasttext::FastText& m) {
  172. std::shared_ptr<const fasttext::DenseMatrix> mm =
  173. m.getOutputMatrix();
  174. return mm.get();
  175. },
  176. pybind11::return_value_policy::reference)
  177. .def(
  178. "setMatrices",
  179. [](fasttext::FastText& m,
  180. py::buffer inputMatrixBuffer,
  181. py::buffer outputMatrixBuffer) {
  182. py::buffer_info inputMatrixInfo = inputMatrixBuffer.request();
  183. py::buffer_info outputMatrixInfo = outputMatrixBuffer.request();
  184. m.setMatrices(
  185. std::make_shared<fasttext::DenseMatrix>(
  186. inputMatrixInfo.shape[0],
  187. inputMatrixInfo.shape[1],
  188. static_cast<float*>(inputMatrixInfo.ptr)),
  189. std::make_shared<fasttext::DenseMatrix>(
  190. outputMatrixInfo.shape[0],
  191. outputMatrixInfo.shape[1],
  192. static_cast<float*>(outputMatrixInfo.ptr)));
  193. })
  194. .def(
  195. "loadModel",
  196. [](fasttext::FastText& m, std::string s) { m.loadModel(s); })
  197. .def(
  198. "saveModel",
  199. [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
  200. .def(
  201. "test",
  202. [](fasttext::FastText& m,
  203. const std::string filename,
  204. int32_t k,
  205. fasttext::real threshold) {
  206. std::ifstream ifs(filename);
  207. if (!ifs.is_open()) {
  208. throw std::invalid_argument("Test file cannot be opened!");
  209. }
  210. fasttext::Meter meter;
  211. m.test(ifs, k, threshold, meter);
  212. ifs.close();
  213. return std::tuple<int64_t, double, double>(
  214. meter.nexamples(), meter.precision(), meter.recall());
  215. })
  216. .def(
  217. "getSentenceVector",
  218. [](fasttext::FastText& m,
  219. fasttext::Vector& v,
  220. const std::string text) {
  221. std::stringstream ioss(text);
  222. m.getSentenceVector(ioss, v);
  223. })
  224. .def(
  225. "tokenize",
  226. [](fasttext::FastText& m, const std::string text) {
  227. std::vector<std::string> text_split;
  228. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  229. std::stringstream ioss(text);
  230. std::string token;
  231. while (!ioss.eof()) {
  232. while (d->readWord(ioss, token)) {
  233. text_split.push_back(token);
  234. }
  235. }
  236. return text_split;
  237. })
  238. .def("getLine", &getLineText)
  239. .def(
  240. "multilineGetLine",
  241. [](fasttext::FastText& m,
  242. const std::vector<std::string> lines,
  243. const char* onUnicodeError) {
  244. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  245. std::vector<std::vector<py::str>> all_words;
  246. std::vector<std::vector<py::str>> all_labels;
  247. for (const auto& text : lines) {
  248. auto pair = getLineText(m, text, onUnicodeError);
  249. all_words.push_back(pair.first);
  250. all_labels.push_back(pair.second);
  251. }
  252. return std::pair<
  253. std::vector<std::vector<py::str>>,
  254. std::vector<std::vector<py::str>>>(all_words, all_labels);
  255. })
  256. .def(
  257. "getVocab",
  258. [](fasttext::FastText& m, const char* onUnicodeError) {
  259. py::str s;
  260. std::vector<py::str> vocab_list;
  261. std::vector<int64_t> vocab_freq;
  262. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  263. vocab_freq = d->getCounts(fasttext::entry_type::word);
  264. for (int32_t i = 0; i < vocab_freq.size(); i++) {
  265. vocab_list.push_back(
  266. castToPythonString(d->getWord(i), onUnicodeError));
  267. }
  268. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  269. vocab_list, vocab_freq);
  270. })
  271. .def(
  272. "getLabels",
  273. [](fasttext::FastText& m, const char* onUnicodeError) {
  274. std::vector<py::str> labels_list;
  275. std::vector<int64_t> labels_freq;
  276. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  277. labels_freq = d->getCounts(fasttext::entry_type::label);
  278. for (int32_t i = 0; i < labels_freq.size(); i++) {
  279. labels_list.push_back(
  280. castToPythonString(d->getLabel(i), onUnicodeError));
  281. }
  282. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  283. labels_list, labels_freq);
  284. })
  285. .def(
  286. "quantize",
  287. [](fasttext::FastText& m,
  288. const std::string input,
  289. bool qout,
  290. int32_t cutoff,
  291. bool retrain,
  292. int epoch,
  293. double lr,
  294. int thread,
  295. int verbose,
  296. int32_t dsub,
  297. bool qnorm) {
  298. fasttext::Args qa = fasttext::Args();
  299. qa.input = input;
  300. qa.qout = qout;
  301. qa.cutoff = cutoff;
  302. qa.retrain = retrain;
  303. qa.epoch = epoch;
  304. qa.lr = lr;
  305. qa.thread = thread;
  306. qa.verbose = verbose;
  307. qa.dsub = dsub;
  308. qa.qnorm = qnorm;
  309. m.quantize(qa);
  310. })
  311. .def(
  312. "predict",
  313. // NOTE: text needs to end in a newline
  314. // to exactly mimic the behavior of the cli
  315. [](fasttext::FastText& m,
  316. const std::string text,
  317. int32_t k,
  318. fasttext::real threshold,
  319. const char* onUnicodeError) {
  320. std::stringstream ioss(text);
  321. std::vector<std::pair<fasttext::real, std::string>> predictions;
  322. m.predictLine(ioss, predictions, k, threshold);
  323. std::vector<std::pair<fasttext::real, py::str>>
  324. transformedPredictions;
  325. for (const auto& prediction : predictions) {
  326. transformedPredictions.push_back(std::make_pair(
  327. prediction.first,
  328. castToPythonString(prediction.second, onUnicodeError)));
  329. }
  330. return transformedPredictions;
  331. })
  332. .def(
  333. "multilinePredict",
  334. // NOTE: text needs to end in a newline
  335. // to exactly mimic the behavior of the cli
  336. [](fasttext::FastText& m,
  337. const std::vector<std::string>& lines,
  338. int32_t k,
  339. fasttext::real threshold,
  340. const char* onUnicodeError) {
  341. std::vector<py::array_t<fasttext::real>> allProbabilities;
  342. std::vector<std::vector<py::str>> allLabels;
  343. std::vector<std::pair<fasttext::real, std::string>> predictions;
  344. for (const std::string& text : lines) {
  345. std::stringstream ioss(text);
  346. m.predictLine(ioss, predictions, k, threshold);
  347. std::vector<fasttext::real> probabilities;
  348. std::vector<py::str> labels;
  349. for (const auto& prediction : predictions) {
  350. probabilities.push_back(prediction.first);
  351. labels.push_back(
  352. castToPythonString(prediction.second, onUnicodeError));
  353. }
  354. allProbabilities.emplace_back(
  355. probabilities.size(), probabilities.data());
  356. allLabels.push_back(labels);
  357. }
  358. return make_pair(allLabels, allProbabilities);
  359. })
  360. .def(
  361. "testLabel",
  362. [](fasttext::FastText& m,
  363. const std::string filename,
  364. int32_t k,
  365. fasttext::real threshold) {
  366. std::ifstream ifs(filename);
  367. if (!ifs.is_open()) {
  368. throw std::invalid_argument("Test file cannot be opened!");
  369. }
  370. fasttext::Meter meter;
  371. m.test(ifs, k, threshold, meter);
  372. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  373. std::unordered_map<std::string, py::dict> returnedValue;
  374. for (int32_t i = 0; i < d->nlabels(); i++) {
  375. returnedValue[d->getLabel(i)] = py::dict(
  376. "precision"_a = meter.precision(i),
  377. "recall"_a = meter.recall(i),
  378. "f1score"_a = meter.f1Score(i));
  379. }
  380. return returnedValue;
  381. })
  382. .def(
  383. "getWordId",
  384. [](fasttext::FastText& m, const std::string word) {
  385. return m.getWordId(word);
  386. })
  387. .def(
  388. "getSubwordId",
  389. [](fasttext::FastText& m, const std::string word) {
  390. return m.getSubwordId(word);
  391. })
  392. .def(
  393. "getInputVector",
  394. [](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {
  395. m.getInputVector(vec, ind);
  396. })
  397. .def(
  398. "getWordVector",
  399. [](fasttext::FastText& m,
  400. fasttext::Vector& vec,
  401. const std::string word) { m.getWordVector(vec, word); })
  402. .def(
  403. "getNN",
  404. [](fasttext::FastText& m, const std::string& word, int32_t k) {
  405. return m.getNN(word, k);
  406. })
  407. .def(
  408. "getAnalogies",
  409. [](fasttext::FastText& m,
  410. const std::string& wordA,
  411. const std::string& wordB,
  412. const std::string& wordC,
  413. int32_t k) { return m.getAnalogies(k, wordA, wordB, wordC); })
  414. .def(
  415. "getSubwords",
  416. [](fasttext::FastText& m,
  417. const std::string word,
  418. const char* onUnicodeError) {
  419. std::vector<std::string> subwords;
  420. std::vector<int32_t> ngrams;
  421. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  422. d->getSubwords(word, ngrams, subwords);
  423. std::vector<py::str> transformedSubwords;
  424. for (const auto& subword : subwords) {
  425. transformedSubwords.push_back(
  426. castToPythonString(subword, onUnicodeError));
  427. }
  428. return std::pair<std::vector<py::str>, std::vector<int32_t>>(
  429. transformedSubwords, ngrams);
  430. })
  431. .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
  432. }