Pārlūkot izejas kodu

added a text-vectors option that outputs vectors for lines in a test file.

Summary:
Added an option that allows the user to compute the vector for a line.
Works a bit like predict: takes a bin file and a text file and outputs one
vector per line.

Reviewed By: EdouardGrave

Differential Revision: D3827983

fbshipit-source-id: 3339d60d73f31a77c79466cc5847bee650fded12
Piotr Bojanowski 9 gadi atpakaļ
vecāks
revīzija
d652288bad
3 mainītis faili ar 53 papildinājumiem un 17 dzēšanām
  1. 9 0
      README.md
  2. 41 16
      src/fasttext.cc
  3. 3 1
      src/fasttext.h

+ 9 - 0
README.md

@@ -109,6 +109,15 @@ The argument `k` is optional, and equal to `1` by default.
 See `classification-example.sh` for an example use case.
 In order to reproduce results from the paper [2](#bag-of-tricks-for-efficient-text-classification), run `classification-results.sh`, this will download all the datasets and reproduce the results from Table 1.
 
+If you want to compute vector representations of sentences or paragraphs, please use:
+
+```
+$ ./fasttext print-vectors model.bin < text.txt
+```
+
+This assumes that the `text.txt` file contains the paragraphs that you want to get vectors for.
+The program will output one vector representation per line in the file.
+
 ## Full documentation
 
 Invoke a command without arguments to list available arguments and their default values:

+ 41 - 16
src/fasttext.cc

@@ -46,15 +46,6 @@ void FastText::saveVectors() {
   ofs.close();
 }
 
-void FastText::printVectors() {
-  std::string word;
-  Vector vec(args_->dim);
-  while (std::cin >> word) {
-    getVector(vec, word);
-    std::cout << word << " " << vec << std::endl;
-  }
-}
-
 void FastText::saveModel() {
   std::ofstream ofs(args_->output + ".bin");
   if (!ofs.is_open()) {
@@ -208,6 +199,40 @@ void FastText::predict(const std::string& filename, int32_t k, bool print_prob)
   ifs.close();
 }
 
+void FastText::wordVectors() {
+  std::string word;
+  Vector vec(args_->dim);
+  while (std::cin >> word) {
+    getVector(vec, word);
+    std::cout << word << " " << vec << std::endl;
+  }
+}
+
+void FastText::textVectors() {
+  std::vector<int32_t> line, labels;
+  Vector vec(args_->dim);
+  while (std::cin.peek() != EOF) {
+    dict_->getLine(std::cin, line, labels, model_->rng);
+    dict_->addNgrams(line, args_->wordNgrams);
+    vec.zero();
+    for (auto it = line.cbegin(); it != line.cend(); ++it) {
+      vec.addRow(*input_, *it);
+    }
+    if (!line.empty()) {
+      vec.mul(1.0 / line.size());
+    }
+    std::cout << vec << std::endl;
+  }
+}
+
+void FastText::printVectors() {
+  if (args_->model == model_name::sup) {
+    textVectors();
+  } else {
+    wordVectors();
+  }
+}
+
 void FastText::trainThread(int32_t threadId) {
   std::ifstream ifs(args_->input);
   utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
@@ -290,13 +315,13 @@ void printUsage() {
   std::cout
     << "usage: fasttext <command> <args>\n\n"
     << "The commands supported by fasttext are:\n\n"
-    << "  supervised       train a supervised classifier\n"
-    << "  test             evaluate a supervised classifier\n"
-    << "  predict          predict most likely labels\n"
-    << "  predict-prob     predict most likely labels with probabilities\n"
-    << "  skipgram         train a skipgram model\n"
-    << "  cbow             train a cbow model\n"
-    << "  print-vectors    print vectors given a trained model\n"
+    << "  supervised          train a supervised classifier\n"
+    << "  test                evaluate a supervised classifier\n"
+    << "  predict             predict most likely labels\n"
+    << "  predict-prob        predict most likely labels with probabilities\n"
+    << "  skipgram            train a skipgram model\n"
+    << "  cbow                train a cbow model\n"
+    << "  print-vectors       print vectors given a trained model\n"
     << std::endl;
 }
 

+ 3 - 1
src/fasttext.h

@@ -36,7 +36,6 @@ class FastText {
   public:
     void getVector(Vector&, const std::string&);
     void saveVectors();
-    void printVectors();
     void saveModel();
     void loadModel(const std::string&);
     void printInfo(real, real);
@@ -47,6 +46,9 @@ class FastText {
     void skipgram(Model&, real, const std::vector<int32_t>&);
     void test(const std::string&, int32_t);
     void predict(const std::string&, int32_t, bool);
+    void wordVectors();
+    void textVectors();
+    void printVectors();
     void trainThread(int32_t);
     void train(std::shared_ptr<Args>);
 };