Преглед изворни кода

Separating loss and gradient computation from model

Summary:
This commit splits the computation of the loss and the subsequent gradient into different classes. Each Loss class implements its own logic and contains the underlying data needed for the computation.

There is a behavioural change :
- now, `NegativeSampling` also uses the sigmoid output instead of softmax output for the prediction. Before this commit, it used sigmoid for train, softmax for prediction.

We are passing many information to `Loss` classes. There are two things we should think next:
- `State` class
- `Model` classes instead of `Loss` classes

Reviewed By: EdouardGrave

Differential Revision: D13359871

fbshipit-source-id: 2f53eaafb800a9a2742817aa113af5f6bd7e282e
Onur Çelebi пре 7 година
родитељ
комит
9ddcabd04f
13 измењених фајлова са 627 додато и 407 уклоњено
  1. 2 0
      CMakeLists.txt
  2. 4 1
      Makefile
  3. 32 11
      src/fasttext.cc
  4. 2 1
      src/fasttext.h
  5. 361 0
      src/loss.cc
  6. 176 0
      src/loss.h
  7. 1 1
      src/meter.cc
  8. 2 3
      src/meter.h
  9. 24 311
      src/model.cc
  10. 15 68
      src/model.h
  11. 4 0
      src/utils.h
  12. 0 7
      src/vector.cc
  13. 4 4
      src/vector.h

+ 2 - 0
CMakeLists.txt

