Просмотр исходного кода

Added nearest neighbor queries and analogies.

Summary:
Added two functions to fasttext:
./fasttext nn <model>
./fasttext analogies <model>
Both start an interactive prompt that ask for a query.

Reviewed By: EdouardGrave

Differential Revision: D4357150

fbshipit-source-id: 294d9e7ee4c0c9b777f227ba3cc589af86855ea7
Piotr Bojanowski 8 лет назад
Родитель
Сommit
a1b23749e9
6 измененных файлов с 154 добавлено и 2 удалено
  1. 1 0
      src/dictionary.cc
  2. 82 0
      src/fasttext.cc
  3. 6 0
      src/fasttext.h
  4. 54 0
      src/main.cc
  5. 9 1
      src/vector.cc
  6. 2 1
      src/vector.h

+ 1 - 0
src/dictionary.cc

@@ -15,6 +15,7 @@
 #include <fstream>
 #include <algorithm>
 #include <iterator>
+#include <cmath>
 
 namespace fasttext {
 

+ 82 - 0
src/fasttext.cc

@@ -17,6 +17,7 @@
 #include <thread>
 #include <string>
 #include <vector>
+#include <queue>
 #include <algorithm>
 
 
@@ -425,6 +426,87 @@ void FastText::printSentenceVectors() {
   }
 }
 
+void FastText::precomputeWordVectors(Matrix& wordVectors) {
+  Vector vec(args_->dim);
+  wordVectors.zero();
+  std::cout << "Pre-computing word vectors...";
+  for (int32_t i = 0; i < dict_->nwords(); i++) {
+    std::string word = dict_->getWord(i);
+    getVector(vec, word);
+    real norm = vec.norm();
+    wordVectors.addRow(vec, i, 1.0 / norm);
+  }
+  std::cout << " done." << std::endl;
+}
+
+void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec,
+                      int32_t k, const std::set<std::string>& banSet) {
+  real queryNorm = queryVec.norm();
+  if (std::abs(queryNorm) < 1e-8) {
+    queryNorm = 1;
+  }
+  std::priority_queue<std::pair<real, std::string>> heap;
+  Vector vec(args_->dim);
+  for (int32_t i = 0; i < dict_->nwords(); i++) {
+    std::string word = dict_->getWord(i);
+    real dp = wordVectors.dotRow(queryVec, i);
+    heap.push(std::make_pair(dp / queryNorm, word));
+  }
+  int32_t i = 0;
+  while (i < k && heap.size() > 0) {
+    auto it = banSet.find(heap.top().second);
+    if (it == banSet.end()) {
+      std::cout << heap.top().second << " " << heap.top().first << std::endl;
+      i++;
+    }
+    heap.pop();
+  }
+}
+
+void FastText::nn(int32_t k) {
+  std::string queryWord;
+  Vector queryVec(args_->dim);
+  Matrix wordVectors(dict_->nwords(), args_->dim);
+  precomputeWordVectors(wordVectors);
+  std::set<std::string> banSet;
+  std::cout << "Query word? ";
+  while (std::cin >> queryWord) {
+    banSet.clear();
+    banSet.insert(queryWord);
+    getVector(queryVec, queryWord);
+    findNN(wordVectors, queryVec, k, banSet);
+    std::cout << "Query word? ";
+  }
+}
+
+void FastText::analogies(int32_t k) {
+  std::string word;
+  Vector buffer(args_->dim), query(args_->dim);
+  Matrix wordVectors(dict_->nwords(), args_->dim);
+  precomputeWordVectors(wordVectors);
+  std::set<std::string> banSet;
+  std::cout << "Query triplet (A - B + C)? ";
+  while (true) {
+    banSet.clear();
+    query.zero();
+    std::cin >> word;
+    banSet.insert(word);
+    getVector(buffer, word);
+    query.addVector(buffer, 1.0);
+    std::cin >> word;
+    banSet.insert(word);
+    getVector(buffer, word);
+    query.addVector(buffer, -1.0);
+    std::cin >> word;
+    banSet.insert(word);
+    getVector(buffer, word);
+    query.addVector(buffer, 1.0);
+
+    findNN(wordVectors, query, k, banSet);
+    std::cout << "Query triplet (A - B + C)? ";
+  }
+}
+
 void FastText::trainThread(int32_t threadId) {
   std::ifstream ifs(args_->input);
   utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);

+ 6 - 0
src/fasttext.h

@@ -17,6 +17,7 @@
 
 #include <atomic>
 #include <memory>
+#include <set>
 
 #include "args.h"
 #include "dictionary.h"
