1
0
Эх сурвалжийг харах

remove printing functions from fasttext class

Summary:
This diff removes the print capabilities from fasttext and defines a new api.
- `predictLine` extracts predictions from exactly one line of the input stream.
- the deprecated `printLabelStats` is removed as [js bindings don't use it]( https://www.facebook.com/groups/1174547215919768/?multi_permalinks=2328051983902613&comment_id=2360179150689896 )
- `ngramVectors` is now deprecated by the addition of `getNgramVectors`. `Vector` class remains copy-free but move semantics has been added.
- `analogies` is now deprecated by `getAnalogies`. when called, fastText class lazy-precomputes word vectors
- `findNN` is now deprecated by `getNN`. when called, fastText class lazy-precomputes word vectors
- `trainThread` and `printInfo` functions are now private.
- `supervised`, `cbow`, `skipgram`, `selectEmbeddings`, `precomputeWordVectors` are now deprecated and will be private in the future.
- `saveVectors`, `saveOutput` and `saveModel` without arguments are now deprecated by their equivalent with filename as string argument.

Reviewed By: EdouardGrave

Differential Revision: D13083799

fbshipit-source-id: f557ed7c141a90a6171045fe118ac16c195c824f
Onur Çelebi 7 жил өмнө
parent
commit
256032b875

+ 8 - 4
python/fastText/FastText.py

@@ -130,12 +130,16 @@ class _FastText():
 
         if type(text) == list:
             text = [check(entry) for entry in text]
-            all_probs, all_labels = self.f.multilinePredict(text, k, threshold)
-            return all_labels, np.array(all_probs, copy=False)
+            predictions = self.f.multilinePredict(text, k, threshold)
+            dt = np.dtype([('probability', 'float64'), ('label', '<U32')])
+            result_as_pair = np.array(predictions, dtype=dt)
+
+            return result_as_pair['label'].tolist(), result_as_pair['probability']
         else:
             text = check(text)
-            pairs = self.f.predict(text, k, threshold)
-            probs, labels = zip(*pairs)
+            predictions = self.f.predict(text, k, threshold)
+            probs, labels = zip(*predictions)
+
             return labels, np.array(probs, copy=False)
 
     def get_input_matrix(self):

+ 23 - 48
python/fastText/pybind/fasttext_pybind.cc

@@ -262,23 +262,30 @@ PYBIND11_MODULE(fasttext_pybind, m) {
              const std::string text,
              int32_t k,
              fasttext::real threshold) {
-            std::vector<std::pair<fasttext::real, int32_t>> predictions;
-            std::vector<std::pair<fasttext::real, std::string>> all_predictions;
             std::stringstream ioss(text);
-            std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
-            std::vector<int32_t> words, labels;
-            d->getLine(ioss, words, labels);
-            m.predict(k, words, predictions, threshold);
-            std::transform(
-                predictions.begin(),
-                predictions.end(),
-                std::back_inserter(all_predictions),
-                [&d](const std::pair<fasttext::real, int32_t>& prediction) {
-                  return std::pair<fasttext::real, std::string>(
-                      std::exp(prediction.first),
-                      d->getLabel(prediction.second));
-                });
-            return all_predictions;
+            std::vector<std::pair<fasttext::real, std::string>> predictions;
+            m.predictLine(ioss, predictions, k, threshold);
+
+            return predictions;
+          })
+      .def(
+          "multilinePredict",
+          // NOTE: text needs to end in a newline
+          // to exactly mimic the behavior of the cli
+          [](fasttext::FastText& m,
+             const std::vector<std::string>& lines,
+             int32_t k,
+             fasttext::real threshold) {
+            std::vector<std::vector<std::pair<fasttext::real, std::string>>>
+                allPredictions;
+            std::vector<std::pair<fasttext::real, std::string>> predictions;
+
+            for (const std::string& text : lines) {
+              std::stringstream ioss(text);
+              m.predictLine(ioss, predictions, k, threshold);
+              allPredictions.push_back(predictions);
+            }
+            return allPredictions;
           })
       .def(
           "testLabel",
@@ -303,38 +310,6 @@ PYBIND11_MODULE(fasttext_pybind, m) {
 
             return returnedValue;
           })
