| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- /**
- * Copyright (c) 2016-present, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
- #include <emscripten.h>
- #include <emscripten/bind.h>
- #include <fasttext.h>
- #include <functional>
- #include <sstream>
- #include <string>
- #include <vector>
- using namespace emscripten;
- using namespace fasttext;
- struct Float32ArrayBridge {
- uintptr_t ptr;
- int size;
- };
- void fillFloat32ArrayFromVector(
- const Float32ArrayBridge& vecFloat,
- const Vector& v) {
- float* buffer = reinterpret_cast<float*>(vecFloat.ptr);
- assert(vecFloat.size == v.size());
- for (int i = 0; i < v.size(); i++) {
- buffer[i] = v[i];
- }
- }
- std::vector<std::pair<float, std::string>>
- predict(FastText* fasttext, std::string text, int k, double threshold) {
- std::stringstream ioss(text + std::string("\n"));
- std::vector<std::pair<float, std::string>> predictions;
- fasttext->predictLine(ioss, predictions, k, threshold);
- return predictions;
- }
- void getWordVector(
- FastText* fasttext,
- const Float32ArrayBridge& vecFloat,
- std::string word) {
- assert(fasttext);
- Vector v(fasttext->getDimension());
- fasttext->getWordVector(v, word);
- fillFloat32ArrayFromVector(vecFloat, v);
- }
- void getSentenceVector(
- FastText* fasttext,
- const Float32ArrayBridge& vecFloat,
- std::string text) {
- assert(fasttext);
- Vector v(fasttext->getDimension());
- std::stringstream ioss(text);
- fasttext->getSentenceVector(ioss, v);
- fillFloat32ArrayFromVector(vecFloat, v);
- }
- std::pair<std::vector<std::string>, std::vector<int32_t>> getSubwords(
- FastText* fasttext,
- std::string word) {
- assert(fasttext);
- std::vector<std::string> subwords;
- std::vector<int32_t> ngrams;
- std::shared_ptr<const Dictionary> d = fasttext->getDictionary();
- d->getSubwords(word, ngrams, subwords);
- return std::pair<std::vector<std::string>, std::vector<int32_t>>(
- subwords, ngrams);
- }
- void getInputVector(
- FastText* fasttext,
- const Float32ArrayBridge& vecFloat,
- int32_t ind) {
- assert(fasttext);
- Vector v(fasttext->getDimension());
- fasttext->getInputVector(v, ind);
- fillFloat32ArrayFromVector(vecFloat, v);
- }
- void train(FastText* fasttext, Args* args, emscripten::val jsCallback) {
- assert(args);
- assert(fasttext);
- fasttext->train(
- *args,
- [=](float progress, float loss, double wst, double lr, int64_t eta) {
- jsCallback(progress, loss, wst, lr, static_cast<int32_t>(eta));
- });
- }
- const DenseMatrix* getInputMatrix(FastText* fasttext) {
- assert(fasttext);
- std::shared_ptr<const DenseMatrix> mm = fasttext->getInputMatrix();
- return mm.get();
- }
- const DenseMatrix* getOutputMatrix(FastText* fasttext) {
- assert(fasttext);
- std::shared_ptr<const DenseMatrix> mm = fasttext->getOutputMatrix();
- return mm.get();
- }
- std::pair<std::vector<std::string>, std::vector<int32_t>> getTokens(
- const FastText& fasttext,
- const std::function<std::string(const Dictionary&, int32_t)> getter,
- entry_type entryType) {
- std::vector<std::string> tokens;
- std::vector<int32_t> retVocabFrequencies;
- std::shared_ptr<const Dictionary> d = fasttext.getDictionary();
- std::vector<int64_t> vocabFrequencies = d->getCounts(entryType);
- for (int32_t i = 0; i < vocabFrequencies.size(); i++) {
- tokens.push_back(getter(*d, i));
- retVocabFrequencies.push_back(vocabFrequencies[i]);
- }
- return std::pair<std::vector<std::string>, std::vector<int32_t>>(
- tokens, retVocabFrequencies);
- }
- std::pair<std::vector<std::string>, std::vector<int32_t>> getWords(
- FastText* fasttext) {
- assert(fasttext);
- return getTokens(*fasttext, &Dictionary::getWord, entry_type::word);
- }
- std::pair<std::vector<std::string>, std::vector<int32_t>> getLabels(
- FastText* fasttext) {
- assert(fasttext);
- return getTokens(*fasttext, &Dictionary::getLabel, entry_type::label);
- }
- std::pair<std::vector<std::string>, std::vector<std::string>> getLine(
- FastText* fasttext,
- const std::string text) {
- assert(fasttext);
- std::shared_ptr<const Dictionary> d = fasttext->getDictionary();
- std::stringstream ioss(text);
- std::string token;
- std::vector<std::string> words;
- std::vector<std::string> labels;
- while (d->readWord(ioss, token)) {
- uint32_t h = d->hash(token);
- int32_t wid = d->getId(token, h);
- entry_type type = wid < 0 ? d->getType(token) : d->getType(wid);
- if (type == entry_type::word) {
- words.push_back(token);
- } else if (type == entry_type::label && wid >= 0) {
- labels.push_back(token);
- }
- if (token == Dictionary::EOS)
- break;
- }
- return std::pair<std::vector<std::string>, std::vector<std::string>>(
- words, labels);
- }
- Meter test(
- FastText* fasttext,
- const std::string& filename,
- int32_t k,
- float threshold) {
- assert(fasttext);
- std::ifstream ifs(filename);
- if (!ifs.is_open()) {
- throw std::invalid_argument("Test file cannot be opened!");
- }
- Meter meter(false);
- fasttext->test(ifs, k, threshold, meter);
- ifs.close();
- return meter;
- }
- EMSCRIPTEN_BINDINGS(fasttext) {
- class_<Args>("Args")
- .constructor<>()
- .property("input", &Args::input)
- .property("output", &Args::output)
- .property("lr", &Args::lr)
- .property("lrUpdateRate", &Args::lrUpdateRate)
- .property("dim", &Args::dim)
- .property("ws", &Args::ws)
- .property("epoch", &Args::epoch)
- .property("minCount", &Args::minCount)
- .property("minCountLabel", &Args::minCountLabel)
- .property("neg", &Args::neg)
- .property("wordNgrams", &Args::wordNgrams)
- .property("loss", &Args::loss)
- .property("model", &Args::model)
- .property("bucket", &Args::bucket)
- .property("minn", &Args::minn)
- .property("maxn", &Args::maxn)
- .property("thread", &Args::thread)
- .property("t", &Args::t)
- .property("label", &Args::label)
- .property("verbose", &Args::verbose)
- .property("pretrainedVectors", &Args::pretrainedVectors)
- .property("saveOutput", &Args::saveOutput)
- .property("seed", &Args::seed)
- .property("qout", &Args::qout)
- .property("retrain", &Args::retrain)
- .property("qnorm", &Args::qnorm)
- .property("cutoff", &Args::cutoff)
- .property("dsub", &Args::dsub)
- .property("qnorm", &Args::qnorm)
- .property("autotuneValidationFile", &Args::autotuneValidationFile)
- .property("autotuneMetric", &Args::autotuneMetric)
- .property("autotunePredictions", &Args::autotunePredictions)
- .property("autotuneDuration", &Args::autotuneDuration)
- .property("autotuneModelSize", &Args::autotuneModelSize);
- class_<FastText>("FastText")
- .constructor<>()
- .function(
- "loadModel",
- select_overload<void(const std::string&)>(&FastText::loadModel))
- .function(
- "getNN",
- select_overload<std::vector<std::pair<real, std::string>>(
- const std::string& word, int32_t k)>(&FastText::getNN))
- .function("getAnalogies", &FastText::getAnalogies)
- .function("getWordId", &FastText::getWordId)
- .function("getSubwordId", &FastText::getSubwordId)
- .function("getInputMatrix", &getInputMatrix, allow_raw_pointers())
- .function("getOutputMatrix", &getOutputMatrix, allow_raw_pointers())
- .function("getWords", &getWords, allow_raw_pointers())
- .function("getLabels", &getLabels, allow_raw_pointers())
- .function("getLine", &getLine, allow_raw_pointers())
- .function("test", &test, allow_raw_pointers())
- .function("predict", &predict, allow_raw_pointers())
- .function("getWordVector", &getWordVector, allow_raw_pointers())
- .function("getSentenceVector", &getSentenceVector, allow_raw_pointers())
- .function("getSubwords", &getSubwords, allow_raw_pointers())
- .function("getInputVector", &getInputVector, allow_raw_pointers())
- .function("train", &train, allow_raw_pointers())
- .function("saveModel", &FastText::saveModel)
- .property("isQuant", &FastText::isQuant)
- .property("args", &FastText::getArgs);
- class_<DenseMatrix>("DenseMatrix")
- .constructor<>()
- // we return int32_t because "JS can't represent int64s"
- .function(
- "rows",
- optional_override(
- [](const DenseMatrix* self) -> int32_t { return self->rows(); }),
- allow_raw_pointers())
- .function(
- "cols",
- optional_override(
- [](const DenseMatrix* self) -> int32_t { return self->cols(); }),
- allow_raw_pointers())
- .function(
- "at",
- optional_override(
- [](const DenseMatrix* self, int32_t i, int32_t j) -> const float {
- return self->at(i, j);
- }),
- allow_raw_pointers());
- class_<Meter>("Meter")
- .constructor<bool>()
- .property(
- "precision", select_overload<double(void) const>(&Meter::precision))
- .property("recall", select_overload<double(void) const>(&Meter::recall))
- .property("f1Score", select_overload<double(void) const>(&Meter::f1Score))
- .function(
- "nexamples",
- optional_override(
- [](const Meter* self) -> int32_t { return self->nexamples(); }),
- allow_raw_pointers());
- enum_<model_name>("ModelName")
- .value("cbow", model_name::cbow)
- .value("skipgram", model_name::sg)
- .value("supervised", model_name::sup);
- enum_<loss_name>("LossName")
- .value("hs", loss_name::hs)
- .value("ns", loss_name::ns)
- .value("softmax", loss_name::softmax)
- .value("ova", loss_name::ova);
- emscripten::value_object<Float32ArrayBridge>("Float32ArrayBridge")
- .field("ptr", &Float32ArrayBridge::ptr)
- .field("size", &Float32ArrayBridge::size);
- emscripten::value_array<std::pair<float, std::string>>(
- "std::pair<float, std::string>")
- .element(&std::pair<float, std::string>::first)
- .element(&std::pair<float, std::string>::second);
- emscripten::register_vector<std::pair<float, std::string>>(
- "std::vector<std::pair<float, std::string>>");
- emscripten::value_array<
- std::pair<std::vector<std::string>, std::vector<int32_t>>>(
- "std::pair<std::vector<std::string>, std::vector<int32_t>>")
- .element(
- &std::pair<std::vector<std::string>, std::vector<int32_t>>::first)
- .element(
- &std::pair<std::vector<std::string>, std::vector<int32_t>>::second);
- emscripten::value_array<
- std::pair<std::vector<std::string>, std::vector<std::string>>>(
- "std::pair<std::vector<std::string>, std::vector<std::string>>")
- .element(
- &std::pair<std::vector<std::string>, std::vector<std::string>>::first)
- .element(&std::pair<std::vector<std::string>, std::vector<std::string>>::
- second);
- emscripten::register_vector<float>("std::vector<float>");
- emscripten::register_vector<int32_t>("std::vector<int32_t>");
- emscripten::register_vector<std::string>("std::vector<std::string>");
- }
|