|
|
@@ -17,6 +17,7 @@
|
|
|
#include <thread>
|
|
|
#include <string>
|
|
|
#include <vector>
|
|
|
+#include <queue>
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
@@ -425,6 +426,87 @@ void FastText::printSentenceVectors() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+void FastText::precomputeWordVectors(Matrix& wordVectors) {
|
|
|
+ Vector vec(args_->dim);
|
|
|
+ wordVectors.zero();
|
|
|
+ std::cout << "Pre-computing word vectors...";
|
|
|
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
|
|
|
+ std::string word = dict_->getWord(i);
|
|
|
+ getVector(vec, word);
|
|
|
+ real norm = vec.norm();
|
|
|
+ wordVectors.addRow(vec, i, 1.0 / norm);
|
|
|
+ }
|
|
|
+ std::cout << " done." << std::endl;
|
|
|
+}
|
|
|
+
|
|
|
+void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec,
|
|
|
+ int32_t k, const std::set<std::string>& banSet) {
|
|
|
+ real queryNorm = queryVec.norm();
|
|
|
+ if (std::abs(queryNorm) < 1e-8) {
|
|
|
+ queryNorm = 1;
|
|
|
+ }
|
|
|
+ std::priority_queue<std::pair<real, std::string>> heap;
|
|
|
+ Vector vec(args_->dim);
|
|
|
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
|
|
|
+ std::string word = dict_->getWord(i);
|
|
|
+ real dp = wordVectors.dotRow(queryVec, i);
|
|
|
+ heap.push(std::make_pair(dp / queryNorm, word));
|
|
|
+ }
|
|
|
+ int32_t i = 0;
|
|
|
+ while (i < k && heap.size() > 0) {
|
|
|
+ auto it = banSet.find(heap.top().second);
|
|
|
+ if (it == banSet.end()) {
|
|
|
+ std::cout << heap.top().second << " " << heap.top().first << std::endl;
|
|
|
+ i++;
|
|
|
+ }
|
|
|
+ heap.pop();
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void FastText::nn(int32_t k) {
|
|
|
+ std::string queryWord;
|
|
|
+ Vector queryVec(args_->dim);
|
|
|
+ Matrix wordVectors(dict_->nwords(), args_->dim);
|
|
|
+ precomputeWordVectors(wordVectors);
|
|
|
+ std::set<std::string> banSet;
|
|
|
+ std::cout << "Query word? ";
|
|
|
+ while (std::cin >> queryWord) {
|
|
|
+ banSet.clear();
|
|
|
+ banSet.insert(queryWord);
|
|
|
+ getVector(queryVec, queryWord);
|
|
|
+ findNN(wordVectors, queryVec, k, banSet);
|
|
|
+ std::cout << "Query word? ";
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void FastText::analogies(int32_t k) {
|
|
|
+ std::string word;
|
|
|
+ Vector buffer(args_->dim), query(args_->dim);
|
|
|
+ Matrix wordVectors(dict_->nwords(), args_->dim);
|
|
|
+ precomputeWordVectors(wordVectors);
|
|
|
+ std::set<std::string> banSet;
|
|
|
+ std::cout << "Query triplet (A - B + C)? ";
|
|
|
+ while (true) {
|
|
|
+ banSet.clear();
|
|
|
+ query.zero();
|
|
|
+ std::cin >> word;
|
|
|
+ banSet.insert(word);
|
|
|
+ getVector(buffer, word);
|
|
|
+ query.addVector(buffer, 1.0);
|
|
|
+ std::cin >> word;
|
|
|
+ banSet.insert(word);
|
|
|
+ getVector(buffer, word);
|
|
|
+ query.addVector(buffer, -1.0);
|
|
|
+ std::cin >> word;
|
|
|
+ banSet.insert(word);
|
|
|
+ getVector(buffer, word);
|
|
|
+ query.addVector(buffer, 1.0);
|
|
|
+
|
|
|
+ findNN(wordVectors, query, k, banSet);
|
|
|
+ std::cout << "Query triplet (A - B + C)? ";
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
void FastText::trainThread(int32_t threadId) {
|
|
|
std::ifstream ifs(args_->input);
|
|
|
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
|