@@ -22,6 +22,7 @@ set(HEADER_FILES
     src/densematrix.h
     src/dictionary.h
     src/fasttext.h
+    src/loss.h
     src/matrix.h
     src/meter.h
     src/model.h
@@ -36,6 +37,7 @@ set(SOURCE_FILES
     src/densematrix.cc
     src/dictionary.cc
     src/fasttext.cc
+    src/loss.cc
     src/main.cc
     src/matrix.cc
     src/meter.cc

+ 4 - 1
Makefile

@@ -8,7 +8,7 @@
 
 CXX = c++
 CXXFLAGS = -pthread -std=c++0x -march=native
-OBJS = args.o matrix.o dictionary.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
+OBJS = args.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
 INCLUDES = -I.
 
 opt: CXXFLAGS += -O3 -funroll-loops
@@ -29,6 +29,9 @@ matrix.o: src/matrix.cc src/matrix.h
 dictionary.o: src/dictionary.cc src/dictionary.h src/args.h
 	$(CXX) $(CXXFLAGS) -c src/dictionary.cc
 
+loss.o: src/loss.cc src/loss.h src/basematrix.h src/real.h
+	$(CXX) $(CXXFLAGS) -c src/loss.cc
+
 productquantizer.o: src/productquantizer.cc src/productquantizer.h src/utils.h
 	$(CXX) $(CXXFLAGS) -c src/productquantizer.cc
 

+ 32 - 11
src/fasttext.cc

@@ -7,6 +7,7 @@
  */
 
 #include "fasttext.h"
+#include "loss.h"
 #include "quantmatrix.h"
 
 #include <algorithm>
@@ -28,6 +29,24 @@ bool comparePairs(
     const std::pair<real, std::string>& l,
     const std::pair<real, std::string>& r);
 
+std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
+  loss_name lossName = args_->loss;
+  switch (lossName) {
+    case loss_name::hs:
+      return std::make_shared<HierarchicalSoftmaxLoss>(
+          output, getTargetCounts());
+    case loss_name::ns:
+      return std::make_shared<NegativeSamplingLoss>(
+          output, args_->neg, getTargetCounts());
+    case loss_name::softmax:
+      return std::make_shared<SoftmaxLoss>(output);
+    case loss_name::ova:
+      return std::make_shared<OneVsAllLoss>(output);
+    default:
+      throw std::runtime_error("Unknown loss");
+  }
+}
+
 FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
 
 void FastText::addInputVector(Vector& vec, int32_t ind) const {
@@ -237,8 +256,8 @@ void FastText::loadModel(std::istream& in) {
   }
   output_->load(in);
 
-  model_ =
-      std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
+  auto loss = createLoss(output_);
+  model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
 }
 
 void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
@@ -297,7 +316,6 @@ void FastText::quantize(const Args& qargs) {
       std::dynamic_pointer_cast<DenseMatrix>(input_);
   std::shared_ptr<DenseMatrix> output =
       std::dynamic_pointer_cast<DenseMatrix>(output_);
-
   if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {
     auto idx = selectEmbeddings(qargs.cutoff);
     dict_->prune(idx);
@@ -314,6 +332,8 @@ void FastText::quantize(const Args& qargs) {
       args_->lr = qargs.lr;
       args_->thread = qargs.thread;
       args_->verbose = qargs.verbose;
+      auto loss = createLoss(output_);
+      model_ = std::make_shared<Model>(input, output, args_, loss, 0);
       startThreads();
     }
   }
@@ -327,8 +347,8 @@ void FastText::quantize(const Args& qargs) {
   }
 
   quant_ = true;
-  model_ =
-      std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
+  auto loss = createLoss(output_);
+  model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
 }
 
 void FastText::supervised(
@@ -393,7 +413,7 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
     const {
   std::vector<int32_t> line;
   std::vector<int32_t> labels;
-  std::vector<std::pair<real, int32_t>> predictions;
+  Predictions predictions;
 
   while (in.peek() != EOF) {
     line.clear();
@@ -411,7 +431,7 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
 void FastText::predict(
     int32_t k,
     const std::vector<int32_t>& words,
-    std::vector<std::pair<real, int32_t>>& predictions,
+    Predictions& predictions,
     real threshold) const {
   if (words.empty()) {
     return;
@@ -433,7 +453,7 @@ bool FastText::predictLine(
 
   std::vector<int32_t> words, labels;
   dict_->getLine(in, words, labels);
-  std::vector<std::pair<real, int32_t>> linePredictions;
+  Predictions linePredictions;
   predict(k, words, linePredictions, threshold);
   for (const auto& p : linePredictions) {
     predictions.push_back(
@@ -624,7 +644,8 @@ void FastText::trainThread(int32_t threadId) {
   std::ifstream ifs(args_->input);
   utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
 
-  Model model(input_, output_, args_, getTargetCounts(), threadId);
+  assert(model_);
+  Model model(*model_, threadId);
 
   const int64_t ntokens = dict_->ntokens();
   int64_t localTokenCount = 0;
@@ -742,9 +763,9 @@ void FastText::train(const Args& args) {
     input_ = createRandomMatrix();
   }
   output_ = createTrainOutputMatrix();
+  auto loss = createLoss(output_);
+  model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
   startThreads();
-  model_ =
-      std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
 }
 
 void FastText::startThreads() {

+ 2 - 1
src/fasttext.h

@@ -60,6 +60,7 @@ class FastText {
   std::shared_ptr<Matrix> createRandomMatrix() const;
   std::shared_ptr<Matrix> createTrainOutputMatrix() const;
   std::vector<int64_t> getTargetCounts() const;
+  std::shared_ptr<Loss> createLoss(std::shared_ptr<Matrix>& output);
 
   bool quant_;
   int32_t version;
@@ -111,7 +112,7 @@ class FastText {
   void predict(
       int32_t k,
       const std::vector<int32_t>& words,
-      std::vector<std::pair<real, int32_t>>& predictions,
+      Predictions& predictions,
       real threshold = 0.0) const;
 
   bool predictLine(

+ 361 - 0
src/loss.cc

@@ -0,0 +1,361 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "loss.h"
+#include "utils.h"
+
+#include <cmath>
+
+namespace fasttext {
+
+constexpr int64_t SIGMOID_TABLE_SIZE = 512;
+constexpr int64_t MAX_SIGMOID = 8;
+constexpr int64_t LOG_TABLE_SIZE = 512;
+
+bool comparePairs(
+    const std::pair<real, int32_t>& l,
+    const std::pair<real, int32_t>& r) {
+  return l.first > r.first;
+}
+
+real std_log(real x) {
+  return std::log(x + 1e-5);
+}
+
+Loss::Loss(std::shared_ptr<Matrix>& wo) : wo_(wo) {
+  t_sigmoid_.reserve(SIGMOID_TABLE_SIZE + 1);
+  for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
+    real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
+    t_sigmoid_.push_back(1.0 / (1.0 + std::exp(-x)));
+  }
+
+  t_log_.reserve(LOG_TABLE_SIZE + 1);
+  for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
+    real x = (real(i) + 1e-5) / LOG_TABLE_SIZE;
+    t_log_.push_back(std::log(x));
+  }
+}
+
+real Loss::log(real x) const {
+  if (x > 1.0) {
+    return 0.0;
+  }
+  int64_t i = int64_t(x * LOG_TABLE_SIZE);
+  return t_log_[i];
+}
+
+real Loss::sigmoid(real x) const {
+  if (x < -MAX_SIGMOID) {
+    return 0.0;
+  } else if (x > MAX_SIGMOID) {
+    return 1.0;
+  } else {
+    int64_t i =
+        int64_t((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
+    return t_sigmoid_[i];
+  }
+}
+
+void Loss::predict(
+    int32_t k,
+    real threshold,
+    Predictions& heap,
+    const Vector& hidden,
+    Vector& output) const {
+  computeOutput(hidden, output);
+  findKBest(k, threshold, heap, output);
+  std::sort_heap(heap.begin(), heap.end(), comparePairs);
+}
+
+void Loss::findKBest(
+    int32_t k,
+    real threshold,
+    Predictions& heap,
+    Vector& output) const {
+  for (int32_t i = 0; i < output.size(); i++) {
+    if (output[i] < threshold) {
+      continue;
+    }
+    if (heap.size() == k && std_log(output[i]) < heap.front().first) {
+      continue;
+    }
+    heap.push_back(std::make_pair(std_log(output[i]), i));
+    std::push_heap(heap.begin(), heap.end(), comparePairs);
+    if (heap.size() > k) {
+      std::pop_heap(heap.begin(), heap.end(), comparePairs);
+      heap.pop_back();
+    }
+  }
+}
+
+BinaryLogisticLoss::BinaryLogisticLoss(std::shared_ptr<Matrix>& wo)
+    : Loss(wo) {}
+
+real BinaryLogisticLoss::binaryLogistic(
+    int32_t target,
+    const Vector& hidden,
+    Vector& grad,
+    bool labelIsPositive,
+    real lr,
+    bool backprop) const {
+  real score = sigmoid(wo_->dotRow(hidden, target));
+  if (backprop) {
+    real alpha = lr * (real(labelIsPositive) - score);
+    grad.addRow(*wo_, target, alpha);
+    wo_->addVectorToRow(hidden, target, alpha);
+  }
+  if (labelIsPositive) {
+    return -log(score);
+  } else {
+    return -log(1.0 - score);
+  }
+}
+
+void BinaryLogisticLoss::computeOutput(const Vector& hidden, Vector& output)
+    const {
+  output.mul(*wo_, hidden);
+  int32_t osz = output.size();
+  for (int32_t i = 0; i < osz; i++) {
+    output[i] = sigmoid(output[i]);
+  }
+}
+
+OneVsAllLoss::OneVsAllLoss(std::shared_ptr<Matrix>& wo)
+    : BinaryLogisticLoss(wo) {}
+
+real OneVsAllLoss::forward(
+    const std::vector<int32_t>& targets,
+    int32_t /* we take all targets here */,
+    const Vector& hidden,
+    Vector& output,
+    Vector& grad,
+    real lr,
+    std::minstd_rand& /*rng*/,
+    bool backprop) {
+  real loss = 0.0;
+  int32_t osz = output.size();
+  for (int32_t i = 0; i < osz; i++) {
+    bool isMatch = utils::contains(targets, i);
+    loss += binaryLogistic(i, hidden, grad, isMatch, lr, backprop);
+  }
+
+  return loss;
+}
+
+NegativeSamplingLoss::NegativeSamplingLoss(
+    std::shared_ptr<Matrix>& wo,
+    int neg,
+    const std::vector<int64_t>& targetCounts)
+    : BinaryLogisticLoss(wo), neg_(neg), negatives_(), uniform_() {
+  real z = 0.0;
+  for (size_t i = 0; i < targetCounts.size(); i++) {
+    z += pow(targetCounts[i], 0.5);
+  }
+  for (size_t i = 0; i < targetCounts.size(); i++) {
+    real c = pow(targetCounts[i], 0.5);
+    for (size_t j = 0; j < c * NegativeSamplingLoss::NEGATIVE_TABLE_SIZE / z;
+         j++) {
+      negatives_.push_back(i);
+    }
+  }
+  uniform_ = std::uniform_int_distribution<size_t>(0, negatives_.size() - 1);
+}
+
+real NegativeSamplingLoss::forward(
+    const std::vector<int32_t>& targets,
+    int32_t targetIndex,
+    const Vector& hidden,
+    Vector& /* output */,
+    Vector& grad,
+    real lr,
+    std::minstd_rand& rng,
+    bool backprop) {
+  assert(targetIndex >= 0);
+  assert(targetIndex < targets.size());
+  int32_t target = targets[targetIndex];
+  real loss = binaryLogistic(target, hidden, grad, true, lr, backprop);
+
+  for (int32_t n = 0; n < neg_; n++) {
+    auto negativeTarget = getNegative(target, rng);
+    loss += binaryLogistic(negativeTarget, hidden, grad, false, lr, backprop);
+  }
+  return loss;
+}
+
+int32_t NegativeSamplingLoss::getNegative(
+    int32_t target,
+    std::minstd_rand& rng) {
+  int32_t negative;
+  do {
+    negative = negatives_[uniform_(rng)];
+  } while (target == negative);
+  return negative;
+}
+
+HierarchicalSoftmaxLoss::HierarchicalSoftmaxLoss(
+    std::shared_ptr<Matrix>& wo,
+    const std::vector<int64_t>& targetCounts)
+    : BinaryLogisticLoss(wo),
+      paths_(),
+      codes_(),
+      tree_(),
+      osz_(targetCounts.size()) {
+  buildTree(targetCounts);
+}
+
+void HierarchicalSoftmaxLoss::buildTree(const std::vector<int64_t>& counts) {
+  tree_.resize(2 * osz_ - 1);
+  for (int32_t i = 0; i < 2 * osz_ - 1; i++) {
+    tree_[i].parent = -1;
+    tree_[i].left = -1;
+    tree_[i].right = -1;
+    tree_[i].count = 1e15;
+    tree_[i].binary = false;
+  }
+  for (int32_t i = 0; i < osz_; i++) {
+    tree_[i].count = counts[i];
+  }
+  int32_t leaf = osz_ - 1;
+  int32_t node = osz_;
+  for (int32_t i = osz_; i < 2 * osz_ - 1; i++) {
+    int32_t mini[2] = {0};
+    for (int32_t j = 0; j < 2; j++) {
+      if (leaf >= 0 && tree_[leaf].count < tree_[node].count) {
+        mini[j] = leaf--;
+      } else {
+        mini[j] = node++;
+      }
+    }
+    tree_[i].left = mini[0];
+    tree_[i].right = mini[1];
+    tree_[i].count = tree_[mini[0]].count + tree_[mini[1]].count;
+    tree_[mini[0]].parent = i;
+    tree_[mini[1]].parent = i;
+    tree_[mini[1]].binary = true;
+  }
+  for (int32_t i = 0; i < osz_; i++) {
+    std::vector<int32_t> path;
+    std::vector<bool> code;
+    int32_t j = i;
+    while (tree_[j].parent != -1) {
+      path.push_back(tree_[j].parent - osz_);
+      code.push_back(tree_[j].binary);
+      j = tree_[j].parent;
+    }
+    paths_.push_back(path);
+    codes_.push_back(code);
+  }
+}
+
+real HierarchicalSoftmaxLoss::forward(
+    const std::vector<int32_t>& targets,
+    int32_t targetIndex,
+    const Vector& hidden,
+    Vector& /* the output is not an explicit Vector here */,
+    Vector& grad,
+    real lr,
+    std::minstd_rand& /*rng*/,
+    bool backprop) {
+  real loss = 0.0;
+  int32_t target = targets[targetIndex];
+  const std::vector<bool>& binaryCode = codes_[target];
+  const std::vector<int32_t>& pathToRoot = paths_[target];
+  for (int32_t i = 0; i < pathToRoot.size(); i++) {
+    loss += binaryLogistic(
+        pathToRoot[i], hidden, grad, binaryCode[i], lr, backprop);
+  }
+  return loss;
+}
+
+void HierarchicalSoftmaxLoss::predict(
+    int32_t k,
+    real threshold,
+    Predictions& heap,
+    const Vector& hidden,
+    Vector& /*output*/) const {
+  dfs(k, threshold, 2 * osz_ - 2, 0.0, heap, hidden);
+  std::sort_heap(heap.begin(), heap.end(), comparePairs);
+}
+
+void HierarchicalSoftmaxLoss::dfs(
+    int32_t k,
+    real threshold,
+    int32_t node,
+    real score,
+    Predictions& heap,
+    const Vector& hidden) const {
+  if (score < std_log(threshold)) {
+    return;
+  }
+  if (heap.size() == k && score < heap.front().first) {
+    return;
+  }
+
+  if (tree_[node].left == -1 && tree_[node].right == -1) {
+    heap.push_back(std::make_pair(score, node));
+    std::push_heap(heap.begin(), heap.end(), comparePairs);
+    if (heap.size() > k) {
+      std::pop_heap(heap.begin(), heap.end(), comparePairs);
+      heap.pop_back();
+    }
+    return;
+  }
+
+  real f = wo_->dotRow(hidden, node - osz_);
+  f = 1. / (1 + std::exp(-f));
+
+  dfs(k, threshold, tree_[node].left, score + std_log(1.0 - f), heap, hidden);
+  dfs(k, threshold, tree_[node].right, score + std_log(f), heap, hidden);
+}
+
+SoftmaxLoss::SoftmaxLoss(std::shared_ptr<Matrix>& wo) : Loss(wo) {}
+
+void SoftmaxLoss::computeOutput(const Vector& hidden, Vector& output) const {
+  output.mul(*wo_, hidden);
+  real max = output[0], z = 0.0;
+  int32_t osz = output.size();
+  for (int32_t i = 0; i < osz; i++) {
+    max = std::max(output[i], max);
+  }
+  for (int32_t i = 0; i < osz; i++) {
+    output[i] = exp(output[i] - max);
+    z += output[i];
+  }
+  for (int32_t i = 0; i < osz; i++) {
+    output[i] /= z;
+  }
+}
+
+real SoftmaxLoss::forward(
+    const std::vector<int32_t>& targets,
+    int32_t targetIndex,
+    const Vector& hidden,
+    Vector& output,
+    Vector& grad,
+    real lr,
+    std::minstd_rand& /*rng*/,
+    bool backprop) {
+  computeOutput(hidden, output);
+
+  assert(targetIndex >= 0);
+  assert(targetIndex < targets.size());
+  int32_t target = targets[targetIndex];
+
+  if (backprop) {
+    int32_t osz = wo_->size(0);
+    for (int32_t i = 0; i < osz; i++) {
+      real label = (i == target) ? 1.0 : 0.0;
+      real alpha = lr * (label - output[i]);
+      grad.addRow(*wo_, i, alpha);
+      wo_->addVectorToRow(hidden, i, alpha);
+    }
+  }
+  return -log(output[target]);
+};
+
+} // namespace fasttext

+ 176 - 0
src/loss.h

@@ -0,0 +1,176 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <memory>
+#include <random>
+#include <vector>
+
+#include "matrix.h"
+#include "real.h"
+#include "utils.h"
+#include "vector.h"
+
+namespace fasttext {
+
+class Loss {
+ private:
+  void findKBest(int32_t, real, Predictions&, Vector&) const;
+
+ protected:
+  std::vector<real> t_sigmoid_;
+  std::vector<real> t_log_;
+  std::shared_ptr<Matrix>& wo_;
+
+  real log(real x) const;
+  real sigmoid(real x) const;
+
+ public:
+  explicit Loss(std::shared_ptr<Matrix>& wo);
+  virtual ~Loss() = default;
+
+  virtual real forward(
+      const std::vector<int32_t>& targets,
+      int32_t targetIndex,
+      const Vector& hidden,
+      Vector& output,
+      Vector& grad,
+      real lr,
+      std::minstd_rand& rng,
+      bool backprop) = 0;
+  virtual void computeOutput(const Vector& hidden, Vector& output) const = 0;
+
+  virtual void predict(
+      int32_t k,
+      real threshold,
+      Predictions& heap,
+      const Vector& hidden,
+      Vector& output) const;
+};
+
+class BinaryLogisticLoss : public Loss {
+ protected:
+  real binaryLogistic(
+      int32_t target,
+      const Vector& hidden,
+      Vector& grad,
+      bool labelIsPositive,
+      real lr,
+      bool backprop) const;
+  void computeOutput(const Vector& hidden, Vector& output) const override;
+
+ public:
+  explicit BinaryLogisticLoss(std::shared_ptr<Matrix>& wo);
+  virtual ~BinaryLogisticLoss() override = default;
+};
+
+class OneVsAllLoss : public BinaryLogisticLoss {
+ public:
+  explicit OneVsAllLoss(std::shared_ptr<Matrix>& wo);
+  ~OneVsAllLoss() = default;
+  real forward(
+      const std::vector<int32_t>& targets,
+      int32_t targetIndex,
+      const Vector& hidden,
+      Vector& output,
+      Vector& grad,
+      real lr,
+      std::minstd_rand& rng,
+      bool backprop) override;
+};
+
+class NegativeSamplingLoss : public BinaryLogisticLoss {
+ protected:
+  static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
+
+  int neg_;
+  std::vector<int32_t> negatives_;
+  std::uniform_int_distribution<size_t> uniform_;
+  int32_t getNegative(int32_t target, std::minstd_rand& rng);
+
+ public:
+  explicit NegativeSamplingLoss(
+      std::shared_ptr<Matrix>& wo,
+      int neg,
+      const std::vector<int64_t>& targetCounts);
+  ~NegativeSamplingLoss() override = default;
+
+  real forward(
+      const std::vector<int32_t>& targets,
+      int32_t targetIndex,
+      const Vector& hidden,
+      Vector& output,
+      Vector& grad,
+      real lr,
+      std::minstd_rand& rng,
+      bool backprop) override;
+};
+
+class HierarchicalSoftmaxLoss : public BinaryLogisticLoss {
+ protected:
+  struct Node {
+    int32_t parent;
+    int32_t left;
+    int32_t right;
+    int64_t count;
+    bool binary;
+  };
+
+  std::vector<std::vector<int32_t>> paths_;
+  std::vector<std::vector<bool>> codes_;
+  std::vector<Node> tree_;
+  int32_t osz_;
+  void buildTree(const std::vector<int64_t>& counts);
+  void dfs(
+      int32_t k,
+      real threshold,
+      int32_t node,
+      real score,
+      Predictions& heap,
+      const Vector& hidden) const;
+
+ public:
+  explicit HierarchicalSoftmaxLoss(
+      std::shared_ptr<Matrix>& wo,
+      const std::vector<int64_t>& counts);
+  ~HierarchicalSoftmaxLoss() override = default;
+  real forward(
+      const std::vector<int32_t>& targets,
+      int32_t targetIndex,
+      const Vector& hidden,
+      Vector& output,
+      Vector& grad,
+      real lr,
+      std::minstd_rand& rng,
+      bool backprop) override;
+  void predict(
+      int32_t k,
+      real threshold,
+      Predictions& heap,
+      const Vector& hidden,
+      Vector& output) const override;
+};
+
+class SoftmaxLoss : public Loss {
+ public:
+  explicit SoftmaxLoss(std::shared_ptr<Matrix>& wo);
+  ~SoftmaxLoss() override = default;
+  real forward(
+      const std::vector<int32_t>& targets,
+      int32_t targetIndex,
+      const Vector& hidden,
+      Vector& output,
+      Vector& grad,
+      real lr,
+      std::minstd_rand& rng,
+      bool backprop) override;
+  void computeOutput(const Vector& hidden, Vector& output) const override;
+};
+
+} // namespace fasttext

+ 1 - 1
src/meter.cc

@@ -18,7 +18,7 @@ namespace fasttext {
 
 void Meter::log(
     const std::vector<int32_t>& labels,
-    const std::vector<std::pair<real, int32_t>>& predictions) {
+    const Predictions& predictions) {
   nexamples_++;
   metrics_.gold += labels.size();
   metrics_.predicted += predictions.size();

+ 2 - 3
src/meter.h

@@ -13,6 +13,7 @@
 
 #include "dictionary.h"
 #include "real.h"
+#include "utils.h"
 
 namespace fasttext {
 
@@ -38,9 +39,7 @@ class Meter {
  public:
   Meter() : metrics_(), nexamples_(0), labelMetrics_() {}
 
-  void log(
-      const std::vector<int32_t>& labels,
-      const std::vector<std::pair<real, int32_t>>& predictions);
+  void log(const std::vector<int32_t>& labels, const Predictions& predictions);
 
   double precision(int32_t);
   double recall(int32_t);

+ 24 - 311
src/model.cc

@@ -15,119 +15,36 @@
 
 namespace fasttext {
 
-constexpr int64_t SIGMOID_TABLE_SIZE = 512;
-constexpr int64_t MAX_SIGMOID = 8;
-constexpr int64_t LOG_TABLE_SIZE = 512;
-
 Model::Model(
     std::shared_ptr<Matrix> wi,
     std::shared_ptr<Matrix> wo,
     std::shared_ptr<Args> args,
-    const std::vector<int64_t>& targetCounts,
+    std::shared_ptr<Loss> loss,
     int32_t seed)
     : hidden_(args->dim), output_(wo->size(0)), grad_(args->dim), rng(seed) {
   wi_ = wi;
   wo_ = wo;
   args_ = args;
+  loss_ = loss;
   osz_ = wo->size(0);
   hsz_ = args->dim;
-  negpos = 0;
-  loss_ = 0.0;
+  lossValue_ = 0.0;
   nexamples_ = 1;
-  t_sigmoid_.reserve(SIGMOID_TABLE_SIZE + 1);
-  t_log_.reserve(LOG_TABLE_SIZE + 1);
-  initSigmoid();
-  initLog();
-  setTargetCounts(targetCounts);
-}
-
-real Model::binaryLogistic(int32_t target, bool label, real lr) {
-  real score = sigmoid(wo_->dotRow(hidden_, target));
-  real alpha = lr * (real(label) - score);
-  grad_.addRow(*wo_, target, alpha);
-  wo_->addVectorToRow(hidden_, target, alpha);
-  if (label) {
-    return -log(score);
-  } else {
-    return -log(1.0 - score);
-  }
-}
-
-real Model::negativeSampling(int32_t target, real lr) {
-  real loss = 0.0;
-  grad_.zero();
-  for (int32_t n = 0; n <= args_->neg; n++) {
-    if (n == 0) {
-      loss += binaryLogistic(target, true, lr);
-    } else {
-      loss += binaryLogistic(getNegative(target), false, lr);
-    }
-  }
-  return loss;
-}
-
-real Model::hierarchicalSoftmax(int32_t target, real lr) {
-  real loss = 0.0;
-  grad_.zero();
-  const std::vector<bool>& binaryCode = codes[target];
-  const std::vector<int32_t>& pathToRoot = paths[target];
-  for (int32_t i = 0; i < pathToRoot.size(); i++) {
-    loss += binaryLogistic(pathToRoot[i], binaryCode[i], lr);
-  }
-  return loss;
 }
 
-void Model::computeOutput(Vector& hidden, Vector& output) const {
-  output.mul(*wo_, hidden);
-}
-
-void Model::computeOutputSigmoid(Vector& hidden, Vector& output) const {
-  computeOutput(hidden, output);
-  for (int32_t i = 0; i < osz_; i++) {
-    output[i] = sigmoid(output[i]);
-  }
-}
-
-void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
-  computeOutput(hidden, output);
-  real max = output[0], z = 0.0;
-  for (int32_t i = 0; i < osz_; i++) {
-    max = std::max(output[i], max);
-  }
-  for (int32_t i = 0; i < osz_; i++) {
-    output[i] = exp(output[i] - max);
-    z += output[i];
-  }
-  for (int32_t i = 0; i < osz_; i++) {
-    output[i] /= z;
-  }
-}
-
-void Model::computeOutputSoftmax() {
-  computeOutputSoftmax(hidden_, output_);
-}
-
-real Model::softmax(int32_t target, real lr) {
-  grad_.zero();
-  computeOutputSoftmax();
-  for (int32_t i = 0; i < osz_; i++) {
-    real label = (i == target) ? 1.0 : 0.0;
-    real alpha = lr * (label - output_[i]);
-    grad_.addRow(*wo_, i, alpha);
-    wo_->addVectorToRow(hidden_, i, alpha);
-  }
-  return -log(output_[target]);
-}
-
-real Model::oneVsAll(const std::vector<int32_t>& targets, real lr) {
-  real loss = 0.0;
-  for (int32_t i = 0; i < osz_; i++) {
-    bool isMatch = utils::contains(targets, i);
-    loss += binaryLogistic(i, isMatch, lr);
-  }
-
-  return loss;
-}
+Model::Model(const Model& other, int32_t seed)
+    : wi_(other.wi_),
+      wo_(other.wo_),
+      args_(other.args_),
+      loss_(other.loss_),
+      hidden_(other.hidden_),
+      output_(other.output_),
+      grad_(other.grad_),
+      hsz_(other.hsz_),
+      osz_(other.osz_),
+      lossValue_(other.lossValue_),
+      nexamples_(other.nexamples_),
+      rng(seed) {}
 
 void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden)
     const {
@@ -139,17 +56,11 @@ void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden)
   hidden.mul(1.0 / input.size());
 }
 
-bool Model::comparePairs(
-    const std::pair<real, int32_t>& l,
-    const std::pair<real, int32_t>& r) {
-  return l.first > r.first;
-}
-
 void Model::predict(
     const std::vector<int32_t>& input,
     int32_t k,
     real threshold,
-    std::vector<std::pair<real, int32_t>>& heap,
+    Predictions& heap,
     Vector& hidden,
     Vector& output) const {
   if (k == Model::kUnlimitedPredictions) {
@@ -162,101 +73,18 @@ void Model::predict(
   }
   heap.reserve(k + 1);
   computeHidden(input, hidden);
-  if (args_->loss == loss_name::hs) {
-    dfs(k, threshold, 2 * osz_ - 2, 0.0, heap, hidden);
-  } else {
-    findKBest(k, threshold, heap, hidden, output);
-  }
-  std::sort_heap(heap.begin(), heap.end(), comparePairs);
+
+  loss_->predict(k, threshold, heap, hidden, output);
 }
 
 void Model::predict(
     const std::vector<int32_t>& input,
     int32_t k,
     real threshold,
-    std::vector<std::pair<real, int32_t>>& heap) {
+    Predictions& heap) {
   predict(input, k, threshold, heap, hidden_, output_);
 }
 
-void Model::findKBest(
-    int32_t k,
-    real threshold,
-    std::vector<std::pair<real, int32_t>>& heap,
-    Vector& hidden,
-    Vector& output) const {
-  if (args_->loss == loss_name::ova) {
-    computeOutputSigmoid(hidden, output);
-  } else {
-    computeOutputSoftmax(hidden, output);
-  }
-  for (int32_t i = 0; i < osz_; i++) {
-    if (output[i] < threshold) {
-      continue;
-    }
-    if (heap.size() == k && std_log(output[i]) < heap.front().first) {
-      continue;
-    }
-    heap.push_back(std::make_pair(std_log(output[i]), i));
-    std::push_heap(heap.begin(), heap.end(), comparePairs);
-    if (heap.size() > k) {
-      std::pop_heap(heap.begin(), heap.end(), comparePairs);
-      heap.pop_back();
-    }
-  }
-}
-
-void Model::dfs(
-    int32_t k,
-    real threshold,
-    int32_t node,
-    real score,
-    std::vector<std::pair<real, int32_t>>& heap,
-    Vector& hidden) const {
-  if (score < std_log(threshold)) {
-    return;
-  }
-  if (heap.size() == k && score < heap.front().first) {
-    return;
-  }
-
-  if (tree[node].left == -1 && tree[node].right == -1) {
-    heap.push_back(std::make_pair(score, node));
-    std::push_heap(heap.begin(), heap.end(), comparePairs);
-    if (heap.size() > k) {
-      std::pop_heap(heap.begin(), heap.end(), comparePairs);
-      heap.pop_back();
-    }
-    return;
-  }
-
-  real f = wo_->dotRow(hidden, node - osz_);
-  f = 1. / (1 + std::exp(-f));
-
-  dfs(k, threshold, tree[node].left, score + std_log(1.0 - f), heap, hidden);
-  dfs(k, threshold, tree[node].right, score + std_log(f), heap, hidden);
-}
-
-real Model::computeLoss(
-    const std::vector<int32_t>& targets,
-    int32_t targetIndex,
-    real lr) {
-  real loss = 0.0;
-
-  if (args_->loss == loss_name::ns) {
-    loss = negativeSampling(targets[targetIndex], lr);
-  } else if (args_->loss == loss_name::hs) {
-    loss = hierarchicalSoftmax(targets[targetIndex], lr);
-  } else if (args_->loss == loss_name::softmax) {
-    loss = softmax(targets[targetIndex], lr);
-  } else if (args_->loss == loss_name::ova) {
-    loss = oneVsAll(targets, lr);
-  } else {
-    throw std::invalid_argument("Unhandled loss function for this model.");
-  }
-
-  return loss;
-}
-
 void Model::update(
     const std::vector<int32_t>& input,
     const std::vector<int32_t>& targets,
@@ -267,13 +95,9 @@ void Model::update(
   }
   computeHidden(input, hidden_);
 
-  if (targetIndex == kAllLabelsAsTarget) {
-    loss_ += computeLoss(targets, -1, lr);
-  } else {
-    assert(targetIndex >= 0);
-    assert(targetIndex < targets.size());
-    loss_ += computeLoss(targets, targetIndex, lr);
-  }
+  grad_.zero();
+  lossValue_ += loss_->forward(
+      targets, targetIndex, hidden_, output_, grad_, lr, rng, true);
 
   nexamples_ += 1;
 
@@ -285,123 +109,12 @@ void Model::update(
   }
 }
 
-void Model::setTargetCounts(const std::vector<int64_t>& counts) {
-  assert(counts.size() == osz_);
-  if (args_->loss == loss_name::ns) {
-    initTableNegatives(counts);
-  }
-  if (args_->loss == loss_name::hs) {
-    buildTree(counts);
-  }
-}
-
-void Model::initTableNegatives(const std::vector<int64_t>& counts) {
-  real z = 0.0;
-  for (size_t i = 0; i < counts.size(); i++) {
-    z += pow(counts[i], 0.5);
-  }
-  for (size_t i = 0; i < counts.size(); i++) {
-    real c = pow(counts[i], 0.5);
-    for (size_t j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) {
-      negatives_.push_back(i);
-    }
-  }
-  std::shuffle(negatives_.begin(), negatives_.end(), rng);
-}
-
-int32_t Model::getNegative(int32_t target) {
-  int32_t negative;
-  do {
-    negative = negatives_[negpos];
-    negpos = (negpos + 1) % negatives_.size();
-  } while (target == negative);
-  return negative;
-}
-
-void Model::buildTree(const std::vector<int64_t>& counts) {
-  tree.resize(2 * osz_ - 1);
-  for (int32_t i = 0; i < 2 * osz_ - 1; i++) {
-    tree[i].parent = -1;
-    tree[i].left = -1;
-    tree[i].right = -1;
-    tree[i].count = 1e15;
-    tree[i].binary = false;
-  }
-  for (int32_t i = 0; i < osz_; i++) {
-    tree[i].count = counts[i];
-  }
-  int32_t leaf = osz_ - 1;
-  int32_t node = osz_;
-  for (int32_t i = osz_; i < 2 * osz_ - 1; i++) {
-    int32_t mini[2];
-    for (int32_t j = 0; j < 2; j++) {
-      if (leaf >= 0 && tree[leaf].count < tree[node].count) {
-        mini[j] = leaf--;
-      } else {
-        mini[j] = node++;
-      }
-    }
-    tree[i].left = mini[0];
-    tree[i].right = mini[1];
-    tree[i].count = tree[mini[0]].count + tree[mini[1]].count;
-    tree[mini[0]].parent = i;
-    tree[mini[1]].parent = i;
-    tree[mini[1]].binary = true;
-  }
-  for (int32_t i = 0; i < osz_; i++) {
-    std::vector<int32_t> path;
-    std::vector<bool> code;
-    int32_t j = i;
-    while (tree[j].parent != -1) {
-      path.push_back(tree[j].parent - osz_);
-      code.push_back(tree[j].binary);
-      j = tree[j].parent;
-    }
-    paths.push_back(path);
-    codes.push_back(code);
-  }
-}
-
 real Model::getLoss() const {
-  return loss_ / nexamples_;
-}
-
-void Model::initSigmoid() {
-  for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
-    real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
-    t_sigmoid_.push_back(1.0 / (1.0 + std::exp(-x)));
-  }
-}
-
-void Model::initLog() {
-  for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
-    real x = (real(i) + 1e-5) / LOG_TABLE_SIZE;
-    t_log_.push_back(std::log(x));
-  }
-}
-
-real Model::log(real x) const {
-  if (x > 1.0) {
-    return 0.0;
-  }
-  int64_t i = int64_t(x * LOG_TABLE_SIZE);
-  return t_log_[i];
+  return lossValue_ / nexamples_;
 }
 
 real Model::std_log(real x) const {
   return std::log(x + 1e-5);
 }
 
-real Model::sigmoid(real x) const {
-  if (x < -MAX_SIGMOID) {
-    return 0.0;
-  } else if (x > MAX_SIGMOID) {
-    return 1.0;
-  } else {
-    int64_t i =
-        int64_t((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
-    return t_sigmoid_[i];
-  }
-}
-
 } // namespace fasttext

+ 15 - 68
src/model.h

@@ -14,109 +14,56 @@
 #include <vector>
 
 #include "args.h"
+#include "loss.h"
 #include "matrix.h"
 #include "real.h"
 #include "vector.h"
 
 namespace fasttext {
 
-struct Node {
-  int32_t parent;
-  int32_t left;
-  int32_t right;
-  int64_t count;
-  bool binary;
-};
-
 class Model {
  protected:
   std::shared_ptr<Matrix> wi_;
   std::shared_ptr<Matrix> wo_;
   std::shared_ptr<Args> args_;
+  std::shared_ptr<Loss> loss_;
   Vector hidden_;
   Vector output_;
   Vector grad_;
   int32_t hsz_;
   int32_t osz_;
-  real loss_;
+  real lossValue_;
   int64_t nexamples_;
-  std::vector<real> t_sigmoid_;
-  std::vector<real> t_log_;
-  // used for negative sampling:
-  std::vector<int32_t> negatives_;
-  size_t negpos;
-  // used for hierarchical softmax:
-  std::vector<std::vector<int32_t>> paths;
-  std::vector<std::vector<bool>> codes;
-  std::vector<Node> tree;
-
-  static bool comparePairs(
-      const std::pair<real, int32_t>&,
-      const std::pair<real, int32_t>&);
-
-  int32_t getNegative(int32_t target);
-  void initSigmoid();
-  void initLog();
-  void computeOutput(Vector&, Vector&) const;
-  void setTargetCounts(const std::vector<int64_t>&);
-
-  static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
 
  public:
   Model(
-      std::shared_ptr<Matrix>,
-      std::shared_ptr<Matrix>,
-      std::shared_ptr<Args>,
-      const std::vector<int64_t>&,
-      int32_t);
-
-  real binaryLogistic(int32_t, bool, real);
-  real negativeSampling(int32_t, real);
-  real hierarchicalSoftmax(int32_t, real);
-  real softmax(int32_t, real);
-  real oneVsAll(const std::vector<int32_t>&, real);
+      std::shared_ptr<Matrix> wi,
+      std::shared_ptr<Matrix> wo,
+      std::shared_ptr<Args> args,
+      std::shared_ptr<Loss> loss,
+      int32_t seed);
+  Model(const Model& model, int32_t seed);
+  Model(const Model& model) = delete;
+  Model(Model&& model) = delete;
+  Model& operator=(const Model& other) = delete;
+  Model& operator=(Model&& other) = delete;
 
   void predict(
       const std::vector<int32_t>&,
       int32_t,
       real,
-      std::vector<std::pair<real, int32_t>>&,
-      Vector&,
-      Vector&) const;
-  void predict(
-      const std::vector<int32_t>&,
-      int32_t,
-      real,
-      std::vector<std::pair<real, int32_t>>&);
-  void dfs(
-      int32_t,
-      real,
-      int32_t,
-      real,
-      std::vector<std::pair<real, int32_t>>&,
-      Vector&) const;
-  void findKBest(
-      int32_t,
-      real,
-      std::vector<std::pair<real, int32_t>>&,
+      Predictions&,
       Vector&,
       Vector&) const;
+  void predict(const std::vector<int32_t>&, int32_t, real, Predictions&);
   void update(
       const std::vector<int32_t>&,
       const std::vector<int32_t>&,
       int32_t,
       real);
-  real computeLoss(const std::vector<int32_t>&, int32_t, real);
   void computeHidden(const std::vector<int32_t>&, Vector&) const;
-  void computeOutputSigmoid(Vector&, Vector&) const;
-  void computeOutputSoftmax(Vector&, Vector&) const;
-  void computeOutputSoftmax();
 
-  void initTableNegatives(const std::vector<int64_t>&);
-  void buildTree(const std::vector<int64_t>&);
   real getLoss() const;
-  real sigmoid(real) const;
-  real log(real) const;
   real std_log(real) const;
 
   std::minstd_rand rng;

+ 4 - 0
src/utils.h

@@ -8,6 +8,8 @@
 
 #pragma once
 
+#include "real.h"
+
 #include <algorithm>
 #include <fstream>
 #include <vector>
@@ -22,6 +24,8 @@
 
 namespace fasttext {
 
+using Predictions = std::vector<std::pair<real, int32_t>>;
+
 namespace utils {
 
 int64_t size(std::ifstream&);

+ 0 - 7
src/vector.cc

@@ -20,13 +20,6 @@ namespace fasttext {
 
 Vector::Vector(int64_t m) : data_(m) {}
 
-Vector::Vector(Vector&& other) noexcept : data_(std::move(other.data_)) {}
-
-Vector& Vector::operator=(Vector&& other) {
-  data_ = std::move(other.data_);
-  return *this;
-}
-
 void Vector::zero() {
   std::fill(data_.begin(), data_.end(), 0.0);
 }

+ 4 - 4
src/vector.h

@@ -24,10 +24,10 @@ class Vector {
 
  public:
   explicit Vector(int64_t);
-  Vector(const Vector&) = delete;
-  Vector(Vector&&) noexcept;
-  Vector& operator=(const Vector&) = delete;
-  Vector& operator=(Vector&&);
+  Vector(const Vector&) = default;
+  Vector(Vector&&) noexcept = default;
+  Vector& operator=(const Vector&) = default;
+  Vector& operator=(Vector&&) = default;
 
   inline real* data() {
     return data_.data();