fasttext_pybind.cc 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. /**
  2. * Copyright (c) 2017-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <args.h>
  9. #include <autotune.h>
  10. #include <densematrix.h>
  11. #include <fasttext.h>
  12. #include <pybind11/numpy.h>
  13. #include <pybind11/pybind11.h>
  14. #include <pybind11/stl.h>
  15. #include <real.h>
  16. #include <vector.h>
  17. #include <cmath>
  18. #include <iterator>
  19. #include <sstream>
  20. #include <stdexcept>
  21. using namespace pybind11::literals;
  22. namespace py = pybind11;
  23. py::str castToPythonString(const std::string& s, const char* onUnicodeError) {
  24. PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError);
  25. if (!handle) {
  26. throw py::error_already_set();
  27. }
  28. // py::str's constructor from a PyObject assumes the string has been encoded
  29. // for python 2 and not encoded for python 3 :
  30. // https://github.com/pybind/pybind11/blob/ccbe68b084806dece5863437a7dc93de20bd9b15/include/pybind11/pytypes.h#L930
  31. #if PY_MAJOR_VERSION < 3
  32. PyObject* handle_encoded =
  33. PyUnicode_AsEncodedString(handle, "utf-8", onUnicodeError);
  34. Py_DECREF(handle);
  35. handle = handle_encoded;
  36. #endif
  37. py::str handle_str = py::str(handle);
  38. Py_DECREF(handle);
  39. return handle_str;
  40. }
  41. std::pair<std::vector<py::str>, std::vector<py::str>> getLineText(
  42. fasttext::FastText& m,
  43. const std::string text,
  44. const char* onUnicodeError) {
  45. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  46. std::stringstream ioss(text);
  47. std::string token;
  48. std::vector<py::str> words;
  49. std::vector<py::str> labels;
  50. while (d->readWord(ioss, token)) {
  51. uint32_t h = d->hash(token);
  52. int32_t wid = d->getId(token, h);
  53. fasttext::entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
  54. if (type == fasttext::entry_type::word) {
  55. words.push_back(castToPythonString(token, onUnicodeError));
  56. // Labels must not be OOV!
  57. } else if (type == fasttext::entry_type::label && wid >= 0) {
  58. labels.push_back(castToPythonString(token, onUnicodeError));
  59. }
  60. if (token == fasttext::Dictionary::EOS)
  61. break;
  62. }
  63. return std::pair<std::vector<py::str>, std::vector<py::str>>(words, labels);
  64. }
  65. PYBIND11_MODULE(fasttext_pybind, m) {
  66. py::class_<fasttext::Args>(m, "args")
  67. .def(py::init<>())
  68. .def_readwrite("input", &fasttext::Args::input)
  69. .def_readwrite("output", &fasttext::Args::output)
  70. .def_readwrite("lr", &fasttext::Args::lr)
  71. .def_readwrite("lrUpdateRate", &fasttext::Args::lrUpdateRate)
  72. .def_readwrite("dim", &fasttext::Args::dim)
  73. .def_readwrite("ws", &fasttext::Args::ws)
  74. .def_readwrite("epoch", &fasttext::Args::epoch)
  75. .def_readwrite("minCount", &fasttext::Args::minCount)
  76. .def_readwrite("minCountLabel", &fasttext::Args::minCountLabel)
  77. .def_readwrite("neg", &fasttext::Args::neg)
  78. .def_readwrite("wordNgrams", &fasttext::Args::wordNgrams)
  79. .def_readwrite("loss", &fasttext::Args::loss)
  80. .def_readwrite("model", &fasttext::Args::model)
  81. .def_readwrite("bucket", &fasttext::Args::bucket)
  82. .def_readwrite("minn", &fasttext::Args::minn)
  83. .def_readwrite("maxn", &fasttext::Args::maxn)
  84. .def_readwrite("thread", &fasttext::Args::thread)
  85. .def_readwrite("t", &fasttext::Args::t)
  86. .def_readwrite("label", &fasttext::Args::label)
  87. .def_readwrite("verbose", &fasttext::Args::verbose)
  88. .def_readwrite("pretrainedVectors", &fasttext::Args::pretrainedVectors)
  89. .def_readwrite("saveOutput", &fasttext::Args::saveOutput)
  90. .def_readwrite("seed", &fasttext::Args::seed)
  91. .def_readwrite("qout", &fasttext::Args::qout)
  92. .def_readwrite("retrain", &fasttext::Args::retrain)
  93. .def_readwrite("qnorm", &fasttext::Args::qnorm)
  94. .def_readwrite("cutoff", &fasttext::Args::cutoff)
  95. .def_readwrite("dsub", &fasttext::Args::dsub)
  96. .def_readwrite(
  97. "autotuneValidationFile", &fasttext::Args::autotuneValidationFile)
  98. .def_readwrite("autotuneMetric", &fasttext::Args::autotuneMetric)
  99. .def_readwrite(
  100. "autotunePredictions", &fasttext::Args::autotunePredictions)
  101. .def_readwrite("autotuneDuration", &fasttext::Args::autotuneDuration)
  102. .def_readwrite("autotuneModelSize", &fasttext::Args::autotuneModelSize)
  103. .def("setManual", [](fasttext::Args& m, const std::string& argName) {
  104. m.setManual(argName);
  105. });
  106. py::enum_<fasttext::model_name>(m, "model_name")
  107. .value("cbow", fasttext::model_name::cbow)
  108. .value("skipgram", fasttext::model_name::sg)
  109. .value("supervised", fasttext::model_name::sup)
  110. .export_values();
  111. py::enum_<fasttext::loss_name>(m, "loss_name")
  112. .value("hs", fasttext::loss_name::hs)
  113. .value("ns", fasttext::loss_name::ns)
  114. .value("softmax", fasttext::loss_name::softmax)
  115. .value("ova", fasttext::loss_name::ova)
  116. .export_values();
  117. py::enum_<fasttext::metric_name>(m, "metric_name")
  118. .value("f1score", fasttext::metric_name::f1score)
  119. .value("labelf1score", fasttext::metric_name::labelf1score)
  120. .export_values();
  121. m.def(
  122. "train",
  123. [](fasttext::FastText& ft, fasttext::Args& a) {
  124. if (a.hasAutotune()) {
  125. fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(
  126. &ft, [](fasttext::FastText*) {}));
  127. autotune.train(a);
  128. } else {
  129. ft.train(a);
  130. }
  131. },
  132. py::call_guard<py::gil_scoped_release>());
  133. py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())
  134. .def(py::init<ssize_t>())
  135. .def_buffer([](fasttext::Vector& m) -> py::buffer_info {
  136. return py::buffer_info(
  137. m.data(),
  138. sizeof(fasttext::real),
  139. py::format_descriptor<fasttext::real>::format(),
  140. 1,
  141. {m.size()},
  142. {sizeof(fasttext::real)});
  143. });
  144. py::class_<fasttext::DenseMatrix>(
  145. m, "DenseMatrix", py::buffer_protocol(), py::module_local())
  146. .def(py::init<>())
  147. .def(py::init<ssize_t, ssize_t>())
  148. .def_buffer([](fasttext::DenseMatrix& m) -> py::buffer_info {
  149. return py::buffer_info(
  150. m.data(),
  151. sizeof(fasttext::real),
  152. py::format_descriptor<fasttext::real>::format(),
  153. 2,
  154. {m.size(0), m.size(1)},
  155. {sizeof(fasttext::real) * m.size(1),
  156. sizeof(fasttext::real) * (int64_t)1});
  157. });
  158. py::class_<fasttext::FastText>(m, "fasttext")
  159. .def(py::init<>())
  160. .def("getArgs", &fasttext::FastText::getArgs)
  161. .def(
  162. "getInputMatrix",
  163. [](fasttext::FastText& m) {
  164. std::shared_ptr<const fasttext::DenseMatrix> mm =
  165. m.getInputMatrix();
  166. return mm.get();
  167. },
  168. pybind11::return_value_policy::reference)
  169. .def(
  170. "getOutputMatrix",
  171. [](fasttext::FastText& m) {
  172. std::shared_ptr<const fasttext::DenseMatrix> mm =
  173. m.getOutputMatrix();
  174. return mm.get();
  175. },
  176. pybind11::return_value_policy::reference)
  177. .def(
  178. "setMatrices",
  179. [](fasttext::FastText& m,
  180. py::buffer inputMatrixBuffer,
  181. py::buffer outputMatrixBuffer) {
  182. py::buffer_info inputMatrixInfo = inputMatrixBuffer.request();
  183. py::buffer_info outputMatrixInfo = outputMatrixBuffer.request();
  184. m.setMatrices(
  185. std::make_shared<fasttext::DenseMatrix>(
  186. inputMatrixInfo.shape[0],
  187. inputMatrixInfo.shape[1],
  188. static_cast<float*>(inputMatrixInfo.ptr)),
  189. std::make_shared<fasttext::DenseMatrix>(
  190. outputMatrixInfo.shape[0],
  191. outputMatrixInfo.shape[1],
  192. static_cast<float*>(outputMatrixInfo.ptr)));
  193. })
  194. .def(
  195. "loadModel",
  196. [](fasttext::FastText& m, std::string s) { m.loadModel(s); })
  197. .def(
  198. "saveModel",
  199. [](fasttext::FastText& m, std::string s) { m.saveModel(s); })
  200. .def(
  201. "test",
  202. [](fasttext::FastText& m, const std::string filename, int32_t k) {
  203. std::ifstream ifs(filename);
  204. if (!ifs.is_open()) {
  205. throw std::invalid_argument("Test file cannot be opened!");
  206. }
  207. fasttext::Meter meter;
  208. m.test(ifs, k, 0.0, meter);
  209. ifs.close();
  210. return std::tuple<int64_t, double, double>(
  211. meter.nexamples(), meter.precision(), meter.recall());
  212. })
  213. .def(
  214. "getSentenceVector",
  215. [](fasttext::FastText& m,
  216. fasttext::Vector& v,
  217. const std::string text) {
  218. std::stringstream ioss(text);
  219. m.getSentenceVector(ioss, v);
  220. })
  221. .def(
  222. "tokenize",
  223. [](fasttext::FastText& m, const std::string text) {
  224. std::vector<std::string> text_split;
  225. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  226. std::stringstream ioss(text);
  227. std::string token;
  228. while (!ioss.eof()) {
  229. while (d->readWord(ioss, token)) {
  230. text_split.push_back(token);
  231. }
  232. }
  233. return text_split;
  234. })
  235. .def("getLine", &getLineText)
  236. .def(
  237. "multilineGetLine",
  238. [](fasttext::FastText& m,
  239. const std::vector<std::string> lines,
  240. const char* onUnicodeError) {
  241. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  242. std::vector<std::vector<py::str>> all_words;
  243. std::vector<std::vector<py::str>> all_labels;
  244. for (const auto& text : lines) {
  245. auto pair = getLineText(m, text, onUnicodeError);
  246. all_words.push_back(pair.first);
  247. all_labels.push_back(pair.second);
  248. }
  249. return std::pair<
  250. std::vector<std::vector<py::str>>,
  251. std::vector<std::vector<py::str>>>(all_words, all_labels);
  252. })
  253. .def(
  254. "getVocab",
  255. [](fasttext::FastText& m, const char* onUnicodeError) {
  256. py::str s;
  257. std::vector<py::str> vocab_list;
  258. std::vector<int64_t> vocab_freq;
  259. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  260. vocab_freq = d->getCounts(fasttext::entry_type::word);
  261. for (int32_t i = 0; i < vocab_freq.size(); i++) {
  262. vocab_list.push_back(
  263. castToPythonString(d->getWord(i), onUnicodeError));
  264. }
  265. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  266. vocab_list, vocab_freq);
  267. })
  268. .def(
  269. "getLabels",
  270. [](fasttext::FastText& m, const char* onUnicodeError) {
  271. std::vector<py::str> labels_list;
  272. std::vector<int64_t> labels_freq;
  273. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  274. labels_freq = d->getCounts(fasttext::entry_type::label);
  275. for (int32_t i = 0; i < labels_freq.size(); i++) {
  276. labels_list.push_back(
  277. castToPythonString(d->getLabel(i), onUnicodeError));
  278. }
  279. return std::pair<std::vector<py::str>, std::vector<int64_t>>(
  280. labels_list, labels_freq);
  281. })
  282. .def(
  283. "quantize",
  284. [](fasttext::FastText& m,
  285. const std::string input,
  286. bool qout,
  287. int32_t cutoff,
  288. bool retrain,
  289. int epoch,
  290. double lr,
  291. int thread,
  292. int verbose,
  293. int32_t dsub,
  294. bool qnorm) {
  295. fasttext::Args qa = fasttext::Args();
  296. qa.input = input;
  297. qa.qout = qout;
  298. qa.cutoff = cutoff;
  299. qa.retrain = retrain;
  300. qa.epoch = epoch;
  301. qa.lr = lr;
  302. qa.thread = thread;
  303. qa.verbose = verbose;
  304. qa.dsub = dsub;
  305. qa.qnorm = qnorm;
  306. m.quantize(qa);
  307. })
  308. .def(
  309. "predict",
  310. // NOTE: text needs to end in a newline
  311. // to exactly mimic the behavior of the cli
  312. [](fasttext::FastText& m,
  313. const std::string text,
  314. int32_t k,
  315. fasttext::real threshold,
  316. const char* onUnicodeError) {
  317. std::stringstream ioss(text);
  318. std::vector<std::pair<fasttext::real, std::string>> predictions;
  319. m.predictLine(ioss, predictions, k, threshold);
  320. std::vector<std::pair<fasttext::real, py::str>>
  321. transformedPredictions;
  322. for (const auto& prediction : predictions) {
  323. transformedPredictions.push_back(std::make_pair(
  324. prediction.first,
  325. castToPythonString(prediction.second, onUnicodeError)));
  326. }
  327. return transformedPredictions;
  328. })
  329. .def(
  330. "multilinePredict",
  331. // NOTE: text needs to end in a newline
  332. // to exactly mimic the behavior of the cli
  333. [](fasttext::FastText& m,
  334. const std::vector<std::string>& lines,
  335. int32_t k,
  336. fasttext::real threshold,
  337. const char* onUnicodeError) {
  338. std::vector<py::array_t<fasttext::real>> allProbabilities;
  339. std::vector<std::vector<py::str>> allLabels;
  340. std::vector<std::pair<fasttext::real, std::string>> predictions;
  341. for (const std::string& text : lines) {
  342. std::stringstream ioss(text);
  343. m.predictLine(ioss, predictions, k, threshold);
  344. std::vector<fasttext::real> probabilities;
  345. std::vector<py::str> labels;
  346. for (const auto& prediction : predictions) {
  347. probabilities.push_back(prediction.first);
  348. labels.push_back(
  349. castToPythonString(prediction.second, onUnicodeError));
  350. }
  351. allProbabilities.emplace_back(
  352. probabilities.size(), probabilities.data());
  353. allLabels.push_back(labels);
  354. }
  355. return make_pair(allLabels, allProbabilities);
  356. })
  357. .def(
  358. "testLabel",
  359. [](fasttext::FastText& m,
  360. const std::string filename,
  361. int32_t k,
  362. fasttext::real threshold) {
  363. std::ifstream ifs(filename);
  364. if (!ifs.is_open()) {
  365. throw std::invalid_argument("Test file cannot be opened!");
  366. }
  367. fasttext::Meter meter;
  368. m.test(ifs, k, threshold, meter);
  369. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  370. std::unordered_map<std::string, py::dict> returnedValue;
  371. for (int32_t i = 0; i < d->nlabels(); i++) {
  372. returnedValue[d->getLabel(i)] = py::dict(
  373. "precision"_a = meter.precision(i),
  374. "recall"_a = meter.recall(i),
  375. "f1score"_a = meter.f1Score(i));
  376. }
  377. return returnedValue;
  378. })
  379. .def(
  380. "getWordId",
  381. [](fasttext::FastText& m, const std::string word) {
  382. return m.getWordId(word);
  383. })
  384. .def(
  385. "getSubwordId",
  386. [](fasttext::FastText& m, const std::string word) {
  387. return m.getSubwordId(word);
  388. })
  389. .def(
  390. "getInputVector",
  391. [](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {
  392. m.getInputVector(vec, ind);
  393. })
  394. .def(
  395. "getWordVector",
  396. [](fasttext::FastText& m,
  397. fasttext::Vector& vec,
  398. const std::string word) { m.getWordVector(vec, word); })
  399. .def(
  400. "getNN",
  401. [](fasttext::FastText& m, const std::string& word, int32_t k) {
  402. return m.getNN(word, k);
  403. })
  404. .def(
  405. "getAnalogies",
  406. [](fasttext::FastText& m,
  407. const std::string& wordA,
  408. const std::string& wordB,
  409. const std::string& wordC,
  410. int32_t k) { return m.getAnalogies(k, wordA, wordB, wordC); })
  411. .def(
  412. "getSubwords",
  413. [](fasttext::FastText& m,
  414. const std::string word,
  415. const char* onUnicodeError) {
  416. std::vector<std::string> subwords;
  417. std::vector<int32_t> ngrams;
  418. std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
  419. d->getSubwords(word, ngrams, subwords);
  420. std::vector<py::str> transformedSubwords;
  421. for (const auto& subword : subwords) {
  422. transformedSubwords.push_back(
  423. castToPythonString(subword, onUnicodeError));
  424. }
  425. return std::pair<std::vector<py::str>, std::vector<int32_t>>(
  426. transformedSubwords, ngrams);
  427. })
  428. .def("isQuant", [](fasttext::FastText& m) { return m.isQuant(); });
  429. }