/** * Copyright (c) 2016-present, Facebook, Inc. * All rights reserved. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include using namespace emscripten; using namespace fasttext; struct Float32ArrayBridge { uintptr_t ptr; int size; }; void fillFloat32ArrayFromVector( const Float32ArrayBridge& vecFloat, const Vector& v) { float* buffer = reinterpret_cast(vecFloat.ptr); assert(vecFloat.size == v.size()); for (int i = 0; i < v.size(); i++) { buffer[i] = v[i]; } } std::vector> predict(FastText* fasttext, std::string text, int k, double threshold) { std::stringstream ioss(text + std::string("\n")); std::vector> predictions; fasttext->predictLine(ioss, predictions, k, threshold); return predictions; } void getWordVector( FastText* fasttext, const Float32ArrayBridge& vecFloat, std::string word) { assert(fasttext); Vector v(fasttext->getDimension()); fasttext->getWordVector(v, word); fillFloat32ArrayFromVector(vecFloat, v); } void getSentenceVector( FastText* fasttext, const Float32ArrayBridge& vecFloat, std::string text) { assert(fasttext); Vector v(fasttext->getDimension()); std::stringstream ioss(text); fasttext->getSentenceVector(ioss, v); fillFloat32ArrayFromVector(vecFloat, v); } std::pair, std::vector> getSubwords( FastText* fasttext, std::string word) { assert(fasttext); std::vector subwords; std::vector ngrams; std::shared_ptr d = fasttext->getDictionary(); d->getSubwords(word, ngrams, subwords); return std::pair, std::vector>( subwords, ngrams); } void getInputVector( FastText* fasttext, const Float32ArrayBridge& vecFloat, int32_t ind) { assert(fasttext); Vector v(fasttext->getDimension()); fasttext->getInputVector(v, ind); fillFloat32ArrayFromVector(vecFloat, v); } void train(FastText* fasttext, Args* args, emscripten::val jsCallback) { assert(args); assert(fasttext); fasttext->train( *args, [=](float progress, float loss, double wst, double lr, int64_t eta) { jsCallback(progress, loss, wst, lr, static_cast(eta)); }); } const DenseMatrix* getInputMatrix(FastText* fasttext) { assert(fasttext); std::shared_ptr mm = fasttext->getInputMatrix(); return mm.get(); } const DenseMatrix* getOutputMatrix(FastText* fasttext) { assert(fasttext); std::shared_ptr mm = fasttext->getOutputMatrix(); return mm.get(); } std::pair, std::vector> getTokens( const FastText& fasttext, const std::function getter, entry_type entryType) { std::vector tokens; std::vector retVocabFrequencies; std::shared_ptr d = fasttext.getDictionary(); std::vector vocabFrequencies = d->getCounts(entryType); for (int32_t i = 0; i < vocabFrequencies.size(); i++) { tokens.push_back(getter(*d, i)); retVocabFrequencies.push_back(vocabFrequencies[i]); } return std::pair, std::vector>( tokens, retVocabFrequencies); } std::pair, std::vector> getWords( FastText* fasttext) { assert(fasttext); return getTokens(*fasttext, &Dictionary::getWord, entry_type::word); } std::pair, std::vector> getLabels( FastText* fasttext) { assert(fasttext); return getTokens(*fasttext, &Dictionary::getLabel, entry_type::label); } std::pair, std::vector> getLine( FastText* fasttext, const std::string text) { assert(fasttext); std::shared_ptr d = fasttext->getDictionary(); std::stringstream ioss(text); std::string token; std::vector words; std::vector labels; while (d->readWord(ioss, token)) { uint32_t h = d->hash(token); int32_t wid = d->getId(token, h); entry_type type = wid < 0 ? d->getType(token) : d->getType(wid); if (type == entry_type::word) { words.push_back(token); } else if (type == entry_type::label && wid >= 0) { labels.push_back(token); } if (token == Dictionary::EOS) break; } return std::pair, std::vector>( words, labels); } Meter test( FastText* fasttext, const std::string& filename, int32_t k, float threshold) { assert(fasttext); std::ifstream ifs(filename); if (!ifs.is_open()) { throw std::invalid_argument("Test file cannot be opened!"); } Meter meter(false); fasttext->test(ifs, k, threshold, meter); ifs.close(); return meter; } EMSCRIPTEN_BINDINGS(fasttext) { class_("Args") .constructor<>() .property("input", &Args::input) .property("output", &Args::output) .property("lr", &Args::lr) .property("lrUpdateRate", &Args::lrUpdateRate) .property("dim", &Args::dim) .property("ws", &Args::ws) .property("epoch", &Args::epoch) .property("minCount", &Args::minCount) .property("minCountLabel", &Args::minCountLabel) .property("neg", &Args::neg) .property("wordNgrams", &Args::wordNgrams) .property("loss", &Args::loss) .property("model", &Args::model) .property("bucket", &Args::bucket) .property("minn", &Args::minn) .property("maxn", &Args::maxn) .property("thread", &Args::thread) .property("t", &Args::t) .property("label", &Args::label) .property("verbose", &Args::verbose) .property("pretrainedVectors", &Args::pretrainedVectors) .property("saveOutput", &Args::saveOutput) .property("seed", &Args::seed) .property("qout", &Args::qout) .property("retrain", &Args::retrain) .property("qnorm", &Args::qnorm) .property("cutoff", &Args::cutoff) .property("dsub", &Args::dsub) .property("qnorm", &Args::qnorm) .property("autotuneValidationFile", &Args::autotuneValidationFile) .property("autotuneMetric", &Args::autotuneMetric) .property("autotunePredictions", &Args::autotunePredictions) .property("autotuneDuration", &Args::autotuneDuration) .property("autotuneModelSize", &Args::autotuneModelSize); class_("FastText") .constructor<>() .function( "loadModel", select_overload(&FastText::loadModel)) .function( "getNN", select_overload>( const std::string& word, int32_t k)>(&FastText::getNN)) .function("getAnalogies", &FastText::getAnalogies) .function("getWordId", &FastText::getWordId) .function("getSubwordId", &FastText::getSubwordId) .function("getInputMatrix", &getInputMatrix, allow_raw_pointers()) .function("getOutputMatrix", &getOutputMatrix, allow_raw_pointers()) .function("getWords", &getWords, allow_raw_pointers()) .function("getLabels", &getLabels, allow_raw_pointers()) .function("getLine", &getLine, allow_raw_pointers()) .function("test", &test, allow_raw_pointers()) .function("predict", &predict, allow_raw_pointers()) .function("getWordVector", &getWordVector, allow_raw_pointers()) .function("getSentenceVector", &getSentenceVector, allow_raw_pointers()) .function("getSubwords", &getSubwords, allow_raw_pointers()) .function("getInputVector", &getInputVector, allow_raw_pointers()) .function("train", &train, allow_raw_pointers()) .function("saveModel", &FastText::saveModel) .property("isQuant", &FastText::isQuant) .property("args", &FastText::getArgs); class_("DenseMatrix") .constructor<>() // we return int32_t because "JS can't represent int64s" .function( "rows", optional_override( [](const DenseMatrix* self) -> int32_t { return self->rows(); }), allow_raw_pointers()) .function( "cols", optional_override( [](const DenseMatrix* self) -> int32_t { return self->cols(); }), allow_raw_pointers()) .function( "at", optional_override( [](const DenseMatrix* self, int32_t i, int32_t j) -> const float { return self->at(i, j); }), allow_raw_pointers()); class_("Meter") .constructor() .property( "precision", select_overload(&Meter::precision)) .property("recall", select_overload(&Meter::recall)) .property("f1Score", select_overload(&Meter::f1Score)) .function( "nexamples", optional_override( [](const Meter* self) -> int32_t { return self->nexamples(); }), allow_raw_pointers()); enum_("ModelName") .value("cbow", model_name::cbow) .value("skipgram", model_name::sg) .value("supervised", model_name::sup); enum_("LossName") .value("hs", loss_name::hs) .value("ns", loss_name::ns) .value("softmax", loss_name::softmax) .value("ova", loss_name::ova); emscripten::value_object("Float32ArrayBridge") .field("ptr", &Float32ArrayBridge::ptr) .field("size", &Float32ArrayBridge::size); emscripten::value_array>( "std::pair") .element(&std::pair::first) .element(&std::pair::second); emscripten::register_vector>( "std::vector>"); emscripten::value_array< std::pair, std::vector>>( "std::pair, std::vector>") .element( &std::pair, std::vector>::first) .element( &std::pair, std::vector>::second); emscripten::value_array< std::pair, std::vector>>( "std::pair, std::vector>") .element( &std::pair, std::vector>::first) .element(&std::pair, std::vector>:: second); emscripten::register_vector("std::vector"); emscripten::register_vector("std::vector"); emscripten::register_vector("std::vector"); }