fasttext_pybind.cc 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. /**
  2. * Copyright (c) 2017-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <args.h>
  9. #include <autotune.h>
  10. #include <densematrix.h>
  11. #include <fasttext.h>
  12. #include <pybind11/numpy.h>
  13. #include <pybind11/pybind11.h>
  14. #include <pybind11/stl.h>
  15. #include <real.h>
  16. #include <vector.h>
  17. #include <cmath>
  18. #include <iterator>
  19. #include <sstream>
  20. #include <stdexcept>
  21. using namespace pybind11::literals;
  22. namespace py = pybind11;
  23. py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
  24. PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
  25. if (!handle) {
  26. throw py::error_already_set();
  27. }
  28. // py::str's constructor from a PyObject assumes the string has been encoded
  29. // for python 2 and not encoded for python 3 :
  30. // https://github.com/pybind/pybind11/blob/ccbe68b084806dece5863437a7dc93de20bd9b15/include/pybind11/pytypes.h#L930
  31. #if PY_MAJOR_VERSION < 3
  32. PyObject* handle_encoded =
  33. PyUnicode_AsEncodedString(handle, "utf-8", onUnicodeError);
  34. Py_DECREF(handle);
  35. handle = handle_encoded;
  36. #endif
  37. py::str handle_str = py::str(handle);
  38. Py_DECREF(handle);
  39. return handle_str;
  40. }
  41. std::vector<std::pair<fasttext::real, py::str>> castToPythonString(
  42. const std::vector<std::pair<fasttext::real, std::string>>& predictions,
  43. const char* onUnicodeError) {
  44. std::vector<std::pair<fasttext::real, py::str>> transformedPredictions;
  45. for (const auto& prediction : predictions) {
  46. transformedPredictions.emplace_back(
  47. prediction.first,
  48. castToPythonString(prediction.second, onUnicodeError));
  49. }
  50. return transformedPredictions;
  51. }
  52. std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
  53. fasttext::FastText& m,
  54. const std::string text,
  55. const char* onUnicodeError) {
  56. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  57. std::stringstream ioss(text);
  58. std::string token;
  59. std::vector<py::str> words;
  60. std::vector<py::str> labels;
  61. while (d->readWord(ioss, token)) {
  62. uint32_t h = d->hash(token);
  63. int32_t wid = d->getId(token, h);
  64. fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
  65. if (type == fasttext::entry_type::word) {
  66. words.push_back(castToPythonString(token, onUnicodeError));
  67. // Labels must not be OOV!
  68. } else if (type == fasttext::entry_type::label && wid >= 0) {
  69. labels.push_back(castToPythonString(token, onUnicodeError));
  70. }
  71. if (token == fasttext::Dictionary::EOS)
  72. break;
  73. }
  74. return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
  75. }
  76. PYBIND11_MODULE(fasttext_pybind, m) {
  77. py::class_<fasttext::Args>(m, "args")
  78. .def(py::init<>())
  79. .def_readwrite("input", &fasttext::Args::input)
  80. .def_readwrite("output", &fasttext::Args::output)
  81. .def_readwrite("lr", &fasttext::Args::lr)
  82. .def_readwrite("lrUpdateRate", &fasttext::Args::lrUpdateRate)
  83. .def_readwrite("dim", &fasttext::Args::dim)
  84. .def_readwrite("ws", &fasttext::Args::ws)
  85. .def_readwrite("epoch", &fasttext::Args::epoch)
  86. .def_readwrite("minCount", &fasttext::Args::minCount)
  87. .def_readwrite("minCountLabel", &fasttext::Args::minCountLabel)
  88. .def_readwrite("neg", &fasttext::Args::neg)
  89. .def_readwrite("wordNgrams", &fasttext::Args::wordNgrams)
  90. .def_readwrite("loss", &fasttext::Args::loss)
  91. .def_readwrite("model", &fasttext::Args::model)
  92. .def_readwrite("bucket", &fasttext::Args::bucket)
  93. .def_readwrite("minn", &fasttext::Args::minn)
  94. .def_readwrite("maxn", &fasttext::Args::maxn)
  95. .def_readwrite("thread", &fasttext::Args::thread)
  96. .def_readwrite("t", &fasttext::Args::t)
  97. .def_readwrite("label", &fasttext::Args::label)
  98. .def_readwrite("verbose", &fasttext::Args::verbose)
  99. .def_readwrite("pretrainedVectors", &fasttext::Args::pretrainedVectors)
  100. .def_readwrite("saveOutput", &fasttext::Args::saveOutput)
  101. .def_readwrite("seed", &fasttext::Args::seed)
  102. .def_readwrite("qout", &fasttext::Args::qout)
  103. .def_readwrite("retrain", &fasttext::Args::retrain)
  104. .def_readwrite("qnorm", &fasttext::Args::qnorm)
  105. .def_readwrite("cutoff", &fasttext::Args::cutoff)
  106. .def_readwrite("dsub", &fasttext::Args::dsub)
  107. .def_readwrite(
  108. "autotuneValidationFile", &fasttext::Args::autotuneValidationFile)
  109. .def_readwrite("autotuneMetric", &fasttext::Args::autotuneMetric)
  110. .def_readwrite(
  111. "autotunePredictions", &fasttext::Args::autotunePredictions)
  112. .def_readwrite("autotuneDuration", &fasttext::Args::autotuneDuration)
  113. .def_readwrite("autotuneModelSize", &fasttext::Args::autotuneModelSize)
  114. .def("setManual", [](fasttext::Args& m, const std::string& argName) {
  115. m.setManual(argName);
  116. });
  117. py::enum_<fasttext::model_name>(m, "model_name")
  118. .value("cbow", fasttext::model_name::cbow)
  119. .value("skipgram", fasttext::model_name::sg)
  120. .value("supervised", fasttext::model_name::sup)
  121. .export_values();
  122. py::enum_<fasttext::loss_name>(m, "loss_name")
  123. .value("hs", fasttext::loss_name::hs)
  124. .value("ns", fasttext::loss_name::ns)
  125. .value("softmax", fasttext::loss_name::softmax)
  126. .value("ova", fasttext::loss_name::ova)
  127. .export_values();
  128. py::enum_<fasttext::metric_name>(m, "metric_name")
  129. .value("f1score", fasttext::metric_name::f1score)
  130. .value("f1scoreLabel", fasttext::metric_name::f1scoreLabel)
  131. .value("precisionAtRecall", fasttext::metric_name::precisionAtRecall)
  132. .value(
  133. "precisionAtRecallLabel",
  134. fasttext::metric_name::precisionAtRecallLabel)
  135. .value("recallAtPrecision", fasttext::metric_name::recallAtPrecision)
  136. .value(
  137. "recallAtPrecisionLabel",
  138. fasttext::metric_name::recallAtPrecisionLabel)
  139. .export_values();
  140. m.def(
  141. "train",
  142. [](fasttext::FastText& ft, fasttext::Args& a) {
  143. if (a.hasAutotune()) {
  144. fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(
  145. &ft, [](fasttext::FastText*) {}));
  146. autotune.train(a);
  147. } else {
  148. ft.train(a);
  149. }
  150. },
  151. py::call_guard<py::gil_scoped_release>());
  152. py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
  153. .def(py::init<ssize_t>())
  154. .def_buffer([](fasttext::Vector& m) -> py::buffer_info {
  155. return py::buffer_info(
  156. m.data(),
  157. sizeof(fasttext::real),
  158. py::format_descriptor<fasttext::real>::format(),
  159. 1,
  160. {m.size()},
  161. {sizeof(fasttext::real)});
  162. });
  163. py::class_<fasttext::DenseMatrix>(
  164. m, "DenseMatrix", py::buffer_protocol(), py::module_local())
  165. .def(py::init<>())
  166. .def(py::init<ssize_t, ssize_t>())
  167. .def_buffer([](fasttext::DenseMatrix& m) -> py::buffer_info {
  168. return py::buffer_info(
  169. m.data(),
  170. sizeof(fasttext::real),
  171. py::format_descriptor<fasttext::real>::format(),
  172. 2,
  173. {m.size(0), m.size(1)},
  174. {sizeof(fasttext::real) * m.size(1),
  175. sizeof(fasttext::real) * (int64_t)1});
  176. });
  177. py::class_<fasttext::Meter>(m, "Meter")
  178. .def(py::init<bool>())
  179. .def("scoreVsTrue", &fasttext::Meter::scoreVsTrue)
  180. .def(
  181. "precisionRecallCurveLabel",
  182. (std::vector<std::pair<double, double>>(fasttext::Meter::*)(int32_t)
  183. const) &
  184. fasttext::Meter::precisionRecallCurve)
  185. .def(
  186. "precisionRecallCurve",
  187. (std::vector<std::pair<double, double>>(fasttext::Meter::*)() const) &
  188. fasttext::Meter::precisionRecallCurve)
  189. .def(
  190. "precisionAtRecallLabel",
  191. (double (fasttext::Meter::*)(int32_t, double) const) &
  192. fasttext::Meter::precisionAtRecall)
  193. .def(
  194. "precisionAtRecall",
  195. (double (fasttext::Meter::*)(double) const) &
  196. fasttext::Meter::precisionAtRecall)
  197. .def(
  198. "recallAtPrecisionLabel",
  199. (double (fasttext::Meter::*)(int32_t, double) const) &
  200. fasttext::Meter::recallAtPrecision)
  201. .def(
  202. "recallAtPrecision",
  203. (double (fasttext::Meter::*)(double) const) &
  204. fasttext::Meter::recallAtPrecision);
  205. py::class_<fasttext::FastText>(m, "fasttext")
  206. .def(py::init<>())
  207. .def("getArgs", &fasttext::FastText::getArgs)
  208. .def(
  209. "getInputMatrix",
  210. [](fasttext::FastText& m) {
  211. std::shared_ptr<const fasttext::DenseMatrix> mm =
  212. m.getInputMatrix();
  213. return mm.get();
  214. },
  215. pybind11::return_value_policy::reference)
  216. .def(
  217. "getOutputMatrix",
  218. [](fasttext::FastText& m) {
  219. std::shared_ptr<const fasttext::DenseMatrix> mm =
  220. m.getOutputMatrix();
  221. return mm.get();
  222. },
  223. pybind11::return_value_policy::reference)
  224. .def(
  225. "setMatrices",
  226. [](fasttext::FastText& m,
  227. py::buffer inputMatrixBuffer,
  228. py::buffer outputMatrixBuffer) {
  229. py::buffer_info inputMatrixInfo = inputMatrixBuffer.request();
  230. py::buffer_info outputMatrixInfo = outputMatrixBuffer.request();
  231. m.setMatrices(
  232. std::make_shared<fasttext::DenseMatrix>(
  233. inputMatrixInfo.shape[0],
  234. inputMatrixInfo.shape[1],
  235. static_cast<float*>(inputMatrixInfo.ptr)),
  236. std::make_shared<fasttext::DenseMatrix>(
  237. outputMatrixInfo.shape[0],
  238. outputMatrixInfo.shape[1],
  239. static_cast<float*>(outputMatrixInfo.ptr)));
  240. })
  241. .def(
  242. "loadModel",
  243. [](fasttext::FastText& m, std::string s) { m.loadModel(s); })
  244. .def(
  245. "saveModel",
  246. [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
  247. .def(
  248. "test",
  249. [](fasttext::FastText& m,
  250. const std::string& filename,
  251. int32_t k,
  252. fasttext::real threshold) {
  253. std::ifstream ifs(filename);
  254. if (!ifs.is_open()) {
  255. throw std::invalid_argument("Test file cannot be opened!");
  256. }
  257. fasttext::Meter meter(false);
  258. m.test(ifs, k, threshold, meter);
  259. ifs.close();
  260. return std::tuple<int64_t, double, double>(
  261. meter.nexamples(), meter.precision(), meter.recall());
  262. })
  263. .def(
  264. "getMeter",
  265. [](fasttext::FastText& m, const std::string& filename, int32_t k) {
  266. std::ifstream ifs(filename);
  267. if (!ifs.is_open()) {
  268. throw std::invalid_argument("Test file cannot be opened!");
  269. }
  270. fasttext::Meter meter(true);
  271. m.test(ifs, k, 0.0, meter);
  272. ifs.close();
  273. return meter;
  274. })
  275. .def(
  276. "getSentenceVector",
  277. [](fasttext::FastText& m,
  278. fasttext::Vector& v,
  279. const std::string text) {
  280. std::stringstream ioss(text);
  281. m.getSentenceVector(ioss, v);
  282. })
  283. .def(
  284. "tokenize",
  285. [](fasttext::FastText& m, const std::string text) {
  286. std::vector<std::string> text_split;
  287. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  288. std::stringstream ioss(text);
  289. std::string token;
  290. while (!ioss.eof()) {
  291. while (d->readWord(ioss, token)) {
  292. text_split.push_back(token);
  293. }
  294. }
  295. return text_split;
  296. })
  297. .def("getLine", &getLineText)
  298. .def(
  299. "multilineGetLine",
  300. [](fasttext::FastText& m,
  301. const std::vector<std::string> lines,
  302. const char* onUnicodeError) {
  303. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  304. std::vector<std::vector<py::str>> all_words;
  305. std::vector<std::vector<py::str>> all_labels;
  306. for (const auto& text : lines) {
  307. auto pair = getLineText(m, text, onUnicodeError);
  308. all_words.push_back(pair.first);
  309. all_labels.push_back(pair.second);
  310. }
  311. return std::pair<
  312. std::vector<std::vector<py::str>>,
  313. std::vector<std::vector<py::str>>>(all_words, all_labels);
  314. })
  315. .def(
  316. "getVocab",
  317. [](fasttext::FastText& m, const char* onUnicodeError) {
  318. py::str s;
  319. std::vector<py::str> vocab_list;
  320. std::vector<int64_t> vocab_freq;
  321. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  322. vocab_freq = d->getCounts(fasttext::entry_type::word);
  323. for (int32_t i = 0; i < vocab_freq.size(); i++) {
  324. vocab_list.push_back(
  325. castToPythonString(d->getWord(i), onUnicodeError));
  326. }
  327. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  328. vocab_list, vocab_freq);
  329. })
  330. .def(
  331. "getLabels",
  332. [](fasttext::FastText& m, const char* onUnicodeError) {
  333. std::vector<py::str> labels_list;
  334. std::vector<int64_t> labels_freq;
  335. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  336. labels_freq = d->getCounts(fasttext::entry_type::label);
  337. for (int32_t i = 0; i < labels_freq.size(); i++) {
  338. labels_list.push_back(
  339. castToPythonString(d->getLabel(i), onUnicodeError));
  340. }
  341. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  342. labels_list, labels_freq);
  343. })
  344. .def(
  345. "quantize",
  346. [](fasttext::FastText& m,
  347. const std::string input,
  348. bool qout,
  349. int32_t cutoff,
  350. bool retrain,
  351. int epoch,
  352. double lr,
  353. int thread,
  354. int verbose,
  355. int32_t dsub,
  356. bool qnorm) {
  357. fasttext::Args qa = fasttext::Args();
  358. qa.input = input;
  359. qa.qout = qout;
  360. qa.cutoff = cutoff;
  361. qa.retrain = retrain;
  362. qa.epoch = epoch;
  363. qa.lr = lr;
  364. qa.thread = thread;
  365. qa.verbose = verbose;
  366. qa.dsub = dsub;
  367. qa.qnorm = qnorm;
  368. m.quantize(qa);
  369. })
  370. .def(
  371. "predict",
  372. // NOTE: text needs to end in a newline
  373. // to exactly mimic the behavior of the cli
  374. [](fasttext::FastText& m,
  375. const std::string text,
  376. int32_t k,
  377. fasttext::real threshold,
  378. const char* onUnicodeError) {
  379. std::stringstream ioss(text);
  380. std::vector<std::pair<fasttext::real, std::string>> predictions;
  381. m.predictLine(ioss, predictions, k, threshold);
  382. return castToPythonString(predictions, onUnicodeError);
  383. })
  384. .def(
  385. "multilinePredict",
  386. // NOTE: text needs to end in a newline
  387. // to exactly mimic the behavior of the cli
  388. [](fasttext::FastText& m,
  389. const std::vector<std::string>& lines,
  390. int32_t k,
  391. fasttext::real threshold,
  392. const char* onUnicodeError) {
  393. std::vector<py::array_t<fasttext::real>> allProbabilities;
  394. std::vector<std::vector<py::str>> allLabels;
  395. std::vector<std::pair<fasttext::real, std::string>> predictions;
  396. for (const std::string& text : lines) {
  397. std::stringstream ioss(text);
  398. m.predictLine(ioss, predictions, k, threshold);
  399. std::vector<fasttext::real> probabilities;
  400. std::vector<py::str> labels;
  401. for (const auto& prediction : predictions) {
  402. probabilities.push_back(prediction.first);
  403. labels.push_back(
  404. castToPythonString(prediction.second, onUnicodeError));
  405. }
  406. allProbabilities.emplace_back(
  407. probabilities.size(), probabilities.data());
  408. allLabels.push_back(labels);
  409. }
  410. return make_pair(allLabels, allProbabilities);
  411. })
  412. .def(
  413. "testLabel",
  414. [](fasttext::FastText& m,
  415. const std::string filename,
  416. int32_t k,
  417. fasttext::real threshold) {
  418. std::ifstream ifs(filename);
  419. if (!ifs.is_open()) {
  420. throw std::invalid_argument("Test file cannot be opened!");
  421. }
  422. fasttext::Meter meter(false);
  423. m.test(ifs, k, threshold, meter);
  424. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  425. std::unordered_map<std::string, py::dict> returnedValue;
  426. for (int32_t i = 0; i < d->nlabels(); i++) {
  427. returnedValue[d->getLabel(i)] = py::dict(
  428. "precision"_a = meter.precision(i),
  429. "recall"_a = meter.recall(i),
  430. "f1score"_a = meter.f1Score(i));
  431. }
  432. return returnedValue;
  433. })
  434. .def(
  435. "getWordId",
  436. [](fasttext::FastText& m, const std::string& word) {
  437. return m.getWordId(word);
  438. })
  439. .def(
  440. "getSubwordId",
  441. [](fasttext::FastText& m, const std::string word) {
  442. return m.getSubwordId(word);
  443. })
  444. .def(
  445. "getLabelId",
  446. [](fasttext::FastText& m, const std::string& label) {
  447. return m.getLabelId(label);
  448. })
  449. .def(
  450. "getInputVector",
  451. [](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {
  452. m.getInputVector(vec, ind);
  453. })
  454. .def(
  455. "getWordVector",
  456. [](fasttext::FastText& m,
  457. fasttext::Vector& vec,
  458. const std::string word) { m.getWordVector(vec, word); })
  459. .def(
  460. "getNN",
  461. [](fasttext::FastText& m,
  462. const std::string& word,
  463. int32_t k,
  464. const char* onUnicodeError) {
  465. return castToPythonString(m.getNN(word, k), onUnicodeError);
  466. })
  467. .def(
  468. "getAnalogies",
  469. [](fasttext::FastText& m,
  470. const std::string& wordA,
  471. const std::string& wordB,
  472. const std::string& wordC,
  473. int32_t k,
  474. const char* onUnicodeError) {
  475. return castToPythonString(
  476. m.getAnalogies(k, wordA, wordB, wordC), onUnicodeError);
  477. })
  478. .def(
  479. "getSubwords",
  480. [](fasttext::FastText& m,
  481. const std::string word,
  482. const char* onUnicodeError) {
  483. std::vector<std::string> subwords;
  484. std::vector<int32_t> ngrams;
  485. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  486. d->getSubwords(word, ngrams, subwords);
  487. std::vector<py::str> transformedSubwords;
  488. for (const auto& subword : subwords) {
  489. transformedSubwords.push_back(
  490. castToPythonString(subword, onUnicodeError));
  491. }
  492. return std::pair<std::vector<py::str>, std::vector<int32_t>>(
  493. transformedSubwords, ngrams);
  494. })
  495. .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
  496. }