fasttext_wasm.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <emscripten.h>
  9. #include <emscripten/bind.h>
  10. #include <fasttext.h>
  11. #include <functional>
  12. #include <sstream>
  13. #include <string>
  14. #include <vector>
  15. using namespace emscripten;
  16. using namespace fasttext;
  17. struct Float32ArrayBridge {
  18. uintptr_t ptr;
  19. int size;
  20. };
  21. void fillFloat32ArrayFromVector(
  22. const Float32ArrayBridge& vecFloat,
  23. const Vector& v) {
  24. float* buffer = reinterpret_cast<float*>(vecFloat.ptr);
  25. assert(vecFloat.size == v.size());
  26. for (int i = 0; i < v.size(); i++) {
  27. buffer[i] = v[i];
  28. }
  29. }
  30. std::vector<std::pair<float, std::string>>
  31. predict(FastText* fasttext, std::string text, int k, double threshold) {
  32. std::stringstream ioss(text + std::string("\n"));
  33. std::vector<std::pair<float, std::string>> predictions;
  34. fasttext->predictLine(ioss, predictions, k, threshold);
  35. return predictions;
  36. }
  37. void getWordVector(
  38. FastText* fasttext,
  39. const Float32ArrayBridge& vecFloat,
  40. std::string word) {
  41. assert(fasttext);
  42. Vector v(fasttext->getDimension());
  43. fasttext->getWordVector(v, word);
  44. fillFloat32ArrayFromVector(vecFloat, v);
  45. }
  46. void getSentenceVector(
  47. FastText* fasttext,
  48. const Float32ArrayBridge& vecFloat,
  49. std::string text) {
  50. assert(fasttext);
  51. Vector v(fasttext->getDimension());
  52. std::stringstream ioss(text);
  53. fasttext->getSentenceVector(ioss, v);
  54. fillFloat32ArrayFromVector(vecFloat, v);
  55. }
  56. std::pair<std::vector<std::string>, std::vector<int32_t>> getSubwords(
  57. FastText* fasttext,
  58. std::string word) {
  59. assert(fasttext);
  60. std::vector<std::string> subwords;
  61. std::vector<int32_t> ngrams;
  62. std::shared_ptr<const Dictionary> d = fasttext->getDictionary();
  63. d->getSubwords(word, ngrams, subwords);
  64. return std::pair<std::vector<std::string>, std::vector<int32_t>>(
  65. subwords, ngrams);
  66. }
  67. void getInputVector(
  68. FastText* fasttext,
  69. const Float32ArrayBridge& vecFloat,
  70. int32_t ind) {
  71. assert(fasttext);
  72. Vector v(fasttext->getDimension());
  73. fasttext->getInputVector(v, ind);
  74. fillFloat32ArrayFromVector(vecFloat, v);
  75. }
  76. void train(FastText* fasttext, Args* args, emscripten::val jsCallback) {
  77. assert(args);
  78. assert(fasttext);
  79. fasttext->train(
  80. *args,
  81. [=](float progress, float loss, double wst, double lr, int64_t eta) {
  82. jsCallback(progress, loss, wst, lr, static_cast<int32_t>(eta));
  83. });
  84. }
  85. const DenseMatrix* getInputMatrix(FastText* fasttext) {
  86. assert(fasttext);
  87. std::shared_ptr<const DenseMatrix> mm = fasttext->getInputMatrix();
  88. return mm.get();
  89. }
  90. const DenseMatrix* getOutputMatrix(FastText* fasttext) {
  91. assert(fasttext);
  92. std::shared_ptr<const DenseMatrix> mm = fasttext->getOutputMatrix();
  93. return mm.get();
  94. }
  95. std::pair<std::vector<std::string>, std::vector<int32_t>> getTokens(
  96. const FastText& fasttext,
  97. const std::function<std::string(const Dictionary&, int32_t)> getter,
  98. entry_type entryType) {
  99. std::vector<std::string> tokens;
  100. std::vector<int32_t> retVocabFrequencies;
  101. std::shared_ptr<const Dictionary> d = fasttext.getDictionary();
  102. std::vector<int64_t> vocabFrequencies = d->getCounts(entryType);
  103. for (int32_t i = 0; i < vocabFrequencies.size(); i++) {
  104. tokens.push_back(getter(*d, i));
  105. retVocabFrequencies.push_back(vocabFrequencies[i]);
  106. }
  107. return std::pair<std::vector<std::string>, std::vector<int32_t>>(
  108. tokens, retVocabFrequencies);
  109. }
  110. std::pair<std::vector<std::string>, std::vector<int32_t>> getWords(
  111. FastText* fasttext) {
  112. assert(fasttext);
  113. return getTokens(*fasttext, &Dictionary::getWord, entry_type::word);
  114. }
  115. std::pair<std::vector<std::string>, std::vector<int32_t>> getLabels(
  116. FastText* fasttext) {
  117. assert(fasttext);
  118. return getTokens(*fasttext, &Dictionary::getLabel, entry_type::label);
  119. }
  120. std::pair<std::vector<std::string>, std::vector<std::string>> getLine(
  121. FastText* fasttext,
  122. const std::string text) {
  123. assert(fasttext);
  124. std::shared_ptr<const Dictionary> d = fasttext->getDictionary();
  125. std::stringstream ioss(text);
  126. std::string token;
  127. std::vector<std::string> words;
  128. std::vector<std::string> labels;
  129. while (d->readWord(ioss, token)) {
  130. uint32_t h = d->hash(token);
  131. int32_t wid = d->getId(token, h);
  132. entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
  133. if (type == entry_type::word) {
  134. words.push_back(token);
  135. } else if (type == entry_type::label && wid >= 0) {
  136. labels.push_back(token);
  137. }
  138. if (token == Dictionary::EOS)
  139. break;
  140. }
  141. return std::pair<std::vector<std::string>, std::vector<std::string>>(
  142. words, labels);
  143. }
  144. Meter test(
  145. FastText* fasttext,
  146. const std::string& filename,
  147. int32_t k,
  148. float threshold) {
  149. assert(fasttext);
  150. std::ifstream ifs(filename);
  151. if (!ifs.is_open()) {
  152. throw std::invalid_argument("Test file cannot be opened!");
  153. }
  154. Meter meter;
  155. fasttext->test(ifs, k, threshold, meter);
  156. ifs.close();
  157. return meter;
  158. }
  159. EMSCRIPTEN_BINDINGS(fasttext) {
  160. class_<Args>("Args")
  161. .constructor<>()
  162. .property("input", &Args::input)
  163. .property("output", &Args::output)
  164. .property("lr", &Args::lr)
  165. .property("lrUpdateRate", &Args::lrUpdateRate)
  166. .property("dim", &Args::dim)
  167. .property("ws", &Args::ws)
  168. .property("epoch", &Args::epoch)
  169. .property("minCount", &Args::minCount)
  170. .property("minCountLabel", &Args::minCountLabel)
  171. .property("neg", &Args::neg)
  172. .property("wordNgrams", &Args::wordNgrams)
  173. .property("loss", &Args::loss)
  174. .property("model", &Args::model)
  175. .property("bucket", &Args::bucket)
  176. .property("minn", &Args::minn)
  177. .property("maxn", &Args::maxn)
  178. .property("thread", &Args::thread)
  179. .property("t", &Args::t)
  180. .property("label", &Args::label)
  181. .property("verbose", &Args::verbose)
  182. .property("pretrainedVectors", &Args::pretrainedVectors)
  183. .property("saveOutput", &Args::saveOutput)
  184. .property("seed", &Args::seed)
  185. .property("qout", &Args::qout)
  186. .property("retrain", &Args::retrain)
  187. .property("qnorm", &Args::qnorm)
  188. .property("cutoff", &Args::cutoff)
  189. .property("dsub", &Args::dsub)
  190. .property("qnorm", &Args::qnorm)
  191. .property("autotuneValidationFile", &Args::autotuneValidationFile)
  192. .property("autotuneMetric", &Args::autotuneMetric)
  193. .property("autotunePredictions", &Args::autotunePredictions)
  194. .property("autotuneDuration", &Args::autotuneDuration)
  195. .property("autotuneModelSize", &Args::autotuneModelSize);
  196. class_<FastText>("FastText")
  197. .constructor<>()
  198. .function(
  199. "loadModel",
  200. select_overload<void(const std::string&)>(&FastText::loadModel))
  201. .function(
  202. "getNN",
  203. select_overload<std::vector<std::pair<real, std::string>>(
  204. const std::string& word, int32_t k)>(&FastText::getNN))
  205. .function("getAnalogies", &FastText::getAnalogies)
  206. .function("getWordId", &FastText::getWordId)
  207. .function("getSubwordId", &FastText::getSubwordId)
  208. .function("getInputMatrix", &getInputMatrix, allow_raw_pointers())
  209. .function("getOutputMatrix", &getOutputMatrix, allow_raw_pointers())
  210. .function("getWords", &getWords, allow_raw_pointers())
  211. .function("getLabels", &getLabels, allow_raw_pointers())
  212. .function("getLine", &getLine, allow_raw_pointers())
  213. .function("test", &test, allow_raw_pointers())
  214. .function("predict", &predict, allow_raw_pointers())
  215. .function("getWordVector", &getWordVector, allow_raw_pointers())
  216. .function("getSentenceVector", &getSentenceVector, allow_raw_pointers())
  217. .function("getSubwords", &getSubwords, allow_raw_pointers())
  218. .function("getInputVector", &getInputVector, allow_raw_pointers())
  219. .function("train", &train, allow_raw_pointers())
  220. .function("saveModel", &FastText::saveModel)
  221. .property("isQuant", &FastText::isQuant)
  222. .property("args", &FastText::getArgs);
  223. class_<DenseMatrix>("DenseMatrix")
  224. .constructor<>()
  225. // we return int32_t because "JS can't represent int64s"
  226. .function(
  227. "rows",
  228. optional_override(
  229. [](const DenseMatrix* self) -> int32_t { return self->rows(); }),
  230. allow_raw_pointers())
  231. .function(
  232. "cols",
  233. optional_override(
  234. [](const DenseMatrix* self) -> int32_t { return self->cols(); }),
  235. allow_raw_pointers())
  236. .function(
  237. "at",
  238. optional_override(
  239. [](const DenseMatrix* self, int32_t i, int32_t j) -> const float {
  240. return self->at(i, j);
  241. }),
  242. allow_raw_pointers());
  243. class_<Meter>("Meter")
  244. .constructor<>()
  245. .property(
  246. "precision", select_overload<double(void) const>(&Meter::precision))
  247. .property("recall", select_overload<double(void) const>(&Meter::recall))
  248. .property("f1Score", select_overload<double(void) const>(&Meter::f1Score))
  249. .function(
  250. "nexamples",
  251. optional_override(
  252. [](const Meter* self) -> int32_t { return self->nexamples(); }),
  253. allow_raw_pointers());
  254. enum_<model_name>("ModelName")
  255. .value("cbow", model_name::cbow)
  256. .value("skipgram", model_name::sg)
  257. .value("supervised", model_name::sup);
  258. enum_<loss_name>("LossName")
  259. .value("hs", loss_name::hs)
  260. .value("ns", loss_name::ns)
  261. .value("softmax", loss_name::softmax)
  262. .value("ova", loss_name::ova);
  263. emscripten::value_object<Float32ArrayBridge>("Float32ArrayBridge")
  264. .field("ptr", &Float32ArrayBridge::ptr)
  265. .field("size", &Float32ArrayBridge::size);
  266. emscripten::value_array<std::pair<float, std::string>>(
  267. "std::pair<float, std::string>")
  268. .element(&std::pair<float, std::string>::first)
  269. .element(&std::pair<float, std::string>::second);
  270. emscripten::register_vector<std::pair<float, std::string>>(
  271. "std::vector<std::pair<float, std::string>>");
  272. emscripten::value_array<
  273. std::pair<std::vector<std::string>, std::vector<int32_t>>>(
  274. "std::pair<std::vector<std::string>, std::vector<int32_t>>")
  275. .element(
  276. &std::pair<std::vector<std::string>, std::vector<int32_t>>::first)
  277. .element(
  278. &std::pair<std::vector<std::string>, std::vector<int32_t>>::second);
  279. emscripten::value_array<
  280. std::pair<std::vector<std::string>, std::vector<std::string>>>(
  281. "std::pair<std::vector<std::string>, std::vector<std::string>>")
  282. .element(
  283. &std::pair<std::vector<std::string>, std::vector<std::string>>::first)
  284. .element(&std::pair<std::vector<std::string>, std::vector<std::string>>::
  285. second);
  286. emscripten::register_vector<float>("std::vector<float>");
  287. emscripten::register_vector<int32_t>("std::vector<int32_t>");
  288. emscripten::register_vector<std::string>("std::vector<std::string>");
  289. }