Ver código fonte

print averaged sentence level vectors

Summary: [fastText] print averaged sentence level vectors

Reviewed By: EdouardGrave

Differential Revision: D4798156

fbshipit-source-id: ec561c997e0cda5705d6554dc9943fca118bb224
Christian Puhrsch 8 anos atrás
pai
commit
a7c2479342
7 arquivos alterados com 87 adições e 22 exclusões
  1. 3 3
      README.md
  2. 1 1
      src/dictionary.cc
  3. 27 2
      src/fasttext.cc
  4. 3 1
      src/fasttext.h
  5. 36 15
      src/main.cc
  6. 15 0
      src/vector.cc
  7. 2 0
      src/vector.h

+ 3 - 3
README.md

@@ -55,14 +55,14 @@ The previously trained model can be used to compute word vectors for out-of-voca
 Provided you have a text file `queries.txt` containing words for which you want to compute vectors, use the following command:
 
 ```
-$ ./fasttext print-vectors model.bin < queries.txt
+$ ./fasttext print-word-vectors model.bin < queries.txt
 ```
 
 This will output word vectors to the standard output, one vector per line.
 This can also be used with pipes:
 
 ```
-$ cat queries.txt | ./fasttext print-vectors model.bin
+$ cat queries.txt | ./fasttext print-word-vectors model.bin
 ```
 
 See the provided scripts for an example. For instance, running:
@@ -108,7 +108,7 @@ In order to reproduce results from the paper [2](#bag-of-tricks-for-efficient-te
 If you want to compute vector representations of sentences or paragraphs, please use:
 
 ```
-$ ./fasttext print-vectors model.bin < text.txt
+$ ./fasttext print-sentence-vectors model.bin < text.txt
 ```
 
 This assumes that the `text.txt` file contains the paragraphs that you want to get vectors for.

+ 1 - 1
src/dictionary.cc

@@ -257,7 +257,7 @@ void Dictionary::initTableDiscard() {
   pdiscard_.resize(size_);
   for (size_t i = 0; i < size_; i++) {
     real f = real(words_[i].count) / real(ntokens_);
-    pdiscard_[i] = sqrt(args_->t / f) + args_->t / f;
+    pdiscard_[i] = std::sqrt(args_->t / f) + args_->t / f;
   }
 }
 

+ 27 - 2
src/fasttext.cc

@@ -12,6 +12,7 @@
 #include <math.h>
 
 #include <iostream>
+#include <sstream>
 #include <iomanip>
 #include <thread>
 #include <string>
@@ -360,6 +361,26 @@ void FastText::wordVectors() {
   }
 }
 
+void FastText::sentenceVectors() {
+  Vector vec(args_->dim);
+  std::string sentence;
+  Vector svec(args_->dim);
+  std::string word;
+  while (std::getline(std::cin, sentence)) {
+    std::istringstream iss(sentence);
+    svec.zero();
+    int32_t count = 0;
+    while(iss >> word) {
+      getVector(vec, word);
+      vec.mul(1.0 / vec.norm());
+      svec.addVector(vec);
+      count++;
+    }
+    svec.mul(1.0 / count);
+    std::cout << sentence << " " << svec << std::endl;
+  }
+}
+
 void FastText::ngramVectors(std::string word) {
   std::vector<int32_t> ngrams;
   std::vector<std::string> substrings;
@@ -390,11 +411,15 @@ void FastText::textVectors() {
   }
 }
 