@@ -78,6 +79,11 @@ class FastText {
     void textVectors();
     void printWordVectors();
     void printSentenceVectors();
+    void precomputeWordVectors(Matrix&);
+    void findNN(const Matrix&, const Vector&, int32_t,
+                const std::set<std::string>&);
+    void nn(int32_t);
+    void analogies(int32_t);
     void trainThread(int32_t);
     void train(std::shared_ptr<Args>);
 

+ 54 - 0
src/main.cc

@@ -27,6 +27,8 @@ void printUsage() {
     << "  cbow                    train a cbow model\n"
     << "  print-word-vectors      print word vectors given a trained model\n"
     << "  print-sentence-vectors  print sentence vectors given a trained model\n"
+    << "  nn                      query for nearest neighbors\n"
+    << "  analogies               query for analogies\n"
     << std::endl;
 }
 
@@ -89,6 +91,22 @@ void quantize(int argc, char** argv) {
   exit(0);
 }
 
+void printNNUsage() {
+  std::cout
+    << "usage: fasttext nn <model> <k>\n\n"
+    << "  <model>      model filename\n"
+    << "  <k>          (optional; 10 by default) predict top k labels\n"
+    << std::endl;
+}
+
+void printAnalogiesUsage() {
+  std::cout
+    << "usage: fasttext analogies <model> <k>\n\n"
+    << "  <model>      model filename\n"
+    << "  <k>          (optional; 10 by default) predict top k labels\n"
+    << std::endl;
+}
+
 void test(int argc, char** argv) {
   if (argc < 4 || argc > 5) {
     printTestUsage();
@@ -180,6 +198,38 @@ void printNgrams(int argc, char** argv) {
   exit(0);
 }
 
+void nn(int argc, char** argv) {
+  int32_t k;
+  if (argc == 3) {
+    k = 10;
+  } else if (argc == 4) {
+    k = atoi(argv[3]);
+  } else {
+    printNNUsage();
+    exit(EXIT_FAILURE);
+  }
+  FastText fasttext;
+  fasttext.loadModel(std::string(argv[2]));
+  fasttext.nn(k);
+  exit(0);
+}
+
+void analogies(int argc, char** argv) {
+  int32_t k;
+  if (argc == 3) {
+    k = 10;
+  } else if (argc == 4) {
+    k = atoi(argv[3]);
+  } else {
+    printAnalogiesUsage();
+    exit(EXIT_FAILURE);
+  }
+  FastText fasttext;
+  fasttext.loadModel(std::string(argv[2]));
+  fasttext.analogies(k);
+  exit(0);
+}
+
 void train(int argc, char** argv) {
   std::shared_ptr<Args> a = std::make_shared<Args>();
   a->parseArgs(argc, argv);
@@ -205,6 +255,10 @@ int main(int argc, char** argv) {
     printSentenceVectors(argc, argv);
   } else if (command == "print-ngrams") {
     printNgrams(argc, argv);
+  } else if (command == "nn") {
+    nn(argc, argv);
+  } else if (command == "analogies") {
+    analogies(argc, argv);
   } else if (command == "predict" || command == "predict-prob" ) {
     predict(argc, argv);
   } else {

+ 9 - 1
src/vector.cc

@@ -12,6 +12,7 @@
 #include <assert.h>
 
 #include <iomanip>
+#include <cmath>
 
 #include "matrix.h"
 #include "qmatrix.h"
@@ -37,7 +38,7 @@ void Vector::zero() {
   }
 }
 
-real Vector::norm() {
+real Vector::norm() const {
   real sum = 0;
   for (int64_t i = 0; i < m_; i++) {
     sum += data_[i] * data_[i];
@@ -58,6 +59,13 @@ void Vector::addVector(const Vector& source) {
   }
 }
 
+void Vector::addVector(const Vector& source, real s) {
+  assert(m_ == source.m_);
+  for (int64_t i = 0; i < m_; i++) {
+    data_[i] += s * source.data_[i];
+  }
+}
+
 void Vector::addRow(const Matrix& A, int64_t i) {
   assert(i >= 0);
   assert(i < A.m_);

+ 2 - 1
src/vector.h

@@ -35,8 +35,9 @@ class Vector {
     int64_t size() const;
     void zero();
     void mul(real);
-    real norm();
+    real norm() const;
     void addVector(const Vector& source);
+    void addVector(const Vector&, real);
     void addRow(const Matrix&, int64_t);
     void addRow(const QMatrix&, int64_t);
     void addRow(const Matrix&, int64_t, real);