|
|
@@ -24,7 +24,11 @@ namespace fasttext {
|
|
|
constexpr int32_t FASTTEXT_VERSION = 12; /* Version 1b */
|
|
|
constexpr int32_t FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314;
|
|
|
|
|
|
-FastText::FastText() : quant_(false) {}
|
|
|
+bool comparePairs(
|
|
|
+ const std::pair<real, std::string>& l,
|
|
|
+ const std::pair<real, std::string>& r);
|
|
|
+
|
|
|
+FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
|
|
|
|
|
|
void FastText::addInputVector(Vector& vec, int32_t ind) const {
|
|
|
if (quant_) {
|
|
|
@@ -81,11 +85,11 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
|
|
|
addInputVector(vec, h);
|
|
|
}
|
|
|
|
|
|
-void FastText::saveVectors() {
|
|
|
- std::ofstream ofs(args_->output + ".vec");
|
|
|
+void FastText::saveVectors(const std::string& filename) {
|
|
|
+ std::ofstream ofs(filename);
|
|
|
if (!ofs.is_open()) {
|
|
|
throw std::invalid_argument(
|
|
|
- args_->output + ".vec" + " cannot be opened for saving vectors!");
|
|
|
+ filename + " cannot be opened for saving vectors!");
|
|
|
}
|
|
|
ofs << dict_->nwords() << " " << args_->dim << std::endl;
|
|
|
Vector vec(args_->dim);
|
|
|
@@ -97,11 +101,15 @@ void FastText::saveVectors() {
|
|
|
ofs.close();
|
|
|
}
|
|
|
|
|
|
-void FastText::saveOutput() {
|
|
|
- std::ofstream ofs(args_->output + ".output");
|
|
|
+void FastText::saveVectors() {
|
|
|
+ saveVectors(args_->output + ".vec");
|
|
|
+}
|
|
|
+
|
|
|
+void FastText::saveOutput(const std::string& filename) {
|
|
|
+ std::ofstream ofs(filename);
|
|
|
if (!ofs.is_open()) {
|
|
|
throw std::invalid_argument(
|
|
|
- args_->output + ".output" + " cannot be opened for saving vectors!");
|
|
|
+ filename + " cannot be opened for saving vectors!");
|
|
|
}
|
|
|
if (quant_) {
|
|
|
throw std::invalid_argument(
|
|
|
@@ -121,6 +129,10 @@ void FastText::saveOutput() {
|
|
|
ofs.close();
|
|
|
}
|
|
|
|
|
|
+void FastText::saveOutput() {
|
|
|
+ saveOutput(args_->output + ".output");
|
|
|
+}
|
|
|
+
|
|
|
bool FastText::checkModel(std::istream& in) {
|
|
|
int32_t magic;
|
|
|
in.read((char*)&(magic), sizeof(int32_t));
|
|
|
@@ -151,10 +163,10 @@ void FastText::saveModel() {
|
|
|
saveModel(fn);
|
|
|
}
|
|
|
|
|
|
-void FastText::saveModel(const std::string path) {
|
|
|
- std::ofstream ofs(path, std::ofstream::binary);
|
|
|
+void FastText::saveModel(const std::string& filename) {
|
|
|
+ std::ofstream ofs(filename, std::ofstream::binary);
|
|
|
if (!ofs.is_open()) {
|
|
|
- throw std::invalid_argument(path + " cannot be opened for saving!");
|
|
|
+ throw std::invalid_argument(filename + " cannot be opened for saving!");
|
|
|
}
|
|
|
signModel(ofs);
|
|
|
args_->save(ofs);
|
|
|
@@ -410,39 +422,26 @@ void FastText::predict(
|
|
|
model_->predict(words, k, threshold, predictions, hidden, output);
|
|
|
}
|
|
|
|
|
|
-void FastText::predict(
|
|
|
+bool FastText::predictLine(
|
|
|
std::istream& in,
|
|
|
+ std::vector<std::pair<real, std::string>>& predictions,
|
|
|
int32_t k,
|
|
|
- bool print_prob,
|
|
|
- real threshold) {
|
|
|
- std::vector<std::pair<real, int32_t>> predictions;
|
|
|
- while (in.peek() != EOF) {
|
|
|
- std::vector<int32_t> words, labels;
|
|
|
- dict_->getLine(in, words, labels);
|
|
|
- predictions.clear();
|
|
|
- predict(k, words, predictions, threshold);
|
|
|
- if (predictions.empty()) {
|
|
|
- std::cout << std::endl;
|
|
|
- continue;
|
|
|
- }
|
|
|
- for (auto it = predictions.cbegin(); it != predictions.cend(); it++) {
|
|
|
- if (it != predictions.cbegin()) {
|
|
|
- std::cout << " ";
|
|
|
- }
|
|
|
- std::cout << dict_->getLabel(it->second);
|
|
|
- if (print_prob) {
|
|
|
- std::cout << " " << std::exp(it->first);
|
|
|
- }
|
|
|
- }
|
|
|
- std::cout << std::endl;
|
|
|
+ real threshold) const {
|
|
|
+ predictions.clear();
|
|
|
+ if (in.peek() == EOF) {
|
|
|
+ return false;
|
|
|
}
|
|
|
-}
|
|
|
|
|
|
-void FastText::printLabelStats(std::istream& in, int32_t k, real threshold)
|
|
|
- const {
|
|
|
- Meter meter;
|
|
|
- test(in, k, threshold, meter);
|
|
|
- writePerLabelMetrics(std::cout, meter);
|
|
|
+ std::vector<int32_t> words, labels;
|
|
|
+ dict_->getLine(in, words, labels);
|
|
|
+ std::vector<std::pair<real, int32_t>> linePredictions;
|
|
|
+ predict(k, words, linePredictions, threshold);
|
|
|
+ for (const auto& p : linePredictions) {
|
|
|
+ predictions.push_back(
|
|
|
+ std::make_pair(std::exp(p.first), dict_->getLabel(p.second)));
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) {
|
|
|
@@ -478,13 +477,15 @@ void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void FastText::ngramVectors(std::string word) {
|
|
|
+std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
|
|
|
+ const std::string& word) const {
|
|
|
+ std::vector<std::pair<std::string, Vector>> result;
|
|
|
std::vector<int32_t> ngrams;
|
|
|
std::vector<std::string> substrings;
|
|
|
- Vector vec(args_->dim);
|
|
|
dict_->getSubwords(word, ngrams, substrings);
|
|
|
+ assert(ngrams.size() <= substrings.size());
|
|
|
for (int32_t i = 0; i < ngrams.size(); i++) {
|
|
|
- vec.zero();
|
|
|
+ Vector vec(args_->dim);
|
|
|
if (ngrams[i] >= 0) {
|
|
|
if (quant_) {
|
|
|
vec.addRow(*qinput_, ngrams[i]);
|
|
|
@@ -492,7 +493,18 @@ void FastText::ngramVectors(std::string word) {
|
|
|
vec.addRow(*input_, ngrams[i]);
|
|
|
}
|
|
|
}
|
|
|
- std::cout << substrings[i] << " " << vec << std::endl;
|
|
|
+ result.push_back(std::make_pair(substrings[i], std::move(vec)));
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+// deprecated. use getNgramVectors instead
|
|
|
+void FastText::ngramVectors(std::string word) {
|
|
|
+ std::vector<std::pair<std::string, Vector>> ngramVectors =
|
|
|
+ getNgramVectors(word);
|
|
|
+
|
|
|
+ for (const auto& ngramVector : ngramVectors) {
|
|
|
+ std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -509,65 +521,107 @@ void FastText::precomputeWordVectors(Matrix& wordVectors) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void FastText::findNN(
|
|
|
+void FastText::lazyComputeWordVectors() {
|
|
|
+ if (!wordVectors_) {
|
|
|
+ wordVectors_ =
|
|
|
+ std::unique_ptr<Matrix>(new Matrix(dict_->nwords(), args_->dim));
|
|
|
+ precomputeWordVectors(*wordVectors_);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+std::vector<std::pair<real, std::string>> FastText::getNN(
|
|
|
+ const std::string& word,
|
|
|
+ int32_t k) {
|
|
|
+ Vector query(args_->dim);
|
|
|
+
|
|
|
+ getWordVector(query, word);
|
|
|
+
|
|
|
+ lazyComputeWordVectors();
|
|
|
+ assert(wordVectors_);
|
|
|
+ return getNN(*wordVectors_, query, k, {word});
|
|
|
+}
|
|
|
+
|
|
|
+std::vector<std::pair<real, std::string>> FastText::getNN(
|
|
|
const Matrix& wordVectors,
|
|
|
- const Vector& queryVec,
|
|
|
+ const Vector& query,
|
|
|
int32_t k,
|
|
|
- const std::set<std::string>& banSet,
|
|
|
- std::vector<std::pair<real, std::string>>& results) {
|
|
|
- results.clear();
|
|
|
- std::priority_queue<std::pair<real, std::string>> heap;
|
|
|
- real queryNorm = queryVec.norm();
|
|
|
+ const std::set<std::string>& banSet) {
|
|
|
+ std::vector<std::pair<real, std::string>> heap;
|
|
|
+
|
|
|
+ real queryNorm = query.norm();
|
|
|
if (std::abs(queryNorm) < 1e-8) {
|
|
|
queryNorm = 1;
|
|
|
}
|
|
|
- Vector vec(args_->dim);
|
|
|
+
|
|
|
for (int32_t i = 0; i < dict_->nwords(); i++) {
|
|
|
std::string word = dict_->getWord(i);
|
|
|
- real dp = wordVectors.dotRow(queryVec, i);
|
|
|
- heap.push(std::make_pair(dp / queryNorm, word));
|
|
|
- }
|
|
|
- int32_t i = 0;
|
|
|
- while (i < k && heap.size() > 0) {
|
|
|
- auto it = banSet.find(heap.top().second);
|
|
|
- if (it == banSet.end()) {
|
|
|
- results.push_back(
|
|
|
- std::pair<real, std::string>(heap.top().first, heap.top().second));
|
|
|
- i++;
|
|
|
+ if (banSet.find(word) == banSet.end()) {
|
|
|
+ real dp = wordVectors.dotRow(query, i);
|
|
|
+ real similarity = dp / queryNorm;
|
|
|
+
|
|
|
+ if (heap.size() == k && similarity < heap.front().first) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ heap.push_back(std::make_pair(similarity, word));
|
|
|
+ std::push_heap(heap.begin(), heap.end(), comparePairs);
|
|
|
+ if (heap.size() > k) {
|
|
|
+ std::pop_heap(heap.begin(), heap.end(), comparePairs);
|
|
|
+ heap.pop_back();
|
|
|
+ }
|
|
|
}
|
|
|
- heap.pop();
|
|
|
}
|
|
|
+ std::sort_heap(heap.begin(), heap.end(), comparePairs);
|
|
|
+
|
|
|
+ return heap;
|
|
|
+}
|
|
|
+
|
|
|
+// depracted. use getNN instead
|
|
|
+void FastText::findNN(
|
|
|
+ const Matrix& wordVectors,
|
|
|
+ const Vector& query,
|
|
|
+ int32_t k,
|
|
|
+ const std::set<std::string>& banSet,
|
|
|
+ std::vector<std::pair<real, std::string>>& results) {
|
|
|
+ results.clear();
|
|
|
+ results = getNN(wordVectors, query, k, banSet);
|
|
|
}
|
|
|
|
|
|
+std::vector<std::pair<real, std::string>> FastText::getAnalogies(
|
|
|
+ int32_t k,
|
|
|
+ const std::string& wordA,
|
|
|
+ const std::string& wordB,
|
|
|
+ const std::string& wordC) {
|
|
|
+ Vector query = Vector(args_->dim);
|
|
|
+ query.zero();
|
|
|
+
|
|
|
+ Vector buffer(args_->dim);
|
|
|
+ getWordVector(buffer, wordA);
|
|
|
+ query.addVector(buffer, 1.0);
|
|
|
+ getWordVector(buffer, wordB);
|
|
|
+ query.addVector(buffer, -1.0);
|
|
|
+ getWordVector(buffer, wordC);
|
|
|
+ query.addVector(buffer, 1.0);
|
|
|
+
|
|
|
+ lazyComputeWordVectors();
|
|
|
+ assert(wordVectors_);
|
|
|
+ return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
|
|
|
+}
|
|
|
+
|
|
|
+// depreacted, use getAnalogies instead
|
|
|
void FastText::analogies(int32_t k) {
|
|
|
- std::string word;
|
|
|
- Vector buffer(args_->dim), query(args_->dim);
|
|
|
- Matrix wordVectors(dict_->nwords(), args_->dim);
|
|
|
- precomputeWordVectors(wordVectors);
|
|
|
- std::set<std::string> banSet;
|
|
|
- std::cout << "Query triplet (A - B + C)? ";
|
|
|
- std::vector<std::pair<real, std::string>> results;
|
|
|
+ std::string prompt("Query triplet (A - B + C)? ");
|
|
|
+ std::string wordA, wordB, wordC;
|
|
|
+ std::cout << prompt;
|
|
|
while (true) {
|
|
|
- banSet.clear();
|
|
|
- query.zero();
|
|
|
- std::cin >> word;
|
|
|
- banSet.insert(word);
|
|
|
- getWordVector(buffer, word);
|
|
|
- query.addVector(buffer, 1.0);
|
|
|
- std::cin >> word;
|
|
|
- banSet.insert(word);
|
|
|
- getWordVector(buffer, word);
|
|
|
- query.addVector(buffer, -1.0);
|
|
|
- std::cin >> word;
|
|
|
- banSet.insert(word);
|
|
|
- getWordVector(buffer, word);
|
|
|
- query.addVector(buffer, 1.0);
|
|
|
-
|
|
|
- findNN(wordVectors, query, k, banSet, results);
|
|
|
+ std::cin >> wordA;
|
|
|
+ std::cin >> wordB;
|
|
|
+ std::cin >> wordC;
|
|
|
+ auto results = getAnalogies(k, wordA, wordB, wordC);
|
|
|
+
|
|
|
for (auto& pair : results) {
|
|
|
std::cout << pair.second << " " << pair.first << std::endl;
|
|
|
}
|
|
|
- std::cout << "Query triplet (A - B + C)? ";
|
|
|
+ std::cout << prompt;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -727,26 +781,10 @@ bool FastText::isQuant() const {
|
|
|
return quant_;
|
|
|
}
|
|
|
|
|
|
-void FastText::writePerLabelMetrics(std::ostream& out, Meter& meter) const {
|
|
|
- out << std::fixed;
|
|
|
- out << std::setprecision(6);
|
|
|
-
|
|
|
- auto writeMetric = [&](const std::string& name, double value) {
|
|
|
- out << name << " : ";
|
|
|
- if (std::isfinite(value)) {
|
|
|
- out << value;
|
|
|
- } else {
|
|
|
- out << "--------";
|
|
|
- }
|
|
|
- out << " ";
|
|
|
- };
|
|
|
-
|
|
|
- for (int32_t i = 0; i < dict_->nlabels(); i++) {
|
|
|
- writeMetric("F1-Score", meter.f1Score(i));
|
|
|
- writeMetric("Precision", meter.precision(i));
|
|
|
- writeMetric("Recall", meter.recall(i));
|
|
|
- out << " " << dict_->getLabel(i) << std::endl;
|
|
|
- }
|
|
|
+bool comparePairs(
|
|
|
+ const std::pair<real, std::string>& l,
|
|
|
+ const std::pair<real, std::string>& r) {
|
|
|
+ return l.first > r.first;
|
|
|
}
|
|
|
|
|
|
} // namespace fasttext
|