Browse Source

adding subwords for supervised models

Summary: Quick hack to add subwords for supervised models. Small bug: word ids are added twice for in-vocabulary words.

Reviewed By: piotr-bojanowski

Differential Revision: D5444991

fbshipit-source-id: 73c7f0bd44405292e5bb7e34225c9ecbab709931
Edouard Grave 8 years ago
parent
commit
ebbd3bfee5
4 changed files with 89 additions and 48 deletions
  1. 71 41
      src/dictionary.cc
  2. 5 2
      src/dictionary.h
  3. 11 4
      src/fasttext.cc
  4. 2 1
      src/fasttext.h

+ 71 - 41
src/dictionary.cc

@@ -173,7 +173,7 @@ void Dictionary::computeSubwords(const std::string& word,
       }
       if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
         int32_t h = hash(ngram) % args_->bucket;
-        ngrams.push_back(nwords_ + h);
+        pushHash(ngrams, h);
       }
     }
   }
@@ -182,8 +182,11 @@ void Dictionary::computeSubwords(const std::string& word,
 void Dictionary::initNgrams() {
   for (size_t i = 0; i < size_; i++) {
     std::string word = BOW + words_[i].word + EOW;
+    words_[i].subwords.clear();
     words_[i].subwords.push_back(i);
-    computeSubwords(word, words_[i].subwords);
+    if (words_[i].word != EOS) {
+      computeSubwords(word, words_[i].subwords);
+    }
   }
 }
 
@@ -281,77 +284,103 @@ std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
 }
 
 void Dictionary::addWordNgrams(std::vector<int32_t>& line,
-                           const std::vector<int32_t>& hashes,
-                           int32_t n) const {
-  if (pruneidx_size_ == 0) return;
+                               const std::vector<int32_t>& hashes,
+                               int32_t n) const {
   for (int32_t i = 0; i < hashes.size(); i++) {
     uint64_t h = hashes[i];
     for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
       h = h * 116049371 + hashes[j];
-      int64_t id = h % args_->bucket;
-      if (pruneidx_size_ > 0) {
-        if (pruneidx_.count(id)) {
-          id = pruneidx_.at(id);
-        } else {continue;}
-      }
-      line.push_back(nwords_ + id);
+      pushHash(line, h % args_->bucket);
     }
   }
 }
 
