Sfoglia il codice sorgente

No more args_ in `Model`

Summary: removing `args_` in `Model`

Reviewed By: EdouardGrave

Differential Revision: D13487730

fbshipit-source-id: ef78aa00637b9b5a34935f5b41f3516a5d37866a
Onur Çelebi 7 anni fa
parent
commit
b9f79a7dd6
3 ha cambiato i file con 49 aggiunte e 34 eliminazioni
  1. 15 4
      src/fasttext.cc
  2. 16 16
      src/model.cc
  3. 18 14
      src/model.h

+ 15 - 4
src/fasttext.cc

@@ -257,7 +257,9 @@ void FastText::loadModel(std::istream& in) {
   output_->load(in);
 
   auto loss = createLoss(output_);
-  model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
+  bool normalizeGradient = (args_->model == model_name::sup);
+  model_ = std::make_shared<Model>(
+      input_, output_, loss, args_->dim, normalizeGradient, 0);
 }
 
 void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
@@ -316,6 +318,8 @@ void FastText::quantize(const Args& qargs) {
       std::dynamic_pointer_cast<DenseMatrix>(input_);
   std::shared_ptr<DenseMatrix> output =
       std::dynamic_pointer_cast<DenseMatrix>(output_);
+  bool normalizeGradient = (args_->model == model_name::sup);
+
   if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {
     auto idx = selectEmbeddings(qargs.cutoff);
     dict_->prune(idx);
@@ -333,7 +337,8 @@ void FastText::quantize(const Args& qargs) {
       args_->thread = qargs.thread;
       args_->verbose = qargs.verbose;
       auto loss = createLoss(output_);
-      model_ = std::make_shared<Model>(input, output, args_, loss, 0);
+      model_ = std::make_shared<Model>(
+          input, output, loss, args_->dim, normalizeGradient, 0);
       startThreads();
     }
   }
@@ -348,7 +353,8 @@ void FastText::quantize(const Args& qargs) {
 
   quant_ = true;
   auto loss = createLoss(output_);
-  model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
+  model_ = std::make_shared<Model>(
+      input_, output_, loss, args_->dim, normalizeGradient, 0);
 }
 
 void FastText::supervised(
@@ -438,6 +444,9 @@ void FastText::predict(
   }
   Vector hidden(args_->dim);
   Vector output(dict_->nlabels());
+  if (args_->model != model_name::sup) {
+    throw std::invalid_argument("Model needs to be supervised for prediction!");
+  }
   model_->predict(words, k, threshold, predictions, hidden, output);
 }
 
@@ -764,7 +773,9 @@ void FastText::train(const Args& args) {
   }
   output_ = createTrainOutputMatrix();
   auto loss = createLoss(output_);
-  model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
+  bool normalizeGradient = (args_->model == model_name::sup);
+  model_ = std::make_shared<Model>(
+      input_, output_, loss, args_->dim, normalizeGradient, 0);
   startThreads();
 }
 

+ 16 - 16
src/model.cc

@@ -18,24 +18,26 @@ namespace fasttext {
 Model::Model(
     std::shared_ptr<Matrix> wi,
     std::shared_ptr<Matrix> wo,
-    std::shared_ptr<Args> args,
     std::shared_ptr<Loss> loss,
+    int32_t hiddenSize,
+    bool normalizeGradient,
     int32_t seed)
-    : hidden_(args->dim), output_(wo->size(0)), grad_(args->dim), rng(seed) {
-  wi_ = wi;
-  wo_ = wo;
-  args_ = args;
-  loss_ = loss;
-  osz_ = wo->size(0);
-  hsz_ = args->dim;
-  lossValue_ = 0.0;
-  nexamples_ = 1;
-}
+    : wi_(wi),
+      wo_(wo),
+      loss_(loss),
+      hidden_(hiddenSize),
+      output_(wo->size(0)),
+      grad_(hiddenSize),
+      hsz_(hiddenSize),
+      osz_(wo->size(0)),
+      lossValue_(0.0),
+      nexamples_(1),
+      normalizeGradient_(normalizeGradient),
+      rng(seed) {}
 
 Model::Model(const Model& other, int32_t seed)
     : wi_(other.wi_),
       wo_(other.wo_),
-      args_(other.args_),
       loss_(other.loss_),
       hidden_(other.hidden_),
       output_(other.output_),
@@ -44,6 +46,7 @@ Model::Model(const Model& other, int32_t seed)
       osz_(other.osz_),
       lossValue_(other.lossValue_),
       nexamples_(other.nexamples_),
+      normalizeGradient_(other.normalizeGradient_),
       rng(seed) {}
 
 void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden)
@@ -68,9 +71,6 @@ void Model::predict(
   } else if (k <= 0) {
     throw std::invalid_argument("k needs to be 1 or higher!");
   }
-  if (args_->model != model_name::sup) {
-    throw std::invalid_argument("Model needs to be supervised for prediction!");
-  }
   heap.reserve(k + 1);
   computeHidden(input, hidden);
 
@@ -101,7 +101,7 @@ void Model::update(
 
   nexamples_ += 1;
 
-  if (args_->model == model_name::sup) {
+  if (normalizeGradient_) {
     grad_.mul(1.0 / input.size());
   }
   for (auto it = input.cbegin(); it != input.cend(); ++it) {

+ 18 - 14
src/model.h

@@ -13,7 +13,6 @@
 #include <utility>
 #include <vector>
 
-#include "args.h"
 #include "loss.h"
 #include "matrix.h"
 #include "real.h"
@@ -25,7 +24,6 @@ class Model {
  protected:
   std::shared_ptr<Matrix> wi_;
   std::shared_ptr<Matrix> wo_;
-  std::shared_ptr<Args> args_;
   std::shared_ptr<Loss> loss_;
   Vector hidden_;
   Vector output_;
@@ -34,13 +32,15 @@ class Model {
   int32_t osz_;
   real lossValue_;
   int64_t nexamples_;
+  bool normalizeGradient_;
 
  public:
   Model(
       std::shared_ptr<Matrix> wi,
       std::shared_ptr<Matrix> wo,
-      std::shared_ptr<Args> args,
       std::shared_ptr<Loss> loss,
+      int32_t hiddenSize,
+      bool normalizeGradient,
       int32_t seed);
   Model(const Model& model, int32_t seed);
   Model(const Model& model) = delete;
@@ -49,18 +49,22 @@ class Model {
   Model& operator=(Model&& other) = delete;
 
   void predict(
-      const std::vector<int32_t>&,
-      int32_t,
-      real,
-      Predictions&,
-      Vector&,
-      Vector&) const;
-  void predict(const std::vector<int32_t>&, int32_t, real, Predictions&);
+      const std::vector<int32_t>& input,
+      int32_t k,
+      real threshold,
+      Predictions& heap,
+      Vector& hidden,
+      Vector& output) const;
+  void predict(
+      const std::vector<int32_t>& input,
+      int32_t k,
+      real threshold,
+      Predictions& heap);
   void update(
-      const std::vector<int32_t>&,
-      const std::vector<int32_t>&,
-      int32_t,
-      real);
+      const std::vector<int32_t>& input,
+      const std::vector<int32_t>& targets,
+      int32_t targetIndex,
+      real lr);
   void computeHidden(const std::vector<int32_t>&, Vector&) const;
 
   real getLoss() const;