-void FastText::printVectors() {
+void FastText::printWordVectors() {
+  wordVectors();
+}
+
+void FastText::printSentenceVectors() {
   if (args_->model == model_name::sup) {
     textVectors();
   } else {
-    wordVectors();
+    sentenceVectors();
   }
 }
 

+ 3 - 1
src/fasttext.h

@@ -75,9 +75,11 @@ class FastText {
         int32_t,
         std::vector<std::pair<real, std::string>>&) const;
     void wordVectors();
+    void sentenceVectors();
     void ngramVectors(std::string);
     void textVectors();
-    void printVectors();
+    void printWordVectors();
+    void printSentenceVectors();
     void trainThread(int32_t);
     void train(std::shared_ptr<Args>);
 

+ 36 - 15
src/main.cc

@@ -18,14 +18,15 @@ void printUsage() {
   std::cerr
     << "usage: fasttext <command> <args>\n\n"
     << "The commands supported by fasttext are:\n\n"
-    << "  supervised          train a supervised classifier\n"
-    << "  quantize            quantize a model to reduce the memory usage\n"
-    << "  test                evaluate a supervised classifier\n"
-    << "  predict             predict most likely labels\n"
-    << "  predict-prob        predict most likely labels with probabilities\n"
-    << "  skipgram            train a skipgram model\n"
-    << "  cbow                train a cbow model\n"
-    << "  print-vectors       print vectors given a trained model\n"
+    << "  supervised              train a supervised classifier\n"
+    << "  quantize                quantize a model to reduce the memory usage\n"
+    << "  test                    evaluate a supervised classifier\n"
+    << "  predict                 predict most likely labels\n"
+    << "  predict-prob            predict most likely labels with probabilities\n"
+    << "  skipgram                train a skipgram model\n"
+    << "  cbow                    train a cbow model\n"
+    << "  print-word-vectors      print word vectors given a trained model\n"
+    << "  print-sentence-vectors  print sentence vectors given a trained model\n"
     << std::endl;
 }
 
@@ -55,9 +56,16 @@ void printPredictUsage() {
     << std::endl;
 }
 
-void printPrintVectorsUsage() {
+void printPrintWordVectorsUsage() {
   std::cerr
-    << "usage: fasttext print-vectors <model>\n\n"
+    << "usage: fasttext print-word-vectors <model>\n\n"
+    << "  <model>      model filename\n"
+    << std::endl;
+}
+
+void printPrintSentenceVectorsUsage() {
+  std::cerr
+    << "usage: fasttext print-sentence-vectors <model>\n\n"
     << "  <model>      model filename\n"
     << std::endl;
 }
@@ -151,14 +159,25 @@ void predict(int argc, char** argv) {
   exit(0);
 }
 
-void printVectors(int argc, char** argv) {
+void printWordVectors(int argc, char** argv) {
+  if (argc != 3) {
+    printPrintWordVectorsUsage();
+    exit(EXIT_FAILURE);
+  }
+  FastText fasttext;
+  fasttext.loadModel(std::string(argv[2]));
+  fasttext.printWordVectors();
+  exit(0);
+}
+
+void printSentenceVectors(int argc, char** argv) {
   if (argc != 3) {
-    printPrintVectorsUsage();
+    printPrintSentenceVectorsUsage();
     exit(EXIT_FAILURE);
   }
   FastText fasttext;
   fasttext.loadModel(std::string(argv[2]));
-  fasttext.printVectors();
+  fasttext.printSentenceVectors();
   exit(0);
 }
 
@@ -192,8 +211,10 @@ int main(int argc, char** argv) {
     test(argc, argv);
   } else if (command == "quantize") {
     quantize(argc, argv);
-  } else if (command == "print-vectors") {
-    printVectors(argc, argv);
+  } else if (command == "print-word-vectors") {
+    printWordVectors(argc, argv);
+  } else if (command == "print-sentence-vectors") {
+    printSentenceVectors(argc, argv);
   } else if (command == "print-ngrams") {
     printNgrams(argc, argv);
   } else if (command == "predict" || command == "predict-prob" ) {

+ 15 - 0
src/vector.cc

@@ -37,12 +37,27 @@ void Vector::zero() {
   }
 }
 
+real Vector::norm() {
+  real sum = 0;
+  for (int64_t i = 0; i < m_; i++) {
+    sum += data_[i] * data_[i];
+  }
+  return std::sqrt(sum);
+}
+
 void Vector::mul(real a) {
   for (int64_t i = 0; i < m_; i++) {
     data_[i] *= a;
   }
 }
 
+void Vector::addVector(const Vector& source) {
+  assert(m_ == source.m_);
+  for (int64_t i = 0; i < m_; i++) {
+    data_[i] += source.data_[i];
+  }
+}
+
 void Vector::addRow(const Matrix& A, int64_t i) {
   assert(i >= 0);
   assert(i < A.m_);

+ 2 - 0
src/vector.h

@@ -35,6 +35,8 @@ class Vector {
     int64_t size() const;
     void zero();
     void mul(real);
+    real norm();
+    void addVector(const Vector& source);
     void addRow(const Matrix&, int64_t);
     void addRow(const QMatrix&, int64_t);
     void addRow(const Matrix&, int64_t, real);