-int32_t Dictionary::getLine(std::istream& in,
-                            std::vector<int32_t>& words,
-                            std::vector<int32_t>& word_hashes,
-                            std::vector<int32_t>& labels,
-                            std::minstd_rand& rng) const {
-  std::uniform_real_distribution<> uniform(0, 1);
+void Dictionary::addSubwords(std::vector<int32_t>& line,
+                             const std::string& token,
+                             int32_t wid) const {
+  if (wid < 0) { // out of vocab
+    computeSubwords(BOW + token + EOW, line);
+  } else {
+    if (args_->maxn <= 0) { // in vocab w/o subwords
+      line.push_back(wid);
+    } else { // in vocab w/ subwords
+      const std::vector<int32_t>& ngrams = getSubwords(wid);
+      line.insert(line.end(), ngrams.cbegin(), ngrams.cend());
+    }
+  }
+}
 
+void Dictionary::reset(std::istream& in) const {
   if (in.eof()) {
     in.clear();
     in.seekg(std::streampos(0));
   }
+}
 
-  words.clear();
-  labels.clear();
-  word_hashes.clear();
-  int32_t ntokens = 0;
+int32_t Dictionary::getLine(std::istream& in,
+                            std::vector<int32_t>& words,
+                            std::minstd_rand& rng) const {
+  std::uniform_real_distribution<> uniform(0, 1);
   std::string token;
+  int32_t ntokens = 0;
+
+  reset(in);
+  words.clear();
   while (readWord(in, token)) {
-    uint32_t h = hash(token);
-    int32_t wid = getId(token, h);
-    if (wid < 0) {
-      entry_type type = getType(token);
-      if (type == entry_type::word) word_hashes.push_back(h);
-      continue;
-    }
-    entry_type type = getType(wid);
+    int32_t h = find(token);
+    int32_t wid = word2int_[h];
+    if (wid < 0) continue;
+
     ntokens++;
-    if (type == entry_type::word && !discard(wid, uniform(rng))) {
+    if (getType(wid) == entry_type::word && !discard(wid, uniform(rng))) {
       words.push_back(wid);
-      word_hashes.push_back(hash(token));
-    }
-    if (type == entry_type::label) {
-      labels.push_back(wid - nwords_);
     }
-    if (token == EOS) break;
-    if (ntokens > MAX_LINE_SIZE && args_->model != model_name::sup) break;
+    if (ntokens > MAX_LINE_SIZE || token == EOS) break;
   }
   return ntokens;
 }
 
-
 int32_t Dictionary::getLine(std::istream& in,
                             std::vector<int32_t>& words,
                             std::vector<int32_t>& labels,
                             std::minstd_rand& rng) const {
   std::vector<int32_t> word_hashes;
-  int32_t ntokens = getLine(in, words, word_hashes, labels, rng);
-  if (args_->model == model_name::sup ) {
-    addWordNgrams(words, word_hashes, args_->wordNgrams);
+  std::string token;
+  int32_t ntokens = 0;
+
+  reset(in);
+  words.clear();
+  labels.clear();
+  while (readWord(in, token)) {
+    uint32_t h = hash(token);
+    int32_t wid = getId(token, h);
+    entry_type type = wid < 0 ? getType(token) : getType(wid);
+
+    ntokens++;
+    if (type == entry_type::word) {
+      addSubwords(words, token, wid);
+      word_hashes.push_back(h);
+    } else if (type == entry_type::label && wid >= 0) {
+      labels.push_back(wid - nwords_);
+    }
+    if (token == EOS) break;
   }
+  addWordNgrams(words, word_hashes, args_->wordNgrams);
   return ntokens;
 }
 
+void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
+  if (pruneidx_size_ == 0 || id < 0) return;
+  if (pruneidx_size_ > 0) {
+    if (pruneidx_.count(id)) {
+      id = pruneidx_.at(id);
+    } else {
+      return;
+    }
+  }
+  hashes.push_back(nwords_ + id);
+}
+
 std::string Dictionary::getLabel(int32_t lid) const {
   assert(lid >= 0);
   assert(lid < nlabels_);
@@ -440,6 +469,7 @@ void Dictionary::prune(std::vector<int32_t>& idx) {
   nwords_ = words.size();
   size_ = nwords_ +  nlabels_;
   words_.erase(words_.begin() + size_, words_.end());
+  initNgrams();
 }
 
 }

+ 5 - 2
src/dictionary.h

@@ -42,6 +42,9 @@ class Dictionary {
     int32_t find(const std::string&, uint32_t h) const;
     void initTableDiscard();
     void initNgrams();
+    void reset(std::istream&) const;
+    void pushHash(std::vector<int32_t>&, int32_t) const;
+    void addSubwords(std::vector<int32_t>&, const std::string&, int32_t) const;
 
     std::shared_ptr<Args> args_;
     std::vector<int32_t> word2int_;
@@ -95,10 +98,10 @@ class Dictionary {
     void save(std::ostream&) const;
     void load(std::istream&);
     std::vector<int64_t> getCounts(entry_type) const;
-    int32_t getLine(std::istream&, std::vector<int32_t>&, std::vector<int32_t>&,
-                    std::vector<int32_t>&, std::minstd_rand&) const;
     int32_t getLine(std::istream&, std::vector<int32_t>&,
                     std::vector<int32_t>&, std::minstd_rand&) const;
+    int32_t getLine(std::istream&, std::vector<int32_t>&,
+                    std::minstd_rand&) const;
     void threshold(int64_t, int64_t);
     void prune(std::vector<int32_t>&);
 };

+ 11 - 4
src/fasttext.cc

@@ -83,13 +83,12 @@ void FastText::saveOutput() {
 
 bool FastText::checkModel(std::istream& in) {
   int32_t magic;
-  int32_t version;
   in.read((char*)&(magic), sizeof(int32_t));
   if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) {
     return false;
   }
   in.read((char*)&(version), sizeof(int32_t));
-  if (version != FASTTEXT_VERSION) {
+  if (version > FASTTEXT_VERSION) {
     return false;
   }
   return true;
@@ -157,7 +156,10 @@ void FastText::loadModel(std::istream& in) {
   qinput_ = std::make_shared<QMatrix>();
   qoutput_ = std::make_shared<QMatrix>();
   args_->load(in);
-
+  if (version == 11 && args_->model == model_name::sup) {
+    // backward compatibility: old supervised models do not use char ngrams.
+    args_->maxn = 0;
+  }
   dict_->load(in);
 
   bool quant_input;
@@ -246,6 +248,7 @@ void FastText::quantize(std::shared_ptr<Args> qargs) {
       args_->verbose = qargs->verbose;
       start = clock();
       tokenCount = 0;
+      start = clock();
       std::vector<std::thread> threads;
       for (int32_t i = 0; i < args_->thread; i++) {
         threads.push_back(std::thread([=]() { trainThread(i); }));
@@ -337,6 +340,7 @@ void FastText::predict(std::istream& in, int32_t k,
   std::vector<int32_t> words, labels;
   predictions.clear();
   dict_->getLine(in, words, labels, model_->rng);
+  predictions.clear();
   if (words.empty()) return;
   Vector hidden(args_->dim);
   Vector output(dict_->nlabels());
@@ -350,6 +354,7 @@ void FastText::predict(std::istream& in, int32_t k,
 void FastText::predict(std::istream& in, int32_t k, bool print_prob) {
   std::vector<std::pair<real,std::string>> predictions;
   while (in.peek() != EOF) {
+    predictions.clear();
     predict(in, k, predictions);
     if (predictions.empty()) {
       std::cout << std::endl;
@@ -556,12 +561,14 @@ void FastText::trainThread(int32_t threadId) {
   while (tokenCount < args_->epoch * ntokens) {
     real progress = real(tokenCount) / (args_->epoch * ntokens);
     real lr = args_->lr * (1.0 - progress);
-    localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
     if (args_->model == model_name::sup) {
+      localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
       supervised(model, lr, line, labels);
     } else if (args_->model == model_name::cbow) {
+      localTokenCount += dict_->getLine(ifs, line, model.rng);
       cbow(model, lr, line);
     } else if (args_->model == model_name::sg) {
+      localTokenCount += dict_->getLine(ifs, line, model.rng);
       skipgram(model, lr, line);
     }
     if (localTokenCount > args_->lrUpdateRate) {

+ 2 - 1
src/fasttext.h

@@ -10,7 +10,7 @@
 #ifndef FASTTEXT_FASTTEXT_H
 #define FASTTEXT_FASTTEXT_H
 
-#define FASTTEXT_VERSION 11 /* Version 1a */
+#define FASTTEXT_VERSION 12 /* Version 1b */
 #define FASTTEXT_FILEFORMAT_MAGIC_INT32 793712314
 
 #include <time.h>
@@ -49,6 +49,7 @@ class FastText {
     bool checkModel(std::istream&);
 
     bool quant_;
+    int32_t version;
 
   public:
     FastText();