-      .def(
-          "multilinePredict",
-          // NOTE: text needs to end in a newline
-          // to exactly mimic the behavior of the cli
-          [](fasttext::FastText& m,
-             const std::vector<std::string>& lines,
-             int32_t k,
-             fasttext::real threshold) {
-            std::pair<
-                std::vector<std::vector<fasttext::real>>,
-                std::vector<std::vector<std::string>>>
-                all_predictions;
-            std::vector<std::pair<fasttext::real, int32_t>> predictions;
-            std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
-            std::vector<int32_t> words, labels;
-            for (const std::string& text : lines) {
-              std::stringstream ioss(text);
-              predictions.clear();
-              d->getLine(ioss, words, labels);
-              m.predict(k, words, predictions, threshold);
-              all_predictions.first.push_back(std::vector<fasttext::real>());
-              all_predictions.second.push_back(std::vector<std::string>());
-              for (auto& pair : predictions) {
-                pair.first = std::exp(pair.first);
-                all_predictions.first[all_predictions.first.size() - 1]
-                    .push_back(pair.first);
-                all_predictions.second[all_predictions.second.size() - 1]
-                    .push_back(d->getLabel(pair.second));
-              }
-            }
-            return all_predictions;
-          })
       .def(
           "getWordId",
           [](fasttext::FastText& m, const std::string word) {

+ 144 - 106
src/fasttext.cc

@@ -24,7 +24,11 @@ namespace fasttext {
 constexpr int32_t FASTTEXT_VERSION = 12; /* Version 1b */
 constexpr int32_t FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314;
 
-FastText::FastText() : quant_(false) {}
+bool comparePairs(
+    const std::pair<real, std::string>& l,
+    const std::pair<real, std::string>& r);
+
+FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
 
 void FastText::addInputVector(Vector& vec, int32_t ind) const {
   if (quant_) {
@@ -81,11 +85,11 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
   addInputVector(vec, h);
 }
 
-void FastText::saveVectors() {
-  std::ofstream ofs(args_->output + ".vec");
+void FastText::saveVectors(const std::string& filename) {
+  std::ofstream ofs(filename);
   if (!ofs.is_open()) {
     throw std::invalid_argument(
-        args_->output + ".vec" + " cannot be opened for saving vectors!");
+        filename + " cannot be opened for saving vectors!");
   }
   ofs << dict_->nwords() << " " << args_->dim << std::endl;
   Vector vec(args_->dim);
@@ -97,11 +101,15 @@ void FastText::saveVectors() {
   ofs.close();
 }
 
-void FastText::saveOutput() {
-  std::ofstream ofs(args_->output + ".output");
+void FastText::saveVectors() {
+  saveVectors(args_->output + ".vec");
+}
+
+void FastText::saveOutput(const std::string& filename) {
+  std::ofstream ofs(filename);
   if (!ofs.is_open()) {
     throw std::invalid_argument(
-        args_->output + ".output" + " cannot be opened for saving vectors!");
+        filename + " cannot be opened for saving vectors!");
   }
   if (quant_) {
     throw std::invalid_argument(
@@ -121,6 +129,10 @@ void FastText::saveOutput() {
   ofs.close();
 }
 
+void FastText::saveOutput() {
+  saveOutput(args_->output + ".output");
+}
+
 bool FastText::checkModel(std::istream& in) {
   int32_t magic;
   in.read((char*)&(magic), sizeof(int32_t));
@@ -151,10 +163,10 @@ void FastText::saveModel() {
   saveModel(fn);
 }
 
-void FastText::saveModel(const std::string path) {
-  std::ofstream ofs(path, std::ofstream::binary);
+void FastText::saveModel(const std::string& filename) {
+  std::ofstream ofs(filename, std::ofstream::binary);
   if (!ofs.is_open()) {
-    throw std::invalid_argument(path + " cannot be opened for saving!");
+    throw std::invalid_argument(filename + " cannot be opened for saving!");
   }
   signModel(ofs);
   args_->save(ofs);
@@ -410,39 +422,26 @@ void FastText::predict(
   model_->predict(words, k, threshold, predictions, hidden, output);
 }
 
-void FastText::predict(
+bool FastText::predictLine(
     std::istream& in,
+    std::vector<std::pair<real, std::string>>& predictions,
     int32_t k,
-    bool print_prob,
-    real threshold) {
-  std::vector<std::pair<real, int32_t>> predictions;
-  while (in.peek() != EOF) {
-    std::vector<int32_t> words, labels;
-    dict_->getLine(in, words, labels);
-    predictions.clear();
-    predict(k, words, predictions, threshold);
-    if (predictions.empty()) {
-      std::cout << std::endl;
-      continue;
-    }
-    for (auto it = predictions.cbegin(); it != predictions.cend(); it++) {
-      if (it != predictions.cbegin()) {
-        std::cout << " ";
-      }
-      std::cout << dict_->getLabel(it->second);
-      if (print_prob) {
-        std::cout << " " << std::exp(it->first);
-      }
-    }
-    std::cout << std::endl;
+    real threshold) const {
+  predictions.clear();
+  if (in.peek() == EOF) {
+    return false;
   }
-}
 
-void FastText::printLabelStats(std::istream& in, int32_t k, real threshold)
-    const {
-  Meter meter;
-  test(in, k, threshold, meter);
-  writePerLabelMetrics(std::cout, meter);
+  std::vector<int32_t> words, labels;
+  dict_->getLine(in, words, labels);
+  std::vector<std::pair<real, int32_t>> linePredictions;
+  predict(k, words, linePredictions, threshold);
+  for (const auto& p : linePredictions) {
+    predictions.push_back(
+        std::make_pair(std::exp(p.first), dict_->getLabel(p.second)));
+  }
+
+  return true;
 }
 
 void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) {
@@ -478,13 +477,15 @@ void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) {
   }
 }
 
-void FastText::ngramVectors(std::string word) {
+std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
+    const std::string& word) const {
+  std::vector<std::pair<std::string, Vector>> result;
   std::vector<int32_t> ngrams;
   std::vector<std::string> substrings;
-  Vector vec(args_->dim);
   dict_->getSubwords(word, ngrams, substrings);
+  assert(ngrams.size() <= substrings.size());
   for (int32_t i = 0; i < ngrams.size(); i++) {
-    vec.zero();
+    Vector vec(args_->dim);
     if (ngrams[i] >= 0) {
       if (quant_) {
         vec.addRow(*qinput_, ngrams[i]);
@@ -492,7 +493,18 @@ void FastText::ngramVectors(std::string word) {
         vec.addRow(*input_, ngrams[i]);
       }
     }
-    std::cout << substrings[i] << " " << vec << std::endl;
+    result.push_back(std::make_pair(substrings[i], std::move(vec)));
+  }
+  return result;
+}
+
+// deprecated. use getNgramVectors instead
+void FastText::ngramVectors(std::string word) {
+  std::vector<std::pair<std::string, Vector>> ngramVectors =
+      getNgramVectors(word);
+
+  for (const auto& ngramVector : ngramVectors) {
+    std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
   }
 }
 
@@ -509,65 +521,107 @@ void FastText::precomputeWordVectors(Matrix& wordVectors) {
   }
 }
 
-void FastText::findNN(
+void FastText::lazyComputeWordVectors() {
+  if (!wordVectors_) {
+    wordVectors_ =
+        std::unique_ptr<Matrix>(new Matrix(dict_->nwords(), args_->dim));
+    precomputeWordVectors(*wordVectors_);
+  }
+}
+
+std::vector<std::pair<real, std::string>> FastText::getNN(
+    const std::string& word,
+    int32_t k) {
+  Vector query(args_->dim);
+
+  getWordVector(query, word);
+
+  lazyComputeWordVectors();
+  assert(wordVectors_);
+  return getNN(*wordVectors_, query, k, {word});
+}
+
+std::vector<std::pair<real, std::string>> FastText::getNN(
     const Matrix& wordVectors,
-    const Vector& queryVec,
+    const Vector& query,
     int32_t k,
-    const std::set<std::string>& banSet,
-    std::vector<std::pair<real, std::string>>& results) {
-  results.clear();
-  std::priority_queue<std::pair<real, std::string>> heap;
-  real queryNorm = queryVec.norm();
+    const std::set<std::string>& banSet) {
+  std::vector<std::pair<real, std::string>> heap;
+
+  real queryNorm = query.norm();
   if (std::abs(queryNorm) < 1e-8) {
     queryNorm = 1;
   }
-  Vector vec(args_->dim);
+
   for (int32_t i = 0; i < dict_->nwords(); i++) {
     std::string word = dict_->getWord(i);
-    real dp = wordVectors.dotRow(queryVec, i);
-    heap.push(std::make_pair(dp / queryNorm, word));
-  }
-  int32_t i = 0;
-  while (i < k && heap.size() > 0) {
-    auto it = banSet.find(heap.top().second);
-    if (it == banSet.end()) {
-      results.push_back(
-          std::pair<real, std::string>(heap.top().first, heap.top().second));
-      i++;
+    if (banSet.find(word) == banSet.end()) {
+      real dp = wordVectors.dotRow(query, i);
+      real similarity = dp / queryNorm;
+
+      if (heap.size() == k && similarity < heap.front().first) {
+        continue;
+      }
+      heap.push_back(std::make_pair(similarity, word));
+      std::push_heap(heap.begin(), heap.end(), comparePairs);
+      if (heap.size() > k) {
+        std::pop_heap(heap.begin(), heap.end(), comparePairs);
+        heap.pop_back();
+      }
     }
-    heap.pop();
   }
+  std::sort_heap(heap.begin(), heap.end(), comparePairs);
+
+  return heap;
+}
+
+// depracted. use getNN instead
+void FastText::findNN(
+    const Matrix& wordVectors,
+    const Vector& query,
+    int32_t k,
+    const std::set<std::string>& banSet,
+    std::vector<std::pair<real, std::string>>& results) {
+  results.clear();
+  results = getNN(wordVectors, query, k, banSet);
 }
 
+std::vector<std::pair<real, std::string>> FastText::getAnalogies(
+    int32_t k,
+    const std::string& wordA,
+    const std::string& wordB,
+    const std::string& wordC) {
+  Vector query = Vector(args_->dim);
+  query.zero();
+
+  Vector buffer(args_->dim);
+  getWordVector(buffer, wordA);
+  query.addVector(buffer, 1.0);
+  getWordVector(buffer, wordB);
+  query.addVector(buffer, -1.0);
+  getWordVector(buffer, wordC);
+  query.addVector(buffer, 1.0);
+
+  lazyComputeWordVectors();
+  assert(wordVectors_);
+  return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
+}
+
+// depreacted, use getAnalogies instead
 void FastText::analogies(int32_t k) {
-  std::string word;
-  Vector buffer(args_->dim), query(args_->dim);
-  Matrix wordVectors(dict_->nwords(), args_->dim);
-  precomputeWordVectors(wordVectors);
-  std::set<std::string> banSet;
-  std::cout << "Query triplet (A - B + C)? ";
-  std::vector<std::pair<real, std::string>> results;
+  std::string prompt("Query triplet (A - B + C)? ");
+  std::string wordA, wordB, wordC;
+  std::cout << prompt;
   while (true) {
-    banSet.clear();
-    query.zero();
-    std::cin >> word;
-    banSet.insert(word);
-    getWordVector(buffer, word);
-    query.addVector(buffer, 1.0);
-    std::cin >> word;
-    banSet.insert(word);
-    getWordVector(buffer, word);
-    query.addVector(buffer, -1.0);
-    std::cin >> word;
-    banSet.insert(word);
-    getWordVector(buffer, word);
-    query.addVector(buffer, 1.0);
-
-    findNN(wordVectors, query, k, banSet, results);
+    std::cin >> wordA;
+    std::cin >> wordB;
+    std::cin >> wordC;
+    auto results = getAnalogies(k, wordA, wordB, wordC);
+
     for (auto& pair : results) {
       std::cout << pair.second << " " << pair.first << std::endl;
     }
-    std::cout << "Query triplet (A - B + C)? ";
+    std::cout << prompt;
   }
 }
 
@@ -727,26 +781,10 @@ bool FastText::isQuant() const {
   return quant_;
 }
 
-void FastText::writePerLabelMetrics(std::ostream& out, Meter& meter) const {
-  out << std::fixed;
-  out << std::setprecision(6);
-
-  auto writeMetric = [&](const std::string& name, double value) {
-    out << name << " : ";
-    if (std::isfinite(value)) {
-      out << value;
-    } else {
-      out << "--------";
-    }
-    out << "  ";
-  };
-
-  for (int32_t i = 0; i < dict_->nlabels(); i++) {
-    writeMetric("F1-Score", meter.f1Score(i));
-    writeMetric("Precision", meter.precision(i));
-    writeMetric("Recall", meter.recall(i));
-    out << " " << dict_->getLabel(i) << std::endl;
-  }
+bool comparePairs(
+    const std::pair<real, std::string>& l,
+    const std::pair<real, std::string>& r) {
+  return l.first > r.first;
 }
 
 } // namespace fasttext

+ 97 - 33
src/fasttext.h

@@ -50,79 +50,143 @@ class FastText {
   std::chrono::steady_clock::time_point start_;
   void signModel(std::ostream&);
   bool checkModel(std::istream&);
+  void startThreads();
+  void addInputVector(Vector&, int32_t) const;
+  void trainThread(int32_t);
+  std::vector<std::pair<real, std::string>> getNN(
+      const Matrix& wordVectors,
+      const Vector& queryVec,
+      int32_t k,
+      const std::set<std::string>& banSet);
+  void lazyComputeWordVectors();
+  void printInfo(real, real, std::ostream&);
 
   bool quant_;
   int32_t version;
-
-  void startThreads();
+  std::unique_ptr<Matrix> wordVectors_;
 
  public:
   FastText();
 
   int32_t getWordId(const std::string&) const;
+
   int32_t getSubwordId(const std::string&) const;
-  FASTTEXT_DEPRECATED(
-      "getVector is being deprecated and replaced by getWordVector.")
-  void getVector(Vector&, const std::string&) const;
+
   void getWordVector(Vector&, const std::string&) const;
+
   void getSubwordVector(Vector&, const std::string&) const;
-  void addInputVector(Vector&, int32_t) const;
+
   inline void getInputVector(Vector& vec, int32_t ind) {
     vec.zero();
     addInputVector(vec, ind);
   }
 
   const Args getArgs() const;
+
   std::shared_ptr<const Dictionary> getDictionary() const;
+
   std::shared_ptr<const Matrix> getInputMatrix() const;
+
   std::shared_ptr<const Matrix> getOutputMatrix() const;
-  void saveVectors();
-  void saveModel(const std::string);
-  void saveOutput();
-  void saveModel();
+
+  void saveVectors(const std::string&);
+
+  void saveModel(const std::string&);
+
+  void saveOutput(const std::string&);
+
   void loadModel(std::istream&);
+
   void loadModel(const std::string&);
-  void printInfo(real, real, std::ostream&);
 
-  void supervised(
-      Model&,
-      real,
-      const std::vector<int32_t>&,
-      const std::vector<int32_t>&);
-  void cbow(Model&, real, const std::vector<int32_t>&);
-  void skipgram(Model&, real, const std::vector<int32_t>&);
-  std::vector<int32_t> selectEmbeddings(int32_t) const;
   void getSentenceVector(std::istream&, Vector&);
+
   void quantize(const Args);
+
   std::tuple<int64_t, double, double> test(std::istream&, int32_t, real = 0.0);
+
   void test(std::istream&, int32_t, real, Meter&) const;
+
   void predict(
       int32_t,
       const std::vector<int32_t>&,
       std::vector<std::pair<real, int32_t>>&,
       real = 0.0) const;
-  void predict(std::istream&, int32_t, bool, real = 0.0);
-  void ngramVectors(std::string);
-  void precomputeWordVectors(Matrix&);
-  void findNN(
-      const Matrix&,
-      const Vector&,
+
+  bool predictLine(
+      std::istream&,
+      std::vector<std::pair<real, std::string>>&,
       int32_t,
-      const std::set<std::string>&,
-      std::vector<std::pair<real, std::string>>& results);
-  void analogies(int32_t);
-  void trainThread(int32_t);
+      real) const;
+
+  std::vector<std::pair<std::string, Vector>> getNgramVectors(
+      const std::string& word) const;
+
+  std::vector<std::pair<real, std::string>> getNN(const std::string&, int32_t);
+
+  std::vector<std::pair<real, std::string>> getAnalogies(
+      int32_t,
+      const std::string&,
+      const std::string&,
+      const std::string&);
+
   void train(const Args);
 
   void loadVectors(std::string);
+
   int getDimension() const;
+
   bool isQuant() const;
 
   FASTTEXT_DEPRECATED(
-      "This function is deprecated, please use `test` function.")
-  void printLabelStats(std::istream&, int32_t, real = 0.0) const;
+      "getVector is being deprecated and replaced by getWordVector.")
+  void getVector(Vector&, const std::string&) const;
+
+  FASTTEXT_DEPRECATED(
+      "ngramVectors is being deprecated and replaced by getNgramVectors.")
+  void ngramVectors(std::string);
+
+  FASTTEXT_DEPRECATED(
+      "analogies is being deprecated and replaced by getAnalogies.")
+  void analogies(int32_t);
+
+  FASTTEXT_DEPRECATED("supervised is being deprecated.")
+  void supervised(
+      Model&,
+      real,
+      const std::vector<int32_t>&,
+      const std::vector<int32_t>&);
+
+  FASTTEXT_DEPRECATED("cbow is being deprecated.")
+  void cbow(Model&, real, const std::vector<int32_t>&);
+
+  FASTTEXT_DEPRECATED("skipgram is being deprecated.")
+  void skipgram(Model&, real, const std::vector<int32_t>&);
+
+  FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.")
+  std::vector<int32_t> selectEmbeddings(int32_t) const;
+
+  FASTTEXT_DEPRECATED(
+      "saveVectors is being deprecated, please use the other signature.")
+  void saveVectors();
+
   FASTTEXT_DEPRECATED(
-      "This function is deprecated and will be removed along with `printLabelStats`.")
-  void writePerLabelMetrics(std::ostream&, Meter&) const;
+      "saveOutput is being deprecated, please use the other signature.")
+  void saveOutput();
+
+  FASTTEXT_DEPRECATED(
+      "saveModel is being deprecated, please use the other signature.")
+  void saveModel();
+
+  FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
+  void precomputeWordVectors(Matrix&);
+
+  FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
+  void findNN(
+      const Matrix&,
+      const Vector&,
+      int32_t,
+      const std::set<std::string>&,
+      std::vector<std::pair<real, std::string>>& results);
 };
 } // namespace fasttext

+ 95 - 34
src/main.cc

@@ -102,7 +102,7 @@ void quantize(const std::vector<std::string>& args) {
   // parseArgs checks if a->output is given.
   fasttext.loadModel(a.output + ".bin");
   fasttext.quantize(a);
-  fasttext.saveModel();
+  fasttext.saveModel(a.output + ".ftz");
   exit(0);
 }
 
@@ -156,13 +156,53 @@ void test(const std::vector<std::string>& args) {
   }
 
   if (perLabel) {
-    fasttext.writePerLabelMetrics(std::cout, meter);
+    std::cout << std::fixed << std::setprecision(6);
+    auto writeMetric = [](const std::string& name, double value) {
+      std::cout << name << " : ";
+      if (std::isfinite(value)) {
+        std::cout << value;
+      } else {
+        std::cout << "--------";
+      }
+      std::cout << "  ";
+    };
+
+    std::shared_ptr<const Dictionary> dict = fasttext.getDictionary();
+    for (int32_t labelId = 0; labelId < dict->nlabels(); labelId++) {
+      writeMetric("F1-Score", meter.f1Score(labelId));
+      writeMetric("Precision", meter.precision(labelId));
+      writeMetric("Recall", meter.recall(labelId));
+      std::cout << " " << dict->getLabel(labelId) << std::endl;
+    }
   }
   meter.writeGeneralMetrics(std::cout, k);
 
   exit(0);
 }
 
+void printPredictions(
+    const std::vector<std::pair<real, std::string>>& predictions,
+    bool printProb,
+    bool multiline) {
+  bool first = true;
+  for (const auto& prediction : predictions) {
+    if (!first && !multiline) {
+      std::cout << " ";
+    }
+    first = false;
+    std::cout << prediction.second;
+    if (printProb) {
+      std::cout << " " << prediction.first;
+    }
+    if (multiline) {
+      std::cout << std::endl;
+    }
+  }
+  if (!multiline) {
+    std::cout << std::endl;
+  }
+}
+
 void predict(const std::vector<std::string>& args) {
   if (args.size() < 4 || args.size() > 6) {
     printPredictUsage();
@@ -177,20 +217,26 @@ void predict(const std::vector<std::string>& args) {
     }
   }
 
-  bool print_prob = args[1] == "predict-prob";
+  bool printProb = args[1] == "predict-prob";
   FastText fasttext;
   fasttext.loadModel(std::string(args[2]));
 
+  std::ifstream ifs;
   std::string infile(args[3]);
-  if (infile == "-") {
-    fasttext.predict(std::cin, k, print_prob, threshold);
-  } else {
-    std::ifstream ifs(infile);
-    if (!ifs.is_open()) {
+  bool inputIsStdIn = infile == "-";
+  if (!inputIsStdIn){
+    ifs.open(infile);
+    if (!inputIsStdIn && !ifs.is_open()) {
       std::cerr << "Input file cannot be opened!" << std::endl;
       exit(EXIT_FAILURE);
     }
-    fasttext.predict(ifs, k, print_prob, threshold);
+  }
+  std::istream& in = inputIsStdIn ? std::cin : ifs;
+  std::vector<std::pair<real, std::string>> predictions;
+  while (fasttext.predictLine(in, predictions, k, threshold)) {
+    printPredictions(predictions, printProb, false);
+  }
+  if (ifs.is_open()){
     ifs.close();
   }
 
@@ -236,7 +282,15 @@ void printNgrams(const std::vector<std::string> args) {
   }
   FastText fasttext;
   fasttext.loadModel(std::string(args[2]));
-  fasttext.ngramVectors(std::string(args[3]));
+
+  std::string word(args[3]);
+  std::vector<std::pair<std::string, Vector>> ngramVectors =
+      fasttext.getNgramVectors(word);
+
+  for (const auto& ngramVector : ngramVectors) {
+    std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
+  }
+
   exit(0);
 }
 
@@ -252,25 +306,13 @@ void nn(const std::vector<std::string> args) {
   }
   FastText fasttext;
   fasttext.loadModel(std::string(args[2]));
+  std::string prompt("Query word? ");
+  std::cout << prompt;
+
   std::string queryWord;
-  std::shared_ptr<const Dictionary> dict = fasttext.getDictionary();
-  Vector queryVec(fasttext.getDimension());
-  Matrix wordVectors(dict->nwords(), fasttext.getDimension());
-  std::cerr << "Pre-computing word vectors...";
-  fasttext.precomputeWordVectors(wordVectors);
-  std::cerr << " done." << std::endl;
-  std::set<std::string> banSet;
-  std::cout << "Query word? ";
-  std::vector<std::pair<real, std::string>> results;
   while (std::cin >> queryWord) {
-    banSet.clear();
-    banSet.insert(queryWord);
-    fasttext.getWordVector(queryVec, queryWord);
-    fasttext.findNN(wordVectors, queryVec, k, banSet, results);
-    for (auto& pair : results) {
-      std::cout << pair.second << " " << pair.first << std::endl;
-    }
-    std::cout << "Query word? ";
+    printPredictions(fasttext.getNN(queryWord, k), true, true);
+    std::cout << prompt;
   }
   exit(0);
 }
@@ -285,9 +327,26 @@ void analogies(const std::vector<std::string> args) {
     printAnalogiesUsage();
     exit(EXIT_FAILURE);
   }
+  if (k <= 0) {
+    throw std::invalid_argument("k needs to be 1 or higher!");
+  }
   FastText fasttext;
-  fasttext.loadModel(std::string(args[2]));
-  fasttext.analogies(k);
+  std::string model(args[2]);
+  std::cout << "Loading model " << model << std::endl;
+  fasttext.loadModel(model);
+
+  std::string prompt("Query triplet (A - B + C)? ");
+  std::string wordA, wordB, wordC;
+  std::cout << prompt;
+  while (true) {
+    std::cin >> wordA;
+    std::cin >> wordB;
+    std::cin >> wordC;
+    printPredictions(
+        fasttext.getAnalogies(k, wordA, wordB, wordC), true, true);
+
+    std::cout << prompt;
+  }
   exit(0);
 }
 
@@ -295,16 +354,18 @@ void train(const std::vector<std::string> args) {
   Args a = Args();
   a.parseArgs(args);
   FastText fasttext;
-  std::ofstream ofs(a.output + ".bin");
+  std::string outputFileName(a.output + ".bin");
+  std::ofstream ofs(outputFileName);
   if (!ofs.is_open()) {
-    throw std::invalid_argument(a.output + ".bin cannot be opened for saving.");
+    throw std::invalid_argument(
+        outputFileName + " cannot be opened for saving.");
   }
   ofs.close();
   fasttext.train(a);
-  fasttext.saveModel();
-  fasttext.saveVectors();
+  fasttext.saveModel(outputFileName);
+  fasttext.saveVectors(a.output + ".vec");
   if (a.saveOutput) {
-    fasttext.saveOutput();
+    fasttext.saveOutput(a.output + ".output");
   }
 }
 

+ 9 - 0
src/meter.cc

@@ -1,3 +1,12 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
 #include "meter.h"
 
 #include <algorithm>

+ 9 - 0
src/meter.h

@@ -1,3 +1,12 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
 #pragma once
 
 #include <unordered_map>

+ 8 - 0
src/vector.cc

@@ -13,6 +13,7 @@
 
 #include <cmath>
 #include <iomanip>
+#include <utility>
 
 #include "matrix.h"
 #include "qmatrix.h"
@@ -21,6 +22,13 @@ namespace fasttext {
 
 Vector::Vector(int64_t m) : data_(m) {}
 
+Vector::Vector(Vector&& other) noexcept : data_(std::move(other.data_)) {}
+
+Vector& Vector::operator=(Vector&& other) {
+  data_ = std::move(other.data_);
+  return *this;
+}
+
 void Vector::zero() {
   std::fill(data_.begin(), data_.end(), 0.0);
 }

+ 2 - 0
src/vector.h

@@ -27,7 +27,9 @@ class Vector {
  public:
   explicit Vector(int64_t);
   Vector(const Vector&) = delete;
+  Vector(Vector&&) noexcept;
   Vector& operator=(const Vector&) = delete;
+  Vector& operator=(Vector&&);
 
   inline real* data() {
     return data_.data();