Ver Fonte

Quantization

Summary: Add the quantization described in the paper fasttext.zip

Reviewed By: mdouze

Differential Revision: D4445021

fbshipit-source-id: 12b9d6b6c358dc8232c03988273312dc32e907ff
Armand Joulin há 8 anos atrás
pai
commit
fbc4214689
19 ficheiros alterados com 999 adições e 89 exclusões
  1. 7 1
      Makefile
  2. 40 0
      quantization-example.sh
  3. 26 5
      src/args.cc
  4. 6 0
      src/args.h
  5. 168 36
      src/dictionary.cc
  6. 13 1
      src/dictionary.h
  7. 111 9
      src/fasttext.cc
  8. 15 0
      src/fasttext.h
  9. 50 15
      src/main.cc
  10. 50 8
      src/matrix.cc
  11. 11 1
      src/matrix.h
  12. 29 6
      src/model.cc
  13. 5 1
      src/model.h
  14. 211 0
      src/productquantizer.cc
  15. 67 0
      src/productquantizer.h
  16. 110 0
      src/qmatrix.cc
  17. 60 0
      src/qmatrix.h
  18. 17 6
      src/vector.cc
  19. 3 0
      src/vector.h

+ 7 - 1
Makefile

@@ -9,7 +9,7 @@
 
 CXX = c++
 CXXFLAGS = -pthread -std=c++0x
-OBJS = args.o dictionary.o matrix.o vector.o model.o utils.o fasttext.o
+OBJS = args.o dictionary.o productquantizer.o matrix.o qmatrix.o vector.o model.o utils.o fasttext.o
 INCLUDES = -I.
 
 opt: CXXFLAGS += -O3 -funroll-loops
@@ -24,9 +24,15 @@ args.o: src/args.cc src/args.h
 dictionary.o: src/dictionary.cc src/dictionary.h src/args.h
 	$(CXX) $(CXXFLAGS) -c src/dictionary.cc
 
+productquantizer.o: src/productquantizer.cc src/productquantizer.h src/utils.h
+	$(CXX) $(CXXFLAGS) -c src/productquantizer.cc
+
 matrix.o: src/matrix.cc src/matrix.h src/utils.h
 	$(CXX) $(CXXFLAGS) -c src/matrix.cc
 
+qmatrix.o: src/qmatrix.cc src/qmatrix.h src/utils.h
+	$(CXX) $(CXXFLAGS) -c src/qmatrix.cc
+
 vector.o: src/vector.cc src/vector.h src/utils.h
 	$(CXX) $(CXXFLAGS) -c src/vector.cc
 

+ 40 - 0
quantization-example.sh

@@ -0,0 +1,40 @@
+myshuf() {
+  perl -MList::Util=shuffle -e 'print shuffle(<>);' "$@";
+}
+
+normalize_text() {
+  tr '[:upper:]' '[:lower:]' | sed -e 's/^/__label__/g' | \
+    sed -e "s/'/ ' /g" -e 's/"//g' -e 's/\./ \. /g' -e 's/<br \/>/ /g' \
+        -e 's/,/ , /g' -e 's/(/ ( /g' -e 's/)/ ) /g' -e 's/\!/ \! /g' \
+        -e 's/\?/ \? /g' -e 's/\;/ /g' -e 's/\:/ /g' | tr -s " " | myshuf
+}
+
+RESULTDIR=result
+DATADIR=data
+
+mkdir -p "${RESULTDIR}"
+mkdir -p "${DATADIR}"
+
+if [ ! -f "${DATADIR}/dbpedia.train" ]
+then
+  wget -c "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz" -O "${DATADIR}/dbpedia_csv.tar.gz"
+  tar -xzvf "${DATADIR}/dbpedia_csv.tar.gz" -C "${DATADIR}"
+  cat "${DATADIR}/dbpedia_csv/train.csv" | normalize_text > "${DATADIR}/dbpedia.train"
+  cat "${DATADIR}/dbpedia_csv/test.csv" | normalize_text > "${DATADIR}/dbpedia.test"
+fi
+
+make
+
+echo "Training..."
+./fasttext supervised -input "${DATADIR}/dbpedia.train" -output "${RESULTDIR}/dbpedia" -dim 10 -lr 0.1 -wordNgrams 2 -minCount 1 -bucket 10000000 -epoch 5 -thread 4
+
+echo "Quantizing..."
+./fasttext quantize -output "${RESULTDIR}/dbpedia" -input "${DATADIR}/dbpedia.train" -qnorm -retrain -epoch 1 -cutoff 100000
+
+echo "Testing original model..."
+./fasttext test "${RESULTDIR}/dbpedia.bin" "${DATADIR}/dbpedia.test"
+echo "Testing quantized model..."
+./fasttext test "${RESULTDIR}/dbpedia.ftz" "${DATADIR}/dbpedia.test" 1 1
+
+ls -lrh "${RESULTDIR}/dbpedia.bin" | awk  '{print "Size of the original model:\t",$5;}'
+ls -lrh "${RESULTDIR}/dbpedia.ftz" | awk  '{print "Size of the quantized model:\t",$5;}'

+ 26 - 5
src/args.cc

@@ -37,6 +37,12 @@ Args::Args() {
   verbose = 2;
   pretrainedVectors = "";
   saveOutput = 0;
+
+  qout = false;
+  retrain = false;
+  qnorm = false;
+  cutoff = 0;
+  dsub = 2;
 }
 
 void Args::parseArgs(int argc, char** argv) {
@@ -116,6 +122,16 @@ void Args::parseArgs(int argc, char** argv) {
       pretrainedVectors = std::string(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-saveOutput") == 0) {
       saveOutput = atoi(argv[ai + 1]);
+    } else if (strcmp(argv[ai], "-qnorm") == 0) {
+      qnorm = true; ai--;
+    } else if (strcmp(argv[ai], "-retrain") == 0) {
+      retrain = true; ai--;
+    } else if (strcmp(argv[ai], "-qout") == 0) {
+      qout = true; ai--;
+    } else if (strcmp(argv[ai], "-cutoff") == 0) {
+    cutoff = atoi(argv[ai + 1]);
+    } else if (strcmp(argv[ai], "-dsub") == 0) {
+      dsub = atoi(argv[ai + 1]);
     } else {
       std::cout << "Unknown argument: " << argv[ai] << std::endl;
       printHelp();
@@ -138,11 +154,10 @@ void Args::printHelp() {
   if (loss == loss_name::hs) lname = "hs";
   if (loss == loss_name::softmax) lname = "softmax";
   std::cout
-    << "\n"
-    << "The following arguments are mandatory:\n"
+    << "\nThe following arguments are mandatory:\n"
     << "  -input              training file path\n"
-    << "  -output             output file path\n\n"
-    << "The following arguments are optional:\n"
+    << "  -output             output file path\n"
+    << "\nThe following arguments are optional:\n"
     << "  -lr                 learning rate [" << lr << "]\n"
     << "  -lrUpdateRate       change the rate of updates for the learning rate [" << lrUpdateRate << "]\n"
     << "  -dim                size of word vectors [" << dim << "]\n"
@@ -160,8 +175,14 @@ void Args::printHelp() {
     << "  -t                  sampling threshold [" << t << "]\n"
     << "  -label              labels prefix [" << label << "]\n"
     << "  -verbose            verbosity level [" << verbose << "]\n"
-    << "  -pretrainedVectors  pretrained word vectors for supervised learning []"
+    << "  -pretrainedVectors  pretrained word vectors for supervised learning ["<< pretrainedVectors <<"]\n"
     << "  -saveOutput         whether output params should be saved [" << saveOutput << "]\n"
+    << "\nThe following arguments for quantization are optional:\n"
+    << "  -cutoff             number of words and ngrams to retain [" << cutoff << "]\n"
+    << "  -retrain            finetune embeddings if a cutoff is applied [" << retrain << "]\n"
+    << "  -qnorm              quantizing the norm separately [" << qnorm << "]\n"
+    << "  -qout               quantizing the classifier [" << qout << "]\n"
+    << "  -dsub               size of each sub-vector [" << dsub << "]\n"
     << std::endl;
 }
 

+ 6 - 0
src/args.h

@@ -46,6 +46,12 @@ class Args {
     std::string pretrainedVectors;
     int saveOutput;
 
+    bool qout;
+    bool retrain;
+    bool qnorm;
+    size_t cutoff;
+    size_t dsub;
+
     void parseArgs(int, char**);
     void printHelp();
     void save(std::ostream&);

+ 168 - 36
src/dictionary.cc

@@ -12,9 +12,9 @@
 #include <assert.h>
 
 #include <iostream>
+#include <fstream>
 #include <algorithm>
 #include <iterator>
-#include <unordered_map>
 
 namespace fasttext {
 
@@ -22,17 +22,9 @@ const std::string Dictionary::EOS = "</s>";
 const std::string Dictionary::BOW = "<";
 const std::string Dictionary::EOW = ">";
 
-Dictionary::Dictionary(std::shared_ptr<Args> args) {
-  args_ = args;
-  size_ = 0;
-  nwords_ = 0;
-  nlabels_ = 0;
-  ntokens_ = 0;
-  word2int_.resize(MAX_VOCAB_SIZE);
-  for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
-    word2int_[i] = -1;
-  }
-}
+Dictionary::Dictionary(std::shared_ptr<Args> args) : args_(args),
+  word2int_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0),
+  ntokens_(0), quant_(false) {}
 
 int32_t Dictionary::find(const std::string& w) const {
   int32_t h = hash(w) % MAX_VOCAB_SIZE;
@@ -246,9 +238,7 @@ void Dictionary::threshold(int64_t t, int64_t tl) {
   size_ = 0;
   nwords_ = 0;
   nlabels_ = 0;
-  for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
-    word2int_[i] = -1;
-  }
+  std::fill(word2int_.begin(), word2int_.end(), -1);
   for (auto it = words_.begin(); it != words_.end(); ++it) {
     int32_t h = find(it->word);
     word2int_[h] = size_++;
@@ -273,43 +263,88 @@ std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
   return counts;
 }
 
-void Dictionary::addNgrams(std::vector<int32_t>& line, int32_t n) const {
-  int32_t line_size = line.size();
-  for (int32_t i = 0; i < line_size; i++) {
-    uint64_t h = line[i];
-    for (int32_t j = i + 1; j < line_size && j < i + n; j++) {
-      h = h * 116049371 + line[j];
-      line.push_back(nwords_ + (h % args_->bucket));
+void Dictionary::addNgrams(std::vector<int32_t>& line,
+                           const std::vector<int32_t>& hashes,
+                           int32_t n) const {
+  for (int32_t i = 0; i < hashes.size(); i++) {
+    uint64_t h = hashes[i];
+    for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
+      h = h * 116049371 + hashes[j];
+      int64_t id = h % args_->bucket;
+      if (quantidx_.size() != 0) {
+        if (quantidx_.find(id) != quantidx_.end()) {
+          id = quantidx_.at(id);
+        } else {continue;}
+      }
+      line.push_back(nwords_ + id);
     }
   }
 }
 
+int32_t Dictionary::getLine(std::istream& in,
+                            std::vector<std::string>& tokens) const {
+  if (in.eof()) {
+    in.clear();
+    in.seekg(std::streampos(0));
+  }
+  tokens.clear();
+  std::string token;
+  while (readWord(in, token)) {
+    tokens.push_back(token);
+    if (token == EOS) break;
+    if (tokens.size() > MAX_LINE_SIZE && args_->model != model_name::sup) break;
+  }
+  return tokens.size();
+}
+
 int32_t Dictionary::getLine(std::istream& in,
                             std::vector<int32_t>& words,
+                            std::vector<int32_t>& word_hashes,
                             std::vector<int32_t>& labels,
                             std::minstd_rand& rng) const {
   std::uniform_real_distribution<> uniform(0, 1);
-  std::string token;
-  int32_t ntokens = 0;
+  std::vector<std::string> tokens;
+  getLine(in, tokens);
   words.clear();
   labels.clear();
-  if (in.eof()) {
-    in.clear();
-    in.seekg(std::streampos(0));
-  }
-  while (readWord(in, token)) {
-    int32_t wid = getId(token);
-    if (wid < 0) continue;
+  word_hashes.clear();
+  int32_t ntokens = 0;
+  for(auto it = tokens.cbegin(); it != tokens.cend(); ++it) {
+    int32_t h = find(*it);
+    int32_t wid = word2int_[h];
+    if (wid < 0) {
+      word_hashes.push_back(hash(*it));
+      continue;
+    }
     entry_type type = getType(wid);
     ntokens++;
     if (type == entry_type::word && !discard(wid, uniform(rng))) {
       words.push_back(wid);
+      word_hashes.push_back(hash(*it));
     }
     if (type == entry_type::label) {
       labels.push_back(wid - nwords_);
     }
-    if (words.size() > MAX_LINE_SIZE && args_->model != model_name::sup) break;
-    if (token == EOS) break;
+  }
+  return ntokens;
+}
+
+
+int32_t Dictionary::getLine(std::istream& in,
+                            std::vector<int32_t>& words,
+                            std::vector<int32_t>& labels,
+                            std::minstd_rand& rng) const {
+  std::vector<int32_t> word_hashes;
+  int32_t ntokens = getLine(in, words, word_hashes, labels, rng);
+  if (args_->model == model_name::sup ) {
+    if (quant_) {
+      addNgrams(words, word_hashes, args_->wordNgrams);
+    }
+    else {
+      std::vector<int32_t> ngrams;
+      addNgrams(ngrams, words, args_->wordNgrams);
+      words.insert(words.end(), ngrams.begin(), ngrams.end());
+    }
   }
   return ntokens;
 }
@@ -332,13 +367,19 @@ void Dictionary::save(std::ostream& out) const {
     out.write((char*) &(e.count), sizeof(int64_t));
     out.write((char*) &(e.type), sizeof(entry_type));
   }
+  if (quant_) {
+    auto ss = quantidx_.size();
+    out.write((char*) &(ss), sizeof(ss));
+    for (auto it = quantidx_.begin(); it != quantidx_.end(); it++) {
+      out.write((char*)&(it->first), sizeof(int32_t));
+      out.write((char*)&(it->second), sizeof(int32_t));
+    }
+  }
 }
 
 void Dictionary::load(std::istream& in) {
   words_.clear();
-  for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
-    word2int_[i] = -1;
-  }
+  std::fill(word2int_.begin(), word2int_.end(), -1);
   in.read((char*) &size_, sizeof(int32_t));
   in.read((char*) &nwords_, sizeof(int32_t));
   in.read((char*) &nlabels_, sizeof(int32_t));
@@ -354,8 +395,99 @@ void Dictionary::load(std::istream& in) {
     words_.push_back(e);
     word2int_[find(e.word)] = i;
   }
+  if (quant_) {
+    std::size_t size;
+    in.read((char*) &size, sizeof(std::size_t));
+    for (auto i = 0; i < size; i++) {
+      int32_t k, v;
+      in.read((char*)&k, sizeof(int32_t));
+      in.read((char*)&v, sizeof(int32_t));
+      quantidx_[k] = v;
+    }
+  }
   initTableDiscard();
   initNgrams();
 }
 
+void Dictionary::prune(std::vector<int32_t>& idx) {
+  std::vector<int32_t> words, ngrams;
+  for (auto it = idx.cbegin(); it != idx.cend(); ++it) {
+    if (*it < nwords_) {words.push_back(*it);}
+    else {ngrams.push_back(*it);}
+  }
+  std::sort(words.begin(), words.end());
+  idx = words;
+
+  if (ngrams.size() != 0) {
+    convertNgrams(ngrams);
+    idx.insert(idx.end(), ngrams.begin(), ngrams.end());
+  }
+
+  std::fill(word2int_.begin(), word2int_.end(), -1);
+
+  int32_t j = 0;
+  for (int32_t i = 0; i < words_.size(); i++) {
+    if (getType(i) == entry_type::label || (j < words.size() && words[j] == i)) {
+      words_[j] = words_[i];
+      word2int_[find(words_[j].word)] = j;
+      j++;
+    }
+  }
+  nwords_ = words.size();
+  size_ = nwords_ +  nlabels_;
+  words_.erase(words_.begin() + size_, words_.end());
+}
+
+void Dictionary::convertNgrams(std::vector<int32_t>& ngramidx) {
+
+  std::ifstream in(args_->input);
+  if (!in.is_open()) {
+    std::cerr << "Input file cannot be opened!" << std::endl;
+    exit(EXIT_FAILURE);
+  }
+
+  std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> convertMap;
+  for (auto it = ngramidx.cbegin(); it != ngramidx.cend(); ++it) {
+   convertMap[*it] = std::unordered_map<int32_t, int32_t>();
+  }
+  std::vector<std::string> tokens;
+  std::vector<int32_t> word_hashes, words, labels, oldhashes, newhashes;
+  std::minstd_rand rng;
+  while (in.peek() != EOF) {
+    getLine(in, words, word_hashes, labels, rng);
+    if (words.empty()) {continue;}
+    oldhashes.clear(); newhashes.clear();
+    addNgrams(oldhashes, words, args_->wordNgrams);
+    addNgrams(newhashes, word_hashes, args_->wordNgrams);
+    assert(newhashes.size() == oldhashes.size());
+    for (int32_t i = 0; i < oldhashes.size(); i++) {
+      auto oh = oldhashes[i];
+      if (convertMap.find(oh) == convertMap.end()) {continue;}
+      convertMap[oh][newhashes[i]]++;
+    }
+  }
+  in.close();
+
+  quantidx_.clear();
+  std::vector<int32_t> remaining_indices;
+  int32_t size = 0;
+  for (auto it = ngramidx.begin(); it != ngramidx.end(); ++it) {
+    auto cm = convertMap[*it];
+    int32_t newhash; int32_t count = -1;
+    for (auto nit = cm.cbegin(); nit != cm.cend(); ++nit) {
+      if (count < nit->second) {
+        newhash = nit->first;
+        count = nit->second;
+      }
+    }
+    newhash -= nwords_;
+    if (quantidx_.find(newhash) == quantidx_.end()) {
+      quantidx_[newhash] = size;
+      size++;
+      remaining_indices.push_back(*it);
+    }
+  }
+  ngramidx = remaining_indices;
+}
+
 }

+ 13 - 1
src/dictionary.h

@@ -16,6 +16,7 @@
 #include <ostream>
 #include <random>
 #include <memory>
+#include <unordered_map>
 
 #include "args.h"
 #include "real.h"
@@ -44,6 +45,9 @@ class Dictionary {
     std::shared_ptr<Args> args_;
     std::vector<int32_t> word2int_;
     std::vector<entry> words_;
+
+    std::unordered_map<int32_t, int32_t> quantidx_;
+
     std::vector<real> pdiscard_;
     int32_t size_;
     int32_t nwords_;
@@ -51,6 +55,8 @@ class Dictionary {
     int64_t ntokens_;
 
   public:
+    bool quant_;
+
     static const std::string EOS;
     static const std::string BOW;
     static const std::string EOW;
@@ -78,10 +84,16 @@ class Dictionary {
     void save(std::ostream&) const;
     void load(std::istream&);
     std::vector<int64_t> getCounts(entry_type) const;
-    void addNgrams(std::vector<int32_t>&, int32_t) const;
+    void addNgrams(std::vector<int32_t>&, const std::vector<int32_t>&,
+                   int32_t) const;
+    int32_t getLine(std::istream&, std::vector<std::string>&) const;
+    int32_t getLine(std::istream&, std::vector<int32_t>&, std::vector<int32_t>&,
+                    std::vector<int32_t>&, std::minstd_rand&) const;
     int32_t getLine(std::istream&, std::vector<int32_t>&,
                     std::vector<int32_t>&, std::minstd_rand&) const;
     void threshold(int64_t, int64_t);
+    void prune(std::vector<int32_t>&);
+    void convertNgrams(std::vector<int32_t>&);
 };
 
 }

+ 111 - 9
src/fasttext.cc

@@ -18,8 +18,11 @@
 #include <vector>
 #include <algorithm>
 
+
 namespace fasttext {
 
+FastText::FastText() : quant_(false) {}
+
 void FastText::getVector(Vector& vec, const std::string& word) {
   const std::vector<int32_t>& ngrams = dict_->getNgrams(word);
   vec.zero();
@@ -65,15 +68,32 @@ void FastText::saveOutput() {
 }
 
 void FastText::saveModel() {
-  std::ofstream ofs(args_->output + ".bin", std::ofstream::binary);
+  std::string fn(args_->output);
+  if (quant_) {
+    fn += ".ftz";
+    dict_->quant_ = true;
+  } else {
+    fn += ".bin";
+  }
+  std::ofstream ofs(fn, std::ofstream::binary);
   if (!ofs.is_open()) {
     std::cerr << "Model file cannot be opened for saving!" << std::endl;
     exit(EXIT_FAILURE);
   }
   args_->save(ofs);
   dict_->save(ofs);
-  input_->save(ofs);
-  output_->save(ofs);
+  if (quant_) {
+    ofs.write((char*) &(args_->qout), sizeof(bool));
+    qinput_->save(ofs);
+  } else {
+    input_->save(ofs);
+  }
+  if (quant_ && args_->qout) {
+    qoutput_->save(ofs);
+  }
+  else {
+    output_->save(ofs);
+  }
   ofs.close();
 }
 
@@ -92,11 +112,30 @@ void FastText::loadModel(std::istream& in) {
   dict_ = std::make_shared<Dictionary>(args_);
   input_ = std::make_shared<Matrix>();
   output_ = std::make_shared<Matrix>();
+  qinput_ = std::make_shared<QMatrix>();
+  qoutput_ = std::make_shared<QMatrix>();
   args_->load(in);
+
+  dict_->quant_ = quant_;
   dict_->load(in);
-  input_->load(in);
-  output_->load(in);
+
+  if (quant_) {
+    in.read((char*) &(args_->qout), sizeof(bool));
+    qinput_->load(in);
+  } else {
+    input_->load(in);
+  }
+
+  if (quant_ && args_->qout) {
+    qoutput_->load(in);
+  } else {
+    output_->load(in);
+  }
+
   model_ = std::make_shared<Model>(input_, output_, args_, 0);
+  model_->quant_ = quant_;
+  model_->setQuantizePointer(qinput_, qoutput_, args_->qout);
+
   if (args_->model == model_name::sup) {
     model_->setTargetCounts(dict_->getCounts(entry_type::label));
   } else {
@@ -120,6 +159,73 @@ void FastText::printInfo(real progress, real loss) {
   std::cout << std::flush;
 }
 
+std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
+  Vector norms(input_->m_);
+  input_->l2NormRow(norms);
+  std::vector<int32_t> idx(input_->m_, 0);
+  std::iota(idx.begin(), idx.end(), 0);
+  auto eosid = dict_->getId(Dictionary::EOS);
+  std::sort(idx.begin(), idx.end(),
+      [&norms, eosid] (size_t i1, size_t i2) {
+      return eosid ==i1 || (eosid != i2 && norms[i1] > norms[i2]);
+      });
+  idx.erase(idx.begin() + cutoff, idx.end());
+  return idx;
+}
+
+void FastText::quantize(std::shared_ptr<Args> qargs) {
+  if (qargs->output.empty()) {
+      std::cout<<"No model provided!"<<std::endl; exit(1);
+  }
+  loadModel(qargs->output + ".bin");
+
+  args_->input = qargs->input;
+  args_->qout = qargs->qout;
+  args_->output = qargs->output;
+
+
+  if (qargs->cutoff > 0 && qargs->cutoff < input_->m_) {
+    auto idx = selectEmbeddings(qargs->cutoff);
+    dict_->prune(idx);
+    dict_->quant_ = true;
+    std::shared_ptr<Matrix> ninput =
+      std::make_shared<Matrix> (idx.size(), args_->dim);
+    for (auto i = 0; i < idx.size(); i++) {
+      for (auto j = 0; j < args_->dim; j++) {
+        ninput->at(i,j) = input_->at(idx[i], j);
+      }
+    }
+    input_ = ninput;
+    if (qargs->retrain) {
+      args_->epoch = qargs->epoch;
+      args_->lr = qargs->lr;
+      args_->thread = qargs->thread;
+      args_->verbose = qargs->verbose;
+      tokenCount = 0;
+      std::vector<std::thread> threads;
+      for (int32_t i = 0; i < args_->thread; i++) {
+        threads.push_back(std::thread([=]() { trainThread(i); }));
+      }
+      for (auto it = threads.begin(); it != threads.end(); ++it) {
+        it->join();
+      }
+    }
+  }
+
+  qinput_ = std::make_shared<QMatrix>(*input_, qargs->dsub, qargs->qnorm);
+
+  if (args_->qout) {
+    qoutput_ = std::make_shared<QMatrix>(*output_, 2, qargs->qnorm);
+  }
+
+  quant_ = true;
+  saveModel();
+}
+
+void FastText::setQuantize(bool quant) {
+  quant_ = quant;
+}
+
 void FastText::supervised(Model& model, real lr,
                           const std::vector<int32_t>& line,
                           const std::vector<int32_t>& labels) {
@@ -167,7 +273,6 @@ void FastText::test(std::istream& in, int32_t k) {
 
   while (in.peek() != EOF) {
     dict_->getLine(in, line, labels, model_->rng);
-    dict_->addNgrams(line, args_->wordNgrams);
     if (labels.size() > 0 && line.size() > 0) {
       std::vector<std::pair<real, int32_t>> modelPredictions;
       model_->predict(line, k, modelPredictions);
@@ -190,7 +295,6 @@ void FastText::predict(std::istream& in, int32_t k,
                        std::vector<std::pair<real,std::string>>& predictions) const {
   std::vector<int32_t> words, labels;
   dict_->getLine(in, words, labels, model_->rng);
-  dict_->addNgrams(words, args_->wordNgrams);
   if (words.empty()) return;
   Vector hidden(args_->dim);
   Vector output(dict_->nlabels());
@@ -251,7 +355,6 @@ void FastText::textVectors() {
   Vector vec(args_->dim);
   while (std::cin.peek() != EOF) {
     dict_->getLine(std::cin, line, labels, model_->rng);
-    dict_->addNgrams(line, args_->wordNgrams);
     vec.zero();
     for (auto it = line.cbegin(); it != line.cend(); ++it) {
       vec.addRow(*input_, *it);
@@ -290,7 +393,6 @@ void FastText::trainThread(int32_t threadId) {
     real lr = args_->lr * (1.0 - progress);
     localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
     if (args_->model == model_name::sup) {
-      dict_->addNgrams(line, args_->wordNgrams);
       supervised(model, lr, line, labels);
     } else if (args_->model == model_name::cbow) {
       cbow(model, lr, line);

+ 15 - 0
src/fasttext.h

@@ -17,6 +17,7 @@
 
 #include "matrix.h"
 #include "vector.h"
+#include "qmatrix.h"
 #include "dictionary.h"
 #include "model.h"
 #include "utils.h"
@@ -29,13 +30,23 @@ class FastText {
   private:
     std::shared_ptr<Args> args_;
     std::shared_ptr<Dictionary> dict_;
+
     std::shared_ptr<Matrix> input_;
     std::shared_ptr<Matrix> output_;
+
+    std::shared_ptr<QMatrix> qinput_;
+    std::shared_ptr<QMatrix> qoutput_;
+
     std::shared_ptr<Model> model_;
+
     std::atomic<int64_t> tokenCount;
     clock_t start;
 
+    bool quant_;
+
   public:
+    FastText();
+
     void getVector(Vector&, const std::string&);
     void saveVectors();
     void saveOutput();
@@ -44,10 +55,14 @@ class FastText {
     void loadModel(std::istream&);
     void printInfo(real, real);
 
+    void setQuantize(bool);
+
     void supervised(Model&, real, const std::vector<int32_t>&,
                     const std::vector<int32_t>&);
     void cbow(Model&, real, const std::vector<int32_t>&);
     void skipgram(Model&, real, const std::vector<int32_t>&);
+    std::vector<int32_t> selectEmbeddings(int32_t) const;
+    void quantize(std::shared_ptr<Args>);
     void test(std::istream&, int32_t);
     void predict(std::istream&, int32_t, bool);
     void predict(std::istream&, int32_t, std::vector<std::pair<real,std::string>>&) const;

+ 50 - 15
src/main.cc

@@ -16,9 +16,10 @@ using namespace fasttext;
 
 void printUsage() {
   std::cout
-    << "usage: fasttext <command> <args>\n\n"
+    << "usage: fasttext <commands> <args>\n\n"
     << "The commands supported by fasttext are:\n\n"
     << "  supervised          train a supervised classifier\n"
+    << "  quantize            quantize a model to reduce the memory usage\n"
     << "  test                evaluate a supervised classifier\n"
     << "  predict             predict most likely labels\n"
     << "  predict-prob        predict most likely labels with probabilities\n"
@@ -28,21 +29,29 @@ void printUsage() {
     << std::endl;
 }
 
+void printQuantizeUsage() {
+  std::cout
+    << "usage: fasttext quantize <args>"
+    << std::endl;
+}
+
 void printTestUsage() {
   std::cout
-    << "usage: fasttext test <model> <test-data> [<k>]\n\n"
+    << "usage: fasttext test <model> <test-data> [<k> <quant>]\n\n"
     << "  <model>      model filename\n"
     << "  <test-data>  test data filename (if -, read from stdin)\n"
     << "  <k>          (optional; 1 by default) predict top k labels\n"
+    << "  <quant>      (optional; 0 by default) used or not quantized model\n"
     << std::endl;
 }
 
 void printPredictUsage() {
   std::cout
-    << "usage: fasttext predict[-prob] <model> <test-data> [<k>]\n\n"
+    << "usage: fasttext predict[-prob] <model> <test-data> [<k> <quant>]\n\n"
     << "  <model>      model filename\n"
     << "  <test-data>  test data filename (if -, read from stdin)\n"
     << "  <k>          (optional; 1 by default) predict top k labels\n"
+    << "  <quant>      (optional; 0 by default) use or not quantized model\n"
     << std::endl;
 }
 
@@ -61,18 +70,37 @@ void printPrintNgramsUsage() {
     << std::endl;
 }
 
+void quantize(int argc, char** argv) {
+  std::shared_ptr<Args> a = std::make_shared<Args>();
+  if (argc < 3) {
+    printQuantizeUsage();
+    a->printHelp();
+    exit(EXIT_FAILURE);
+  }
+  a->parseArgs(argc, argv);
+  FastText fasttext;
+  fasttext.quantize(a);
+  exit(0);
+}
+
 void test(int argc, char** argv) {
-  int32_t k;
-  if (argc == 4) {
-    k = 1;
-  } else if (argc == 5) {
-    k = atoi(argv[4]);
-  } else {
+  if (argc < 4 || argc > 6) {
     printTestUsage();
     exit(EXIT_FAILURE);
   }
+  int32_t k = 1;
+  if (argc >= 5) {
+    k = atoi(argv[4]);
+  }
+  bool quant = false;
+  if (argc >= 6) {
+    quant = atoi(argv[5]);
+  }
+
   FastText fasttext;
+  fasttext.setQuantize(quant);
   fasttext.loadModel(std::string(argv[2]));
+
   std::string infile(argv[3]);
   if (infile == "-") {
     fasttext.test(std::cin, k);
@@ -89,17 +117,22 @@ void test(int argc, char** argv) {
 }
 
 void predict(int argc, char** argv) {
-  int32_t k;
-  if (argc == 4) {
-    k = 1;
-  } else if (argc == 5) {
-    k = atoi(argv[4]);
-  } else {
+  if (argc < 4 || argc > 6) {
     printPredictUsage();
     exit(EXIT_FAILURE);
   }
+  int32_t k = 1;
+  if (argc >= 5) {
+    k = atoi(argv[4]);
+  }
+  bool quant = false;
+  if (argc >= 6) {
+    quant = atoi(argv[5]);
+  }
+
   bool print_prob = std::string(argv[1]) == "predict-prob";
   FastText fasttext;
+  fasttext.setQuantize(quant);
   fasttext.loadModel(std::string(argv[2]));
 
   std::string infile(argv[3]);
@@ -157,6 +190,8 @@ int main(int argc, char** argv) {
     train(argc, argv);
   } else if (command == "test") {
     test(argc, argv);
+  } else if (command == "quantize") {
+    quantize(argc, argv);
   } else if (command == "print-vectors") {
     printVectors(argc, argv);
   } else if (command == "print-ngrams") {

+ 50 - 8
src/matrix.cc

@@ -65,24 +65,66 @@ void Matrix::uniform(real a) {
   }
 }
 
-void Matrix::addRow(const Vector& vec, int64_t i, real a) {
+real Matrix::dotRow(const Vector& vec, int64_t i) const {
   assert(i >= 0);
   assert(i < m_);
-  assert(vec.m_ == n_);
+  assert(vec.size() == n_);
+  real d = 0.0;
   for (int64_t j = 0; j < n_; j++) {
-    data_[i * n_ + j] += a * vec.data_[j];
+    d += at(i, j) * vec.data_[j];
   }
+  return d;
 }
 
-real Matrix::dotRow(const Vector& vec, int64_t i) {
+void Matrix::addRow(const Vector& vec, int64_t i, real a) {
   assert(i >= 0);
   assert(i < m_);
-  assert(vec.m_ == n_);
-  real d = 0.0;
+  assert(vec.size() == n_);
   for (int64_t j = 0; j < n_; j++) {
-    d += data_[i * n_ + j] * vec.data_[j];
+    data_[i * n_ + j] += a * vec.data_[j];
   }
-  return d;
+}
+
+void Matrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
+  if (ie == -1) {ie = m_;}
+  assert(ie <= nums.size());
+  for (auto i = ib; i < ie; i++) {
+    real n = nums[i-ib];
+    if (n != 0) {
+      for (auto j = 0; j < n_; j++) {
+        at(i, j) *= n;
+      }
+    }
+  }
+}
+
+void Matrix::divideRow(const Vector& denoms, int64_t ib, int64_t ie) {
+  if (ie == -1) {ie = m_;}
+  assert(ie <= denoms.size());
+  for (auto i = ib; i < ie; i++) {
+    real n = denoms[i-ib];
+    if (n != 0) {
+      for (auto j = 0; j < n_; j++) {
+        at(i, j) /= n;
+      }
+    }
+  }
+}
+
+real Matrix::l2NormRow(int64_t i) const {
+  auto norm = 0.0;
+  for (auto j = 0; j < n_; j++) {
+    const real v = at(i,j);
+    norm += v * v;
+  }
+  return std::sqrt(norm);
+}
+
+void Matrix::l2NormRow(Vector& norms) const {
+  assert(norms.size() == m_);
+    for (auto i = 0; i < m_; i++) {
+      norms[i] = l2NormRow(i);
+    }
 }
 
 void Matrix::save(std::ostream& out) {

+ 11 - 1
src/matrix.h

@@ -33,11 +33,21 @@ class Matrix {
     Matrix& operator=(const Matrix&);
     ~Matrix();
 
+    inline const real& at(int64_t i, int64_t j) const {return data_[i * n_ + j];};
+    inline real& at(int64_t i, int64_t j) {return data_[i * n_ + j];};
+
+
     void zero();
     void uniform(real);
-    real dotRow(const Vector&, int64_t);
+    real dotRow(const Vector&, int64_t) const;
     void addRow(const Vector&, int64_t, real);
 
+    void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
+    void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
+
+    real l2NormRow(int64_t i) const;
+    void l2NormRow(Vector& norms) const;
+
     void save(std::ostream&);
     void load(std::istream&);
 };

+ 29 - 6
src/model.cc

@@ -9,8 +9,8 @@
 
 #include "model.h"
 
+#include <iostream>
 #include <assert.h>
-
 #include <algorithm>
 
 namespace fasttext {
@@ -19,12 +19,12 @@ Model::Model(std::shared_ptr<Matrix> wi,
              std::shared_ptr<Matrix> wo,
              std::shared_ptr<Args> args,
              int32_t seed)
-  : hidden_(args->dim), output_(wo->m_), grad_(args->dim), rng(seed)
+  : hidden_(args->dim), output_(wo->m_),
+  grad_(args->dim), rng(seed), quant_(false)
 {
   wi_ = wi;
   wo_ = wo;
   args_ = args;
-  isz_ = wi->m_;
   osz_ = wo->m_;
   hsz_ = args->dim;
   negpos = 0;
@@ -39,6 +39,15 @@ Model::~Model() {
   delete[] t_log;
 }
 
+void Model::setQuantizePointer(std::shared_ptr<QMatrix> qwi,
+                               std::shared_ptr<QMatrix> qwo, bool qout) {
+  qwi_ = qwi;
+  qwo_ = qwo;
+  if (qout) {
+    osz_ = qwo_->m_;
+  }
+}
+
 real Model::binaryLogistic(int32_t target, bool label, real lr) {
   real score = sigmoid(wo_->dotRow(hidden_, target));
   real alpha = lr * (real(label) - score);
@@ -76,7 +85,11 @@ real Model::hierarchicalSoftmax(int32_t target, real lr) {
 }
 
 void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
-  output.mul(*wo_, hidden);
+  if (quant_ && args_->qout) {
+    output.mul(*qwo_, hidden);
+  } else {
+    output.mul(*wo_, hidden);
+  }
   real max = output[0], z = 0.0;
   for (int32_t i = 0; i < osz_; i++) {
     max = std::max(output[i], max);
@@ -110,7 +123,11 @@ void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden) con
   assert(hidden.size() == hsz_);
   hidden.zero();
   for (auto it = input.cbegin(); it != input.cend(); ++it) {
-    hidden.addRow(*wi_, *it);
+    if(quant_) {
+      hidden.addRow(*qwi_, *it);
+    } else {
+      hidden.addRow(*wi_, *it);
+    }
   }
   hidden.mul(1.0 / input.size());
 }
@@ -172,7 +189,13 @@ void Model::dfs(int32_t k, int32_t node, real score,
     return;
   }
 
-  real f = sigmoid(wo_->dotRow(hidden, node - osz_));
+  real f;
+  if (quant_ && args_->qout) {
+    f= sigmoid(qwo_->dotRow(hidden, node - osz_));
+  } else {
+    f= sigmoid(wo_->dotRow(hidden, node - osz_));
+  }
+
   dfs(k, tree[node].left, score + log(1.0 - f), heap, hidden);
   dfs(k, tree[node].right, score + log(f), heap, hidden);
 }

+ 5 - 1
src/model.h

@@ -18,6 +18,7 @@
 #include "args.h"
 #include "matrix.h"
 #include "vector.h"
+#include "qmatrix.h"
 #include "real.h"
 
 #define SIGMOID_TABLE_SIZE 512
@@ -38,12 +39,13 @@ class Model {
   private:
     std::shared_ptr<Matrix> wi_;
     std::shared_ptr<Matrix> wo_;
+    std::shared_ptr<QMatrix> qwi_;
+    std::shared_ptr<QMatrix> qwo_;
     std::shared_ptr<Args> args_;
     Vector hidden_;
     Vector output_;
     Vector grad_;
     int32_t hsz_;
-    int32_t isz_;
     int32_t osz_;
     real loss_;
     int64_t nexamples_;
@@ -99,6 +101,8 @@ class Model {
     real log(real) const;
 
     std::minstd_rand rng;
+    bool quant_;
+    void setQuantizePointer(std::shared_ptr<QMatrix>, std::shared_ptr<QMatrix>, bool);
 };
 
 }

+ 211 - 0
src/productquantizer.cc

@@ -0,0 +1,211 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+#include "productquantizer.h"
+
+#include <algorithm>
+#include <iostream>
+
+namespace fasttext {
+
+real distL2(const real* x, const real* y, int32_t d) {
+  real dist = 0;
+  for (auto i = 0; i < d; i++) {
+    auto tmp = x[i] - y[i];
+    dist += tmp * tmp;
+  }
+  return dist;
+}
+
+ProductQuantizer::ProductQuantizer(int32_t dim, int32_t dsub): dim_(dim),
+  nsubq_(dim / dsub), dsub_(dsub), centroids_(dim * ksub_), rng(seed_) {
+  lastdsub_ = dim_ % dsub;
+  if (lastdsub_ == 0) {lastdsub_ = dsub_;}
+  else {nsubq_++;}
+}
+
+const real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) const {
+  if (m == nsubq_ - 1) {return &centroids_[m * ksub_ * dsub_ + i * lastdsub_];}
+  return &centroids_[(m * ksub_ + i) * dsub_];
+}
+
+real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) {
+  if (m == nsubq_ - 1) {return &centroids_[m * ksub_ * dsub_ + i * lastdsub_];}
+  return &centroids_[(m * ksub_ + i) * dsub_];
+}
+
+real ProductQuantizer::assign_centroid(const real * x, const real* c0,
+                                       uint8_t* code, int32_t d) const {
+  const real* c = c0;
+  real dis = distL2(x, c, d);
+  code[0] = 0;
+  for (auto j = 1; j < ksub_; j++) {
+    c += d;
+    real disij = distL2(x, c, d);
+    if (disij < dis) {
+      code[0] = (uint8_t) j;
+      dis = disij;
+    }
+  }
+  return dis;
+}
+
+void ProductQuantizer::Estep(const real* x, const real* centroids,
+                             uint8_t* codes, int32_t d,
+                             int32_t n) const {
+  for (auto i = 0; i < n; i++) {
+    assign_centroid(x + i * d, centroids, codes + i, d);
+  }
+}
+
+void ProductQuantizer::MStep(const real* x0, real* centroids,
+                             const uint8_t* codes,
+                             int32_t d, int32_t n) {
+  std::vector<int32_t> nelts(ksub_, 0);
+  memset(centroids, 0, sizeof(real) * d * ksub_);
+  const real* x = x0;
+  for (auto i = 0; i < n; i++) {
+    auto k = codes[i];
+    real* c = centroids + k * d;
+    for (auto j = 0; j < d; j++) {
+      c[j] += x[j];
+    }
+    nelts[k]++;
+    x += d;
+  }
+
+  real* c = centroids;
+  for (auto k = 0; k < ksub_; k++) {
+    real z = (real) nelts[k];
+    if (z != 0) {
+      for (auto j = 0; j < d; j++) {
+        c[j] /= z;
+      }
+    }
+    c += d;
+  }
+
+  std::uniform_real_distribution<> runiform(0,1);
+  for (auto k = 0; k < ksub_; k++) {
+    if (nelts[k] == 0) {
+      int32_t m = 0;
+      while (runiform(rng) * (n - ksub_) >= nelts[m] - 1) {
+        m = (m + 1) % ksub_;
+      }
+      memcpy(centroids + k * d, centroids + m * d, sizeof(real) * d);
+      for (auto j = 0; j < d; j++) {
+        int32_t sign = (j % 2) * 2 - 1;
+        centroids[k * d + j] += sign * eps_;
+        centroids[m * d + j] -= sign * eps_;
+      }
+      nelts[k] = nelts[m] / 2;
+      nelts[m] -= nelts[k];
+    }
+  }
+}
+
+void ProductQuantizer::kmeans(const real *x, real* c, int32_t n, int32_t d) {
+  std::vector<int32_t> perm(n,0);
+  std::iota(perm.begin(), perm.end(), 0);
+  std::shuffle(perm.begin(), perm.end(), rng);
+  for (auto i = 0; i < ksub_; i++) {
+    memcpy (&c[i * d], x + perm[i] * d, d * sizeof(real));
+  }
+  uint8_t* codes = new uint8_t[n];
+  for (auto i = 0; i < niter_; i++) {
+    Estep(x, c, codes, d, n);
+    MStep(x, c, codes, d, n);
+  }
+  delete [] codes;
+}
+
+void ProductQuantizer::train(int32_t n, const real * x) {
+  if (n < ksub_) {
+    std::cerr<<"Matrix too small for quantization, must have > 256 rows"<<std::endl;
+    exit(1);
+  }
+  std::vector<int32_t> perm(n, 0);
+  std::iota(perm.begin(), perm.end(), 0);
+  auto d = dsub_;
+  auto np = std::min(n, max_points_);
+  real* xslice = new real[np * dsub_];
+  for (auto m = 0; m < nsubq_; m++) {
+    if (m == nsubq_-1) {d = lastdsub_;}
+    if (np != n) {std::shuffle(perm.begin(), perm.end(), rng);}
+    for (auto j = 0; j < np; j++) {
+      memcpy (xslice + j * d, x + perm[j] * dim_ + m * dsub_, d * sizeof(real));
+    }
+    kmeans(xslice, get_centroids(m, 0), np, d);
+  }
+  delete [] xslice;
+}
+
+real ProductQuantizer::mulcode(const Vector& x, const uint8_t* codes,
+                               int32_t t, real alpha) const {
+  real res = 0.0;
+  auto d = dsub_;
+  const uint8_t* code = codes + nsubq_ * t;
+  for (auto m = 0; m < nsubq_; m++) {
+    const real* c = get_centroids(m, code[m]);
+    if (m == nsubq_ - 1) {d = lastdsub_;}
+    for(auto n = 0; n < d; n++) {
+      res += x[m * dsub_ + n] * c[n];
+    }
+  }
+  return res * alpha;
+}
+
+void ProductQuantizer::addcode(Vector& x, const uint8_t* codes,
+                               int32_t t, real alpha) const {
+  auto d = dsub_;
+  const uint8_t* code = codes + nsubq_ * t;
+  for (auto m = 0; m < nsubq_; m++) {
+    const real* c = get_centroids(m, code[m]);
+    if (m == nsubq_ - 1) {d = lastdsub_;}
+    for(auto n = 0; n < d; n++) {
+      x[m * dsub_ + n] += alpha * c[n];
+    }
+  }
+}
+
+void ProductQuantizer::compute_code(const real* x, uint8_t* code) const {
+  auto d = dsub_;
+  for (auto m = 0; m < nsubq_; m++) {
+    if (m == nsubq_ - 1) {d = lastdsub_;}
+    assign_centroid(x + m * dsub_, get_centroids(m, 0), code + m, d);
+  }
+}
+
+void ProductQuantizer::compute_codes(const real* x, uint8_t* codes,
+                                     int32_t n) const {
+  for (auto i = 0; i < n; i++) {
+    compute_code(x + i * dim_, codes + i * nsubq_);
+  }
+}
+
+void ProductQuantizer::save(std::ostream& out) {
+  out.write((char*) &dim_, sizeof(dim_));
+  out.write((char*) &nsubq_, sizeof(nsubq_));
+  out.write((char*) &dsub_, sizeof(dsub_));
+  out.write((char*) &lastdsub_, sizeof(lastdsub_));
+  out.write((char*) centroids_.data(), centroids_.size() * sizeof(real));
+}
+
+void ProductQuantizer::load(std::istream& in) {
+  in.read((char*) &dim_, sizeof(dim_));
+  in.read((char*) &nsubq_, sizeof(nsubq_));
+  in.read((char*) &dsub_, sizeof(dsub_));
+  in.read((char*) &lastdsub_, sizeof(lastdsub_));
+  centroids_.resize(dim_ * ksub_);
+  for (auto i=0; i < centroids_.size(); i++) {
+    in.read((char*) &centroids_[i], sizeof(real));
+  }
+}
+
+}

+ 67 - 0
src/productquantizer.h

@@ -0,0 +1,67 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+#ifndef FASTTEXT_PRODUCT_QUANTIZER_H
+#define FASTTEXT_PRODUCT_QUANTIZER_H
+
+#include <cstring>
+#include <istream>
+#include <ostream>
+#include <vector>
+#include <random>
+
+#include "real.h"
+#include "vector.h"
+
+namespace fasttext {
+
+class ProductQuantizer {
+  private:
+    const int32_t nbits_ = 8;
+    const int32_t ksub_ = 1 << nbits_;
+    const int32_t max_points_per_cluster_ = 256;
+    const int32_t max_points_ = max_points_per_cluster_ * ksub_;
+    const int32_t seed_ = 1234;
+    const int32_t niter_ = 25;
+    const real eps_ = 1e-7;
+
+    int32_t dim_;
+    int32_t nsubq_;
+    int32_t dsub_;
+    int32_t lastdsub_;
+
+    std::vector<real> centroids_;
+
+    std::minstd_rand rng;
+
+  public:
+    ProductQuantizer() {}
+    ProductQuantizer(int32_t, int32_t);
+
+    real* get_centroids (int32_t, uint8_t);
+    const real* get_centroids(int32_t, uint8_t) const;
+
+    real assign_centroid(const real*, const real*, uint8_t*, int32_t) const;
+    void Estep(const real*, const real*, uint8_t*, int32_t, int32_t) const;
+    void MStep(const real*, real*, const uint8_t*, int32_t, int32_t);
+    void kmeans(const real*, real*, int32_t, int32_t);
+    void train(int, const real*);
+
+    real mulcode(const Vector&, const uint8_t*, int32_t, real) const;
+    void addcode(Vector&, const uint8_t*, int32_t, real) const;
+    void compute_code(const real*, uint8_t*)  const;
+    void compute_codes(const real*, uint8_t*, int32_t)  const;
+
+    void save(std::ostream&);
+    void load(std::istream&);
+};
+
+}
+
+#endif

+ 110 - 0
src/qmatrix.cc

@@ -0,0 +1,110 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+#include "qmatrix.h"
+
+#include <assert.h>
+#include <cmath>
+#include <iostream>
+
+namespace fasttext {
+
+QMatrix::QMatrix() : qnorm_(false),
+  m_(0), n_(0), codesize_(0) {}
+
+QMatrix::QMatrix(const Matrix& mat, int32_t dsub, bool qnorm)
+      : qnorm_(qnorm), m_(mat.m_), n_(mat.n_),
+        codesize_(m_ * std::ceil(n_ / dsub)) {
+  codes_ = new uint8_t[codesize_];
+  pq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer(n_, dsub));
+  if (qnorm_) {
+    norm_codes_ = new uint8_t[m_];
+    npq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer(1, 1));
+  }
+  quantize(mat);
+}
+
+QMatrix::~QMatrix() {
+  delete[] codes_;
+  if (qnorm_) { delete[] norm_codes_; }
+}
+
+void QMatrix::quantizeNorm(const Vector& norms) {
+  assert(qnorm_);
+  assert(norms.m_ == m_);
+  auto dataptr = norms.data_;
+  npq_->train(m_, dataptr);
+  npq_->compute_codes(dataptr, norm_codes_, m_);
+}
+
+void QMatrix::quantize(const Matrix& matrix) {
+  assert(n_ == matrix.n_);
+  assert(m_ == matrix.m_);
+  Matrix temp(matrix);
+  if (qnorm_) {
+    Vector norms(temp.m_);
+    temp.l2NormRow(norms);
+    temp.divideRow(norms);
+    quantizeNorm(norms);
+  }
+  auto dataptr = temp.data_;
+  pq_->train(m_, dataptr);
+  pq_->compute_codes(dataptr, codes_, m_);
+}
+
+void QMatrix::addToVector(Vector& x, int32_t t) const {
+  real norm = 1;
+  if (qnorm_) {
+    norm = npq_->get_centroids(0, norm_codes_[t])[0];
+  }
+  pq_->addcode(x, codes_, t, norm);
+}
+
+real QMatrix::dotRow(const Vector& vec, int64_t i) const {
+  assert(i >= 0);
+  assert(i < m_);
+  assert(vec.size() == n_);
+  real norm = 1;
+  if (qnorm_) {
+    norm = npq_->get_centroids(0, norm_codes_[i])[0];
+  }
+  return pq_->mulcode(vec, codes_, i, norm);
+}
+
+void QMatrix::save(std::ostream& out) {
+    out.write((char*) &qnorm_, sizeof(qnorm_));
+    out.write((char*) &m_, sizeof(m_));
+    out.write((char*) &n_, sizeof(n_));
+    out.write((char*) &codesize_, sizeof(codesize_));
+    out.write((char*) codes_, codesize_ * sizeof(uint8_t));
+    pq_->save(out);
+    if (qnorm_) {
+      out.write((char*) norm_codes_, m_ * sizeof(uint8_t));
+      npq_->save(out);
+    }
+}
+
+void QMatrix::load(std::istream& in) {
+    in.read((char*) &qnorm_, sizeof(qnorm_));
+    in.read((char*) &m_, sizeof(m_));
+    in.read((char*) &n_, sizeof(n_));
+    in.read((char*) &codesize_, sizeof(codesize_));
+    codes_ = new uint8_t[codesize_];
+    in.read((char*) codes_, codesize_ * sizeof(uint8_t));
+    pq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer());
+    pq_->load(in);
+    if (qnorm_) {
+      norm_codes_ = new uint8_t[m_];
+      in.read((char*) norm_codes_, m_ * sizeof(uint8_t));
+      npq_ = std::unique_ptr<ProductQuantizer>( new ProductQuantizer());
+      npq_->load(in);
+    }
+}
+
+}

+ 60 - 0
src/qmatrix.h

@@ -0,0 +1,60 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+#ifndef FASTTEXT_QMATRIX_H
+#define FASTTEXT_QMATRIX_H
+
+#include <cstdint>
+#include <istream>
+#include <ostream>
+
+#include <vector>
+#include <memory>
+
+#include "real.h"
+
+#include "matrix.h"
+#include "vector.h"
+
+#include "productquantizer.h"
+
+namespace fasttext {
+
+class QMatrix {
+  private:
+    std::unique_ptr<ProductQuantizer> pq_;
+    std::unique_ptr<ProductQuantizer> npq_;
+
+    uint8_t* codes_;
+    uint8_t* norm_codes_;
+
+  public:
+    bool qnorm_;
+    int64_t m_;
+    int64_t n_;
+    int32_t codesize_;
+
+    QMatrix();
+    QMatrix(const Matrix&, int32_t, bool);
+    ~QMatrix();
+
+
+    void quantizeNorm(const Vector&);
+    void quantize(const Matrix&);
+
+    void addToVector(Vector& x, int32_t t) const;
+    real dotRow(const Vector&, int64_t) const;
+
+    void save(std::ostream&);
+    void load(std::istream&);
+};
+
+}
+
+#endif

+ 17 - 6
src/vector.cc

@@ -14,6 +14,7 @@
 #include <iomanip>
 
 #include "matrix.h"
+#include "qmatrix.h"
 
 namespace fasttext {
 
@@ -47,7 +48,7 @@ void Vector::addRow(const Matrix& A, int64_t i) {
   assert(i < A.m_);
   assert(m_ == A.n_);
   for (int64_t j = 0; j < A.n_; j++) {
-    data_[j] += A.data_[i * A.n_ + j];
+    data_[j] += A.at(i, j);
   }
 }
 
@@ -56,18 +57,28 @@ void Vector::addRow(const Matrix& A, int64_t i, real a) {
   assert(i < A.m_);
   assert(m_ == A.n_);
   for (int64_t j = 0; j < A.n_; j++) {
-    data_[j] += a * A.data_[i * A.n_ + j];
+    data_[j] += a * A.at(i, j);
   }
 }
 
+void Vector::addRow(const QMatrix& A, int64_t i) {
+  assert(i >= 0);
+  A.addToVector(*this, i);
+}
+
 void Vector::mul(const Matrix& A, const Vector& vec) {
   assert(A.m_ == m_);
   assert(A.n_ == vec.m_);
   for (int64_t i = 0; i < m_; i++) {
-    data_[i] = 0.0;
-    for (int64_t j = 0; j < A.n_; j++) {
-      data_[i] += A.data_[i * A.n_ + j] * vec.data_[j];
-    }
+    data_[i] = A.dotRow(vec, i);
+  }
+}
+
+void Vector::mul(const QMatrix& A, const Vector& vec) {
+  assert(A.m_ == m_);
+  assert(A.n_ == vec.m_);
+  for (int64_t i = 0; i < m_; i++) {
+    data_[i] = A.dotRow(vec, i);
   }
 }
 

+ 3 - 0
src/vector.h

@@ -18,6 +18,7 @@
 namespace fasttext {
 
 class Matrix;
+class QMatrix;
 
 class Vector {
 
@@ -35,7 +36,9 @@ class Vector {
     void zero();
     void mul(real);
     void addRow(const Matrix&, int64_t);
+    void addRow(const QMatrix&, int64_t);
     void addRow(const Matrix&, int64_t, real);
+    void mul(const QMatrix&, const Vector&);
     void mul(const Matrix&, const Vector&);
     int64_t argmax();
 };