1
0

fasttext_pybind.cc 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. /**
  2. * Copyright (c) 2017-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <args.h>
  9. #include <autotune.h>
  10. #include <densematrix.h>
  11. #include <fasttext.h>
  12. #include <pybind11/numpy.h>
  13. #include <pybind11/pybind11.h>
  14. #include <pybind11/stl.h>
  15. #include <real.h>
  16. #include <vector.h>
  17. #include <cmath>
  18. #include <iterator>
  19. #include <sstream>
  20. #include <stdexcept>
  21. using namespace pybind11::literals;
  22. namespace py = pybind11;
  23. py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
  24. PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
  25. if (!handle) {
  26. throw py::error_already_set();
  27. }
  28. // py::str's constructor from a PyObject assumes the string has been encoded
  29. // for python 2 and not encoded for python 3 :
  30. // https://github.com/pybind/pybind11/blob/ccbe68b084806dece5863437a7dc93de20bd9b15/include/pybind11/pytypes.h#L930
  31. #if PY_MAJOR_VERSION < 3
  32. PyObject* handle_encoded =
  33. PyUnicode_AsEncodedString(handle, "utf-8", onUnicodeError);
  34. Py_DECREF(handle);
  35. handle = handle_encoded;
  36. #endif
  37. py::str handle_str = py::str(handle);
  38. Py_DECREF(handle);
  39. return handle_str;
  40. }
  41. std::vector<std::pair<fasttext::real, py::str>> castToPythonString(
  42. const std::vector<std::pair<fasttext::real, std::string>>& predictions,
  43. const char* onUnicodeError) {
  44. std::vector<std::pair<fasttext::real, py::str>> transformedPredictions;
  45. for (const auto& prediction : predictions) {
  46. transformedPredictions.emplace_back(
  47. prediction.first,
  48. castToPythonString(prediction.second, onUnicodeError));
  49. }
  50. return transformedPredictions;
  51. }
  52. std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
  53. fasttext::FastText& m,
  54. const std::string text,
  55. const char* onUnicodeError) {
  56. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  57. std::stringstream ioss(text);
  58. std::string token;
  59. std::vector<py::str> words;
  60. std::vector<py::str> labels;
  61. while (d->readWord(ioss, token)) {
  62. uint32_t h = d->hash(token);
  63. int32_t wid = d->getId(token, h);
  64. fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
  65. if (type == fasttext::entry_type::word) {
  66. words.push_back(castToPythonString(token, onUnicodeError));
  67. // Labels must not be OOV!
  68. } else if (type == fasttext::entry_type::label && wid >= 0) {
  69. labels.push_back(castToPythonString(token, onUnicodeError));
  70. }
  71. if (token == fasttext::Dictionary::EOS)
  72. break;
  73. }
  74. return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
  75. }
  76. PYBIND11_MODULE(fasttext_pybind, m) {
  77. py::class_<fasttext::Args>(m, "args")
  78. .def(py::init<>())
  79. .def_readwrite("input", &fasttext::Args::input)
  80. .def_readwrite("output", &fasttext::Args::output)
  81. .def_readwrite("lr", &fasttext::Args::lr)
  82. .def_readwrite("lrUpdateRate", &fasttext::Args::lrUpdateRate)
  83. .def_readwrite("dim", &fasttext::Args::dim)
  84. .def_readwrite("ws", &fasttext::Args::ws)
  85. .def_readwrite("epoch", &fasttext::Args::epoch)
  86. .def_readwrite("minCount", &fasttext::Args::minCount)
  87. .def_readwrite("minCountLabel", &fasttext::Args::minCountLabel)
  88. .def_readwrite("neg", &fasttext::Args::neg)
  89. .def_readwrite("wordNgrams", &fasttext::Args::wordNgrams)
  90. .def_readwrite("loss", &fasttext::Args::loss)
  91. .def_readwrite("model", &fasttext::Args::model)
  92. .def_readwrite("bucket", &fasttext::Args::bucket)
  93. .def_readwrite("minn", &fasttext::Args::minn)
  94. .def_readwrite("maxn", &fasttext::Args::maxn)
  95. .def_readwrite("thread", &fasttext::Args::thread)
  96. .def_readwrite("t", &fasttext::Args::t)
  97. .def_readwrite("label", &fasttext::Args::label)
  98. .def_readwrite("verbose", &fasttext::Args::verbose)
  99. .def_readwrite("pretrainedVectors", &fasttext::Args::pretrainedVectors)
  100. .def_readwrite("saveOutput", &fasttext::Args::saveOutput)
  101. .def_readwrite("seed", &fasttext::Args::seed)
  102. .def_readwrite("qout", &fasttext::Args::qout)
  103. .def_readwrite("retrain", &fasttext::Args::retrain)
  104. .def_readwrite("qnorm", &fasttext::Args::qnorm)
  105. .def_readwrite("cutoff", &fasttext::Args::cutoff)
  106. .def_readwrite("dsub", &fasttext::Args::dsub)
  107. .def_readwrite(
  108. "autotuneValidationFile", &fasttext::Args::autotuneValidationFile)
  109. .def_readwrite("autotuneMetric", &fasttext::Args::autotuneMetric)
  110. .def_readwrite(
  111. "autotunePredictions", &fasttext::Args::autotunePredictions)
  112. .def_readwrite("autotuneDuration", &fasttext::Args::autotuneDuration)
  113. .def_readwrite("autotuneModelSize", &fasttext::Args::autotuneModelSize)
  114. .def("setManual", [](fasttext::Args& m, const std::string& argName) {
  115. m.setManual(argName);
  116. });
  117. py::enum_<fasttext::model_name>(m, "model_name")
  118. .value("cbow", fasttext::model_name::cbow)
  119. .value("skipgram", fasttext::model_name::sg)
  120. .value("supervised", fasttext::model_name::sup)
  121. .export_values();
  122. py::enum_<fasttext::loss_name>(m, "loss_name")
  123. .value("hs", fasttext::loss_name::hs)
  124. .value("ns", fasttext::loss_name::ns)
  125. .value("softmax", fasttext::loss_name::softmax)
  126. .value("ova", fasttext::loss_name::ova)
  127. .export_values();
  128. py::enum_<fasttext::metric_name>(m, "metric_name")
  129. .value("f1score", fasttext::metric_name::f1score)
  130. .value("f1scoreLabel", fasttext::metric_name::f1scoreLabel)
  131. .value("precisionAtRecall", fasttext::metric_name::precisionAtRecall)
  132. .value(
  133. "precisionAtRecallLabel",
  134. fasttext::metric_name::precisionAtRecallLabel)
  135. .value("recallAtPrecision", fasttext::metric_name::recallAtPrecision)
  136. .value(
  137. "recallAtPrecisionLabel",
  138. fasttext::metric_name::recallAtPrecisionLabel)
  139. .export_values();
  140. m.def(
  141. "train",
  142. [](fasttext::FastText& ft, fasttext::Args& a) {
  143. if (a.hasAutotune()) {
  144. fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(
  145. &ft, [](fasttext::FastText*) {}));
  146. autotune.train(a);
  147. } else {
  148. ft.train(a);
  149. }
  150. },
  151. py::call_guard<py::gil_scoped_release>());
  152. py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
  153. .def(py::init<ssize_t>())
  154. .def_buffer([](fasttext::Vector& m) -> py::buffer_info {
  155. return py::buffer_info(
  156. m.data(),
  157. sizeof(fasttext::real),
  158. py::format_descriptor<fasttext::real>::format(),
  159. 1,
  160. {m.size()},
  161. {sizeof(fasttext::real)});
  162. });
  163. py::class_<fasttext::DenseMatrix>(
  164. m, "DenseMatrix", py::buffer_protocol(), py::module_local())
  165. .def(py::init<>())
  166. .def(py::init<ssize_t, ssize_t>())
  167. .def_buffer([](fasttext::DenseMatrix& m) -> py::buffer_info {
  168. return py::buffer_info(
  169. m.data(),
  170. sizeof(fasttext::real),
  171. py::format_descriptor<fasttext::real>::format(),
  172. 2,
  173. {m.size(0), m.size(1)},
  174. {sizeof(fasttext::real) * m.size(1),
  175. sizeof(fasttext::real) * (int64_t)1});
  176. });
  177. py::class_<fasttext::Meter>(m, "Meter")
  178. .def(py::init<bool>())
  179. .def("scoreVsTrue", &fasttext::Meter::scoreVsTrue)
  180. .def(
  181. "precisionRecallCurveLabel",
  182. py::overload_cast<int32_t>(
  183. &fasttext::Meter::precisionRecallCurve, py::const_))
  184. .def(
  185. "precisionRecallCurve",
  186. py::overload_cast<>(
  187. &fasttext::Meter::precisionRecallCurve, py::const_))
  188. .def(
  189. "precisionAtRecallLabel",
  190. py::overload_cast<int32_t, double>(
  191. &fasttext::Meter::precisionAtRecall, py::const_))
  192. .def(
  193. "precisionAtRecall",
  194. py::overload_cast<double>(
  195. &fasttext::Meter::precisionAtRecall, py::const_))
  196. .def(
  197. "recallAtPrecisionLabel",
  198. py::overload_cast<int32_t, double>(
  199. &fasttext::Meter::recallAtPrecision, py::const_))
  200. .def(
  201. "recallAtPrecision",
  202. py::overload_cast<double>(
  203. &fasttext::Meter::recallAtPrecision, py::const_));
  204. py::class_<fasttext::FastText>(m, "fasttext")
  205. .def(py::init<>())
  206. .def("getArgs", &fasttext::FastText::getArgs)
  207. .def(
  208. "getInputMatrix",
  209. [](fasttext::FastText& m) {
  210. std::shared_ptr<const fasttext::DenseMatrix> mm =
  211. m.getInputMatrix();
  212. return mm.get();
  213. },
  214. pybind11::return_value_policy::reference)
  215. .def(
  216. "getOutputMatrix",
  217. [](fasttext::FastText& m) {
  218. std::shared_ptr<const fasttext::DenseMatrix> mm =
  219. m.getOutputMatrix();
  220. return mm.get();
  221. },
  222. pybind11::return_value_policy::reference)
  223. .def(
  224. "setMatrices",
  225. [](fasttext::FastText& m,
  226. py::buffer inputMatrixBuffer,
  227. py::buffer outputMatrixBuffer) {
  228. py::buffer_info inputMatrixInfo = inputMatrixBuffer.request();
  229. py::buffer_info outputMatrixInfo = outputMatrixBuffer.request();
  230. m.setMatrices(
  231. std::make_shared<fasttext::DenseMatrix>(
  232. inputMatrixInfo.shape[0],
  233. inputMatrixInfo.shape[1],
  234. static_cast<float*>(inputMatrixInfo.ptr)),
  235. std::make_shared<fasttext::DenseMatrix>(
  236. outputMatrixInfo.shape[0],
  237. outputMatrixInfo.shape[1],
  238. static_cast<float*>(outputMatrixInfo.ptr)));
  239. })
  240. .def(
  241. "loadModel",
  242. [](fasttext::FastText& m, std::string s) { m.loadModel(s); })
  243. .def(
  244. "saveModel",
  245. [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
  246. .def(
  247. "test",
  248. [](fasttext::FastText& m,
  249. const std::string& filename,
  250. int32_t k,
  251. fasttext::real threshold) {
  252. std::ifstream ifs(filename);
  253. if (!ifs.is_open()) {
  254. throw std::invalid_argument("Test file cannot be opened!");
  255. }
  256. fasttext::Meter meter(false);
  257. m.test(ifs, k, threshold, meter);
  258. ifs.close();
  259. return std::tuple<int64_t, double, double>(
  260. meter.nexamples(), meter.precision(), meter.recall());
  261. })
  262. .def(
  263. "getMeter",
  264. [](fasttext::FastText& m, const std::string& filename, int32_t k) {
  265. std::ifstream ifs(filename);
  266. if (!ifs.is_open()) {
  267. throw std::invalid_argument("Test file cannot be opened!");
  268. }
  269. fasttext::Meter meter(true);
  270. m.test(ifs, k, 0.0, meter);
  271. ifs.close();
  272. return meter;
  273. })
  274. .def(
  275. "getSentenceVector",
  276. [](fasttext::FastText& m,
  277. fasttext::Vector& v,
  278. const std::string text) {
  279. std::stringstream ioss(text);
  280. m.getSentenceVector(ioss, v);
  281. })
  282. .def(
  283. "tokenize",
  284. [](fasttext::FastText& m, const std::string text) {
  285. std::vector<std::string> text_split;
  286. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  287. std::stringstream ioss(text);
  288. std::string token;
  289. while (!ioss.eof()) {
  290. while (d->readWord(ioss, token)) {
  291. text_split.push_back(token);
  292. }
  293. }
  294. return text_split;
  295. })
  296. .def("getLine", &getLineText)
  297. .def(
  298. "multilineGetLine",
  299. [](fasttext::FastText& m,
  300. const std::vector<std::string> lines,
  301. const char* onUnicodeError) {
  302. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  303. std::vector<std::vector<py::str>> all_words;
  304. std::vector<std::vector<py::str>> all_labels;
  305. for (const auto& text : lines) {
  306. auto pair = getLineText(m, text, onUnicodeError);
  307. all_words.push_back(pair.first);
  308. all_labels.push_back(pair.second);
  309. }
  310. return std::pair<
  311. std::vector<std::vector<py::str>>,
  312. std::vector<std::vector<py::str>>>(all_words, all_labels);
  313. })
  314. .def(
  315. "getVocab",
  316. [](fasttext::FastText& m, const char* onUnicodeError) {
  317. py::str s;
  318. std::vector<py::str> vocab_list;
  319. std::vector<int64_t> vocab_freq;
  320. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  321. vocab_freq = d->getCounts(fasttext::entry_type::word);
  322. for (int32_t i = 0; i < vocab_freq.size(); i++) {
  323. vocab_list.push_back(
  324. castToPythonString(d->getWord(i), onUnicodeError));
  325. }
  326. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  327. vocab_list, vocab_freq);
  328. })
  329. .def(
  330. "getLabels",
  331. [](fasttext::FastText& m, const char* onUnicodeError) {
  332. std::vector<py::str> labels_list;
  333. std::vector<int64_t> labels_freq;
  334. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  335. labels_freq = d->getCounts(fasttext::entry_type::label);
  336. for (int32_t i = 0; i < labels_freq.size(); i++) {
  337. labels_list.push_back(
  338. castToPythonString(d->getLabel(i), onUnicodeError));
  339. }
  340. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  341. labels_list, labels_freq);
  342. })
  343. .def(
  344. "quantize",
  345. [](fasttext::FastText& m,
  346. const std::string input,
  347. bool qout,
  348. int32_t cutoff,
  349. bool retrain,
  350. int epoch,
  351. double lr,
  352. int thread,
  353. int verbose,
  354. int32_t dsub,
  355. bool qnorm) {
  356. fasttext::Args qa = fasttext::Args();
  357. qa.input = input;
  358. qa.qout = qout;
  359. qa.cutoff = cutoff;
  360. qa.retrain = retrain;
  361. qa.epoch = epoch;
  362. qa.lr = lr;
  363. qa.thread = thread;
  364. qa.verbose = verbose;
  365. qa.dsub = dsub;
  366. qa.qnorm = qnorm;
  367. m.quantize(qa);
  368. })
  369. .def(
  370. "predict",
  371. // NOTE: text needs to end in a newline
  372. // to exactly mimic the behavior of the cli
  373. [](fasttext::FastText& m,
  374. const std::string text,
  375. int32_t k,
  376. fasttext::real threshold,
  377. const char* onUnicodeError) {
  378. std::stringstream ioss(text);
  379. std::vector<std::pair<fasttext::real, std::string>> predictions;
  380. m.predictLine(ioss, predictions, k, threshold);
  381. return castToPythonString(predictions, onUnicodeError);
  382. })
  383. .def(
  384. "multilinePredict",
  385. // NOTE: text needs to end in a newline
  386. // to exactly mimic the behavior of the cli
  387. [](fasttext::FastText& m,
  388. const std::vector<std::string>& lines,
  389. int32_t k,
  390. fasttext::real threshold,
  391. const char* onUnicodeError) {
  392. std::vector<py::array_t<fasttext::real>> allProbabilities;
  393. std::vector<std::vector<py::str>> allLabels;
  394. std::vector<std::pair<fasttext::real, std::string>> predictions;
  395. for (const std::string& text : lines) {
  396. std::stringstream ioss(text);
  397. m.predictLine(ioss, predictions, k, threshold);
  398. std::vector<fasttext::real> probabilities;
  399. std::vector<py::str> labels;
  400. for (const auto& prediction : predictions) {
  401. probabilities.push_back(prediction.first);
  402. labels.push_back(
  403. castToPythonString(prediction.second, onUnicodeError));
  404. }
  405. allProbabilities.emplace_back(
  406. probabilities.size(), probabilities.data());
  407. allLabels.push_back(labels);
  408. }
  409. return make_pair(allLabels, allProbabilities);
  410. })
  411. .def(
  412. "testLabel",
  413. [](fasttext::FastText& m,
  414. const std::string filename,
  415. int32_t k,
  416. fasttext::real threshold) {
  417. std::ifstream ifs(filename);
  418. if (!ifs.is_open()) {
  419. throw std::invalid_argument("Test file cannot be opened!");
  420. }
  421. fasttext::Meter meter(false);
  422. m.test(ifs, k, threshold, meter);
  423. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  424. std::unordered_map<std::string, py::dict> returnedValue;
  425. for (int32_t i = 0; i < d->nlabels(); i++) {
  426. returnedValue[d->getLabel(i)] = py::dict(
  427. "precision"_a = meter.precision(i),
  428. "recall"_a = meter.recall(i),
  429. "f1score"_a = meter.f1Score(i));
  430. }
  431. return returnedValue;
  432. })
  433. .def(
  434. "getWordId",
  435. [](fasttext::FastText& m, const std::string& word) {
  436. return m.getWordId(word);
  437. })
  438. .def(
  439. "getSubwordId",
  440. [](fasttext::FastText& m, const std::string word) {
  441. return m.getSubwordId(word);
  442. })
  443. .def(
  444. "getLabelId",
  445. [](fasttext::FastText& m, const std::string& label) {
  446. return m.getLabelId(label);
  447. })
  448. .def(
  449. "getInputVector",
  450. [](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {
  451. m.getInputVector(vec, ind);
  452. })
  453. .def(
  454. "getWordVector",
  455. [](fasttext::FastText& m,
  456. fasttext::Vector& vec,
  457. const std::string word) { m.getWordVector(vec, word); })
  458. .def(
  459. "getNN",
  460. [](fasttext::FastText& m,
  461. const std::string& word,
  462. int32_t k,
  463. const char* onUnicodeError) {
  464. return castToPythonString(m.getNN(word, k), onUnicodeError);
  465. })
  466. .def(
  467. "getAnalogies",
  468. [](fasttext::FastText& m,
  469. const std::string& wordA,
  470. const std::string& wordB,
  471. const std::string& wordC,
  472. int32_t k,
  473. const char* onUnicodeError) {
  474. return castToPythonString(
  475. m.getAnalogies(k, wordA, wordB, wordC), onUnicodeError);
  476. })
  477. .def(
  478. "getSubwords",
  479. [](fasttext::FastText& m,
  480. const std::string word,
  481. const char* onUnicodeError) {
  482. std::vector<std::string> subwords;
  483. std::vector<int32_t> ngrams;
  484. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  485. d->getSubwords(word, ngrams, subwords);
  486. std::vector<py::str> transformedSubwords;
  487. for (const auto& subword : subwords) {
  488. transformedSubwords.push_back(
  489. castToPythonString(subword, onUnicodeError));
  490. }
  491. return std::pair<std::vector<py::str>, std::vector<int32_t>>(
  492. transformedSubwords, ngrams);
  493. })
  494. .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
  495. }