Jelajahi Sumber

Base class for matrices

Summary: This diff adds a base class Matrix

Reviewed By: EdouardGrave

Differential Revision: D13234636

fbshipit-source-id: 0309b3b9b42a88571de77cc5e5269ce8b962979d
Onur Çelebi 7 tahun lalu
induk
melakukan
c35edc3a26
18 mengubah file dengan 493 tambahan dan 444 penghapusan
  1. 4 2
      CMakeLists.txt
  2. 8 5
      Makefile
  3. 8 6
      python/fastText/pybind/fasttext_pybind.cc
  4. 155 0
      src/densematrix.cc
  5. 75 0
      src/densematrix.h
  6. 97 98
      src/fasttext.cc
  7. 14 12
      src/fasttext.h
  8. 7 115
      src/matrix.cc
  9. 12 46
      src/matrix.h
  10. 9 35
      src/model.cc
  11. 2 7
      src/model.h
  12. 1 1
      src/productquantizer.cc
  13. 1 1
      src/productquantizer.h
  14. 0 59
      src/qmatrix.h
  15. 36 32
      src/quantmatrix.cc
  16. 60 0
      src/quantmatrix.h
  17. 4 22
      src/vector.cc
  18. 0 3
      src/vector.h

+ 4 - 2
CMakeLists.txt

@@ -19,19 +19,21 @@ set(CMAKE_CXX_FLAGS " -pthread -std=c++11 -funroll-loops -O3 -march=native")
 
 set(HEADER_FILES
     src/args.h
+    src/densematrix.h
     src/dictionary.h
     src/fasttext.h
     src/matrix.h
     src/meter.h
     src/model.h
     src/productquantizer.h
-    src/qmatrix.h
+    src/quantmatrix.h
     src/real.h
     src/utils.h
     src/vector.h)
 
 set(SOURCE_FILES
     src/args.cc
+    src/densematrix.cc
     src/dictionary.cc
     src/fasttext.cc
     src/main.cc
@@ -39,7 +41,7 @@ set(SOURCE_FILES
     src/meter.cc
     src/model.cc
     src/productquantizer.cc
-    src/qmatrix.cc
+    src/quantmatrix.cc
     src/utils.cc
     src/vector.cc)
 

+ 8 - 5
Makefile

@@ -8,7 +8,7 @@
 
 CXX = c++
 CXXFLAGS = -pthread -std=c++0x -march=native
-OBJS = args.o dictionary.o productquantizer.o matrix.o qmatrix.o vector.o model.o utils.o meter.o fasttext.o
+OBJS = args.o matrix.o dictionary.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
 INCLUDES = -I.
 
 opt: CXXFLAGS += -O3 -funroll-loops
@@ -23,17 +23,20 @@ debug: fasttext
 args.o: src/args.cc src/args.h
 	$(CXX) $(CXXFLAGS) -c src/args.cc
 
+matrix.o: src/matrix.cc src/matrix.h
+	$(CXX) $(CXXFLAGS) -c src/matrix.cc
+
 dictionary.o: src/dictionary.cc src/dictionary.h src/args.h
 	$(CXX) $(CXXFLAGS) -c src/dictionary.cc
 
 productquantizer.o: src/productquantizer.cc src/productquantizer.h src/utils.h
 	$(CXX) $(CXXFLAGS) -c src/productquantizer.cc
 
-matrix.o: src/matrix.cc src/matrix.h src/utils.h
-	$(CXX) $(CXXFLAGS) -c src/matrix.cc
+densematrix.o: src/densematrix.cc src/densematrix.h src/utils.h src/matrix.h
+	$(CXX) $(CXXFLAGS) -c src/densematrix.cc
 
-qmatrix.o: src/qmatrix.cc src/qmatrix.h src/utils.h
-	$(CXX) $(CXXFLAGS) -c src/qmatrix.cc
+quantmatrix.o: src/quantmatrix.cc src/quantmatrix.h src/utils.h src/matrix.h
+	$(CXX) $(CXXFLAGS) -c src/quantmatrix.cc
 
 vector.o: src/vector.cc src/vector.h src/utils.h
 	$(CXX) $(CXXFLAGS) -c src/vector.cc

+ 8 - 6
python/fastText/pybind/fasttext_pybind.cc

@@ -7,8 +7,8 @@
  */
 
 #include <args.h>
+#include <densematrix.h>
 #include <fasttext.h>
-#include <matrix.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
 #include <real.h>
@@ -117,11 +117,11 @@ PYBIND11_MODULE(fasttext_pybind, m) {
             {sizeof(fasttext::real)});
       });
 
-  py::class_<fasttext::Matrix>(
-      m, "Matrix", py::buffer_protocol(), py::module_local())
+  py::class_<fasttext::DenseMatrix>(
+      m, "DenseMatrix", py::buffer_protocol(), py::module_local())
       .def(py::init<>())
       .def(py::init<ssize_t, ssize_t>())
-      .def_buffer([](fasttext::Matrix& m) -> py::buffer_info {
+      .def_buffer([](fasttext::DenseMatrix& m) -> py::buffer_info {
         return py::buffer_info(
             m.data(),
             sizeof(fasttext::real),
@@ -138,13 +138,15 @@ PYBIND11_MODULE(fasttext_pybind, m) {
       .def(
           "getInputMatrix",
           [](fasttext::FastText& m) {
-            std::shared_ptr<const fasttext::Matrix> mm = m.getInputMatrix();
+            std::shared_ptr<const fasttext::DenseMatrix> mm =
+                m.getInputMatrix();
             return *mm.get();
           })
       .def(
           "getOutputMatrix",
           [](fasttext::FastText& m) {
-            std::shared_ptr<const fasttext::Matrix> mm = m.getOutputMatrix();
+            std::shared_ptr<const fasttext::DenseMatrix> mm =
+                m.getOutputMatrix();
             return *mm.get();
           })
       .def(

+ 155 - 0
src/densematrix.cc

@@ -0,0 +1,155 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "densematrix.h"
+
+#include <exception>
+#include <random>
+#include <stdexcept>
+#include <utility>
+
+#include "utils.h"
+#include "vector.h"
+
+namespace fasttext {
+
+DenseMatrix::DenseMatrix() : DenseMatrix(0, 0) {}
+
+DenseMatrix::DenseMatrix(int64_t m, int64_t n) : Matrix(m, n), data_(m * n) {}
+
+DenseMatrix::DenseMatrix(DenseMatrix&& other) noexcept
+    : Matrix(other.m_, other.n_), data_(std::move(other.data_)) {}
+
+void DenseMatrix::zero() {
+  std::fill(data_.begin(), data_.end(), 0.0);
+}
+
+void DenseMatrix::uniform(real a) {
+  std::minstd_rand rng(1);
+  std::uniform_real_distribution<> uniform(-a, a);
+  for (int64_t i = 0; i < (m_ * n_); i++) {
+    data_[i] = uniform(rng);
+  }
+}
+
+void DenseMatrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
+  if (ie == -1) {
+    ie = m_;
+  }
+  assert(ie <= nums.size());
+  for (auto i = ib; i < ie; i++) {
+    real n = nums[i - ib];
+    if (n != 0) {
+      for (auto j = 0; j < n_; j++) {
+        at(i, j) *= n;
+      }
+    }
+  }
+}
+
+void DenseMatrix::divideRow(const Vector& denoms, int64_t ib, int64_t ie) {
+  if (ie == -1) {
+    ie = m_;
+  }
+  assert(ie <= denoms.size());
+  for (auto i = ib; i < ie; i++) {
+    real n = denoms[i - ib];
+    if (n != 0) {
+      for (auto j = 0; j < n_; j++) {
+        at(i, j) /= n;
+      }
+    }
+  }
+}
+
+real DenseMatrix::l2NormRow(int64_t i) const {
+  auto norm = 0.0;
+  for (auto j = 0; j < n_; j++) {
+    norm += at(i, j) * at(i, j);
+  }
+  if (std::isnan(norm)) {
+    throw std::runtime_error("Encountered NaN.");
+  }
+  return std::sqrt(norm);
+}
+
+void DenseMatrix::l2NormRow(Vector& norms) const {
+  assert(norms.size() == m_);
+  for (auto i = 0; i < m_; i++) {
+    norms[i] = l2NormRow(i);
+  }
+}
+
+real DenseMatrix::dotRow(const Vector& vec, int64_t i) const {
+  assert(i >= 0);
+  assert(i < m_);
+  assert(vec.size() == n_);
+  real d = 0.0;
+  for (int64_t j = 0; j < n_; j++) {
+    d += at(i, j) * vec[j];
+  }
+  if (std::isnan(d)) {
+    throw std::runtime_error("Encountered NaN.");
+  }
+  return d;
+}
+
+void DenseMatrix::addVectorToRow(const Vector& vec, int64_t i, real a) {
+  assert(i >= 0);
+  assert(i < m_);
+  assert(vec.size() == n_);
+  for (int64_t j = 0; j < n_; j++) {
+    data_[i * n_ + j] += a * vec[j];
+  }
+}
+
+void DenseMatrix::addRowToVector(Vector& x, int32_t i) const {
+  assert(i >= 0);
+  assert(i < this->size(0));
+  assert(x.size() == this->size(1));
+  for (int64_t j = 0; j < this->size(1); j++) {
+    x[j] += at(i, j);
+  }
+}
+
+void DenseMatrix::addRowToVector(Vector& x, int32_t i, real a) const {
+  assert(i >= 0);
+  assert(i < this->size(0));
+  assert(x.size() == this->size(1));
+  for (int64_t j = 0; j < this->size(1); j++) {
+    x[j] += a * at(i, j);
+  }
+}
+
+void DenseMatrix::save(std::ostream& out) const {
+  out.write((char*)&m_, sizeof(int64_t));
+  out.write((char*)&n_, sizeof(int64_t));
+  out.write((char*)data_.data(), m_ * n_ * sizeof(real));
+}
+
+void DenseMatrix::load(std::istream& in) {
+  in.read((char*)&m_, sizeof(int64_t));
+  in.read((char*)&n_, sizeof(int64_t));
+  data_ = std::vector<real>(m_ * n_);
+  in.read((char*)data_.data(), m_ * n_ * sizeof(real));
+}
+
+void DenseMatrix::dump(std::ostream& out) const {
+  out << m_ << " " << n_ << std::endl;
+  for (int64_t i = 0; i < m_; i++) {
+    for (int64_t j = 0; j < n_; j++) {
+      if (j > 0) {
+        out << " ";
+      }
+      out << at(i, j);
+    }
+    out << std::endl;
+  }
+};
+
+} // namespace fasttext

+ 75 - 0
src/densematrix.h

@@ -0,0 +1,75 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <cstdint>
+#include <istream>
+#include <ostream>
+#include <vector>
+
+#include <assert.h>
+#include "matrix.h"
+#include "real.h"
+
+namespace fasttext {
+
+class Vector;
+
+class DenseMatrix : public Matrix {
+ protected:
+  std::vector<real> data_;
+
+ public:
+  DenseMatrix();
+  explicit DenseMatrix(int64_t, int64_t);
+  DenseMatrix(const DenseMatrix&) = default;
+  DenseMatrix(DenseMatrix&&) noexcept;
+  DenseMatrix& operator=(const DenseMatrix&) = delete;
+  DenseMatrix& operator=(DenseMatrix&&) = delete;
+  virtual ~DenseMatrix() = default;
+
+  inline real* data() {
+    return data_.data();
+  }
+  inline const real* data() const {
+    return data_.data();
+  }
+
+  inline const real& at(int64_t i, int64_t j) const {
+    assert(i * n_ + j < data_.size());
+    return data_[i * n_ + j];
+  };
+  inline real& at(int64_t i, int64_t j) {
+    return data_[i * n_ + j];
+  };
+
+  inline int64_t rows() const {
+    return m_;
+  }
+  inline int64_t cols() const {
+    return n_;
+  }
+  void zero();
+  void uniform(real);
+
+  void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
+  void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
+
+  real l2NormRow(int64_t i) const;
+  void l2NormRow(Vector& norms) const;
+
+  real dotRow(const Vector&, int64_t) const override;
+  void addVectorToRow(const Vector&, int64_t, real) override;
+  void addRowToVector(Vector& x, int32_t i) const override;
+  void addRowToVector(Vector& x, int32_t i, real a) const override;
+  void save(std::ostream&) const override;
+  void load(std::istream&) override;
+  void dump(std::ostream&) const override;
+};
+} // namespace fasttext

+ 97 - 98
src/fasttext.cc

@@ -7,6 +7,7 @@
  */
 
 #include "fasttext.h"
+#include "quantmatrix.h"
 
 #include <algorithm>
 #include <iomanip>
@@ -30,11 +31,7 @@ bool comparePairs(
 FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
 
 void FastText::addInputVector(Vector& vec, int32_t ind) const {
-  if (quant_) {
-    vec.addRow(*qinput_, ind);
-  } else {
-    vec.addRow(*input_, ind);
-  }
+  vec.addRow(*input_, ind);
 }
 
 std::shared_ptr<const Dictionary> FastText::getDictionary() const {
@@ -45,12 +42,20 @@ const Args FastText::getArgs() const {
   return *args_.get();
 }
 
-std::shared_ptr<const Matrix> FastText::getInputMatrix() const {
-  return input_;
+std::shared_ptr<const DenseMatrix> FastText::getInputMatrix() const {
+  if (quant_) {
+    throw std::runtime_error("Can't export quantized matrix");
+  }
+  assert(input_.get());
+  return std::dynamic_pointer_cast<DenseMatrix>(input_);
 }
 
-std::shared_ptr<const Matrix> FastText::getOutputMatrix() const {
-  return output_;
+std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
+  if (quant_ && args_->qout) {
+    throw std::runtime_error("Can't export quantized matrix");
+  }
+  assert(output_.get());
+  return std::dynamic_pointer_cast<DenseMatrix>(output_);
 }
 
 int32_t FastText::getWordId(const std::string& word) const {
@@ -172,18 +177,10 @@ void FastText::saveModel(const std::string& filename) {
   dict_->save(ofs);
 
   ofs.write((char*)&(quant_), sizeof(bool));
-  if (quant_) {
-    qinput_->save(ofs);
-  } else {
-    input_->save(ofs);
-  }
+  input_->save(ofs);
 
   ofs.write((char*)&(args_->qout), sizeof(bool));
-  if (quant_ && args_->qout) {
-    qoutput_->save(ofs);
-  } else {
-    output_->save(ofs);
-  }
+  output_->save(ofs);
 
   ofs.close();
 }
@@ -200,12 +197,18 @@ void FastText::loadModel(const std::string& filename) {
   ifs.close();
 }
 
+std::vector<int64_t> FastText::getTargetCounts() const {
+  if (args_->model == model_name::sup) {
+    return dict_->getCounts(entry_type::label);
+  } else {
+    return dict_->getCounts(entry_type::word);
+  }
+}
+
 void FastText::loadModel(std::istream& in) {
   args_ = std::make_shared<Args>();
-  input_ = std::make_shared<Matrix>();
-  output_ = std::make_shared<Matrix>();
-  qinput_ = std::make_shared<QMatrix>();
-  qoutput_ = std::make_shared<QMatrix>();
+  input_ = std::make_shared<DenseMatrix>();
+  output_ = std::make_shared<DenseMatrix>();
   args_->load(in);
   if (version == 11 && args_->model == model_name::sup) {
     // backward compatibility: old supervised models do not use char ngrams.
@@ -217,10 +220,9 @@ void FastText::loadModel(std::istream& in) {
   in.read((char*)&quant_input, sizeof(bool));
   if (quant_input) {
     quant_ = true;
-    qinput_->load(in);
-  } else {
-    input_->load(in);
+    input_ = std::make_shared<QuantMatrix>();
   }
+  input_->load(in);
 
   if (!quant_input && dict_->isPruned()) {
     throw std::invalid_argument(
@@ -231,20 +233,12 @@ void FastText::loadModel(std::istream& in) {
 
   in.read((char*)&args_->qout, sizeof(bool));
   if (quant_ && args_->qout) {
-    qoutput_->load(in);
-  } else {
-    output_->load(in);
+    output_ = std::make_shared<QuantMatrix>();
   }
+  output_->load(in);
 
-  model_ = std::make_shared<Model>(input_, output_, args_, 0);
-  model_->quant_ = quant_;
-  model_->setQuantizePointer(qinput_, qoutput_, args_->qout);
-
-  if (args_->model == model_name::sup) {
-    model_->setTargetCounts(dict_->getCounts(entry_type::label));
-  } else {
-    model_->setTargetCounts(dict_->getCounts(entry_type::word));
-  }
+  model_ =
+      std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
 }
 
 void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
@@ -277,9 +271,11 @@ void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
 }
 
 std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
-  Vector norms(input_->size(0));
-  input_->l2NormRow(norms);
-  std::vector<int32_t> idx(input_->size(0), 0);
+  std::shared_ptr<DenseMatrix> input =
+      std::dynamic_pointer_cast<DenseMatrix>(input_);
+  Vector norms(input->size(0));
+  input->l2NormRow(norms);
+  std::vector<int32_t> idx(input->size(0), 0);
   std::iota(idx.begin(), idx.end(), 0);
   auto eosid = dict_->getId(Dictionary::EOS);
   std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
@@ -297,18 +293,22 @@ void FastText::quantize(const Args& qargs) {
   args_->input = qargs.input;
   args_->qout = qargs.qout;
   args_->output = qargs.output;
+  std::shared_ptr<DenseMatrix> input =
+      std::dynamic_pointer_cast<DenseMatrix>(input_);
+  std::shared_ptr<DenseMatrix> output =
+      std::dynamic_pointer_cast<DenseMatrix>(output_);
 
-  if (qargs.cutoff > 0 && qargs.cutoff < input_->size(0)) {
+  if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {
     auto idx = selectEmbeddings(qargs.cutoff);
     dict_->prune(idx);
-    std::shared_ptr<Matrix> ninput =
-        std::make_shared<Matrix>(idx.size(), args_->dim);
+    std::shared_ptr<DenseMatrix> ninput =
+        std::make_shared<DenseMatrix>(idx.size(), args_->dim);
     for (auto i = 0; i < idx.size(); i++) {
       for (auto j = 0; j < args_->dim; j++) {
-        ninput->at(i, j) = input_->at(idx[i], j);
+        ninput->at(i, j) = input->at(idx[i], j);
       }
     }
-    input_ = ninput;
+    input = ninput;
     if (qargs.retrain) {
       args_->epoch = qargs.epoch;
       args_->lr = qargs.lr;
@@ -318,21 +318,17 @@ void FastText::quantize(const Args& qargs) {
     }
   }
 
-  qinput_ = std::make_shared<QMatrix>(*input_, qargs.dsub, qargs.qnorm);
+  input_ = std::make_shared<QuantMatrix>(
+      std::move(*(input.get())), qargs.dsub, qargs.qnorm);
 
   if (args_->qout) {
-    qoutput_ = std::make_shared<QMatrix>(*output_, 2, qargs.qnorm);
+    output_ = std::make_shared<QuantMatrix>(
+        std::move(*(output.get())), 2, qargs.qnorm);
   }
 
   quant_ = true;
-  model_ = std::make_shared<Model>(input_, output_, args_, 0);
-  model_->quant_ = quant_;
-  model_->setQuantizePointer(qinput_, qoutput_, args_->qout);
-  if (args_->model == model_name::sup) {
-    model_->setTargetCounts(dict_->getCounts(entry_type::label));
-  } else {
-    model_->setTargetCounts(dict_->getCounts(entry_type::word));
-  }
+  model_ =
+      std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
 }
 
 void FastText::supervised(
@@ -490,11 +486,7 @@ std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
   for (int32_t i = 0; i < ngrams.size(); i++) {
     Vector vec(args_->dim);
     if (ngrams[i] >= 0) {
-      if (quant_) {
-        vec.addRow(*qinput_, ngrams[i]);
-      } else {
-        vec.addRow(*input_, ngrams[i]);
-      }
+      vec.addRow(*input_, ngrams[i]);
     }
     result.push_back(std::make_pair(substrings[i], std::move(vec)));
   }
@@ -511,7 +503,7 @@ void FastText::ngramVectors(std::string word) {
   }
 }
 
-void FastText::precomputeWordVectors(Matrix& wordVectors) {
+void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
   Vector vec(args_->dim);
   wordVectors.zero();
   for (int32_t i = 0; i < dict_->nwords(); i++) {
@@ -519,15 +511,15 @@ void FastText::precomputeWordVectors(Matrix& wordVectors) {
     getWordVector(vec, word);
     real norm = vec.norm();
     if (norm > 0) {
-      wordVectors.addRow(vec, i, 1.0 / norm);
+      wordVectors.addVectorToRow(vec, i, 1.0 / norm);
     }
   }
 }
 
 void FastText::lazyComputeWordVectors() {
   if (!wordVectors_) {
-    wordVectors_ =
-        std::unique_ptr<Matrix>(new Matrix(dict_->nwords(), args_->dim));
+    wordVectors_ = std::unique_ptr<DenseMatrix>(
+        new DenseMatrix(dict_->nwords(), args_->dim));
     precomputeWordVectors(*wordVectors_);
   }
 }
@@ -545,7 +537,7 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
 }
 
 std::vector<std::pair<real, std::string>> FastText::getNN(
-    const Matrix& wordVectors,
+    const DenseMatrix& wordVectors,
     const Vector& query,
     int32_t k,
     const std::set<std::string>& banSet) {
@@ -580,7 +572,7 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
 
 // depracted. use getNN instead
 void FastText::findNN(
-    const Matrix& wordVectors,
+    const DenseMatrix& wordVectors,
     const Vector& query,
     int32_t k,
     const std::set<std::string>& banSet,
@@ -632,12 +624,7 @@ void FastText::trainThread(int32_t threadId) {
   std::ifstream ifs(args_->input);
   utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
 
-  Model model(input_, output_, args_, threadId);
-  if (args_->model == model_name::sup) {
-    model.setTargetCounts(dict_->getCounts(entry_type::label));
-  } else {
-    model.setTargetCounts(dict_->getCounts(entry_type::word));
-  }
+  Model model(input_, output_, args_, getTargetCounts(), threadId);
 
   const int64_t ntokens = dict_->ntokens();
   int64_t localTokenCount = 0;
@@ -667,10 +654,11 @@ void FastText::trainThread(int32_t threadId) {
   ifs.close();
 }
 
-void FastText::loadVectors(const std::string& filename) {
+std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
+    const std::string& filename) const {
   std::ifstream in(filename);
   std::vector<std::string> words;
-  std::shared_ptr<Matrix> mat; // temp. matrix for pretrained vectors
+  std::shared_ptr<DenseMatrix> mat; // temp. matrix for pretrained vectors
   int64_t n, dim;
   if (!in.is_open()) {
     throw std::invalid_argument(filename + " cannot be opened for loading!");
@@ -681,7 +669,7 @@ void FastText::loadVectors(const std::string& filename) {
         "Dimension of pretrained vectors (" + std::to_string(dim) +
         ") does not match dimension (" + std::to_string(args_->dim) + ")!");
   }
-  mat = std::make_shared<Matrix>(n, dim);
+  mat = std::make_shared<DenseMatrix>(n, dim);
   for (size_t i = 0; i < n; i++) {
     std::string word;
     in >> word;
@@ -695,9 +683,9 @@ void FastText::loadVectors(const std::string& filename) {
 
   dict_->threshold(1, 0);
   dict_->init();
-  input_ =
-      std::make_shared<Matrix>(dict_->nwords() + args_->bucket, args_->dim);
-  input_->uniform(1.0 / args_->dim);
+  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
+      dict_->nwords() + args_->bucket, args_->dim);
+  input->uniform(1.0 / args_->dim);
 
   for (size_t i = 0; i < n; i++) {
     int32_t idx = dict_->getId(words[i]);
@@ -705,9 +693,32 @@ void FastText::loadVectors(const std::string& filename) {
       continue;
     }
     for (size_t j = 0; j < dim; j++) {
-      input_->at(idx, j) = mat->at(i, j);
+      input->at(idx, j) = mat->at(i, j);
     }
   }
+  return input;
+}
+
+void FastText::loadVectors(const std::string& filename) {
+  input_ = getInputMatrixFromFile(filename);
+}
+
+std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
+  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
+      dict_->nwords() + args_->bucket, args_->dim);
+  input->uniform(1.0 / args_->dim);
+
+  return input;
+}
+
+std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
+  int64_t m =
+      (args_->model == model_name::sup) ? dict_->nlabels() : dict_->nwords();
+  std::shared_ptr<DenseMatrix> output =
+      std::make_shared<DenseMatrix>(m, args_->dim);
+  output->zero();
+
+  return output;
 }
 
 void FastText::train(const Args& args) {
@@ -725,27 +736,15 @@ void FastText::train(const Args& args) {
   dict_->readFromFile(ifs);
   ifs.close();
 
-  if (args_->pretrainedVectors.size() != 0) {
-    loadVectors(args_->pretrainedVectors);
-  } else {
-    input_ =
-        std::make_shared<Matrix>(dict_->nwords() + args_->bucket, args_->dim);
-    input_->uniform(1.0 / args_->dim);
-  }
-
-  if (args_->model == model_name::sup) {
-    output_ = std::make_shared<Matrix>(dict_->nlabels(), args_->dim);
+  if (!args_->pretrainedVectors.empty()) {
+    input_ = getInputMatrixFromFile(args_->pretrainedVectors);
   } else {
-    output_ = std::make_shared<Matrix>(dict_->nwords(), args_->dim);
+    input_ = createRandomMatrix();
   }
-  output_->zero();
+  output_ = createTrainOutputMatrix();
   startThreads();
-  model_ = std::make_shared<Model>(input_, output_, args_, 0);
-  if (args_->model == model_name::sup) {
-    model_->setTargetCounts(dict_->getCounts(entry_type::label));
-  } else {
-    model_->setTargetCounts(dict_->getCounts(entry_type::word));
-  }
+  model_ =
+      std::make_shared<Model>(input_, output_, args_, getTargetCounts(), 0);
 }
 
 void FastText::startThreads() {

+ 14 - 12
src/fasttext.h

@@ -19,11 +19,11 @@
 #include <tuple>
 
 #include "args.h"
+#include "densematrix.h"
 #include "dictionary.h"
 #include "matrix.h"
 #include "meter.h"
 #include "model.h"
-#include "qmatrix.h"
 #include "real.h"
 #include "utils.h"
 #include "vector.h"
@@ -38,9 +38,6 @@ class FastText {
   std::shared_ptr<Matrix> input_;
   std::shared_ptr<Matrix> output_;
 
-  std::shared_ptr<QMatrix> qinput_;
-  std::shared_ptr<QMatrix> qoutput_;
-
   std::shared_ptr<Model> model_;
 
   std::atomic<int64_t> tokenCount_{};
@@ -53,16 +50,20 @@ class FastText {
   void addInputVector(Vector&, int32_t) const;
   void trainThread(int32_t);
   std::vector<std::pair<real, std::string>> getNN(
-      const Matrix& wordVectors,
+      const DenseMatrix& wordVectors,
       const Vector& queryVec,
       int32_t k,
       const std::set<std::string>& banSet);
   void lazyComputeWordVectors();
   void printInfo(real, real, std::ostream&);
+  std::shared_ptr<Matrix> getInputMatrixFromFile(const std::string&) const;
+  std::shared_ptr<Matrix> createRandomMatrix() const;
+  std::shared_ptr<Matrix> createTrainOutputMatrix() const;
+  std::vector<int64_t> getTargetCounts() const;
 
   bool quant_;
   int32_t version;
-  std::unique_ptr<Matrix> wordVectors_;
+  std::unique_ptr<DenseMatrix> wordVectors_;
 
  public:
   FastText();
@@ -84,9 +85,9 @@ class FastText {
 
   std::shared_ptr<const Dictionary> getDictionary() const;
 
-  std::shared_ptr<const Matrix> getInputMatrix() const;
+  std::shared_ptr<const DenseMatrix> getInputMatrix() const;
 
-  std::shared_ptr<const Matrix> getOutputMatrix() const;
+  std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
 
   void saveVectors(const std::string& filename);
 
@@ -134,12 +135,13 @@ class FastText {
 
   void train(const Args& args);
 
-  void loadVectors(const std::string& filename);
-
   int getDimension() const;
 
   bool isQuant() const;
 
+  FASTTEXT_DEPRECATED("loadVectors is being deprecated.")
+  void loadVectors(const std::string& filename);
+
   FASTTEXT_DEPRECATED(
       "getVector is being deprecated and replaced by getWordVector.")
   void getVector(Vector& vec, const std::string& word) const;
@@ -181,11 +183,11 @@ class FastText {
   void saveModel();
 
   FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
-  void precomputeWordVectors(Matrix& wordVectors);
+  void precomputeWordVectors(DenseMatrix& wordVectors);
 
   FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
   void findNN(
-      const Matrix& wordVectors,
+      const DenseMatrix& wordVectors,
       const Vector& query,
       int32_t k,
       const std::set<std::string>& banSet,

+ 7 - 115
src/matrix.cc

@@ -8,126 +8,18 @@
 
 #include "matrix.h"
 
-#include <exception>
-#include <random>
-#include <stdexcept>
-
-#include "utils.h"
-#include "vector.h"
-
 namespace fasttext {
 
-Matrix::Matrix() : Matrix(0, 0) {}
-
-Matrix::Matrix(int64_t m, int64_t n) : data_(m * n), m_(m), n_(n) {}
-
-void Matrix::zero() {
-  std::fill(data_.begin(), data_.end(), 0.0);
-}
-
-void Matrix::uniform(real a) {
-  std::minstd_rand rng(1);
-  std::uniform_real_distribution<> uniform(-a, a);
-  for (int64_t i = 0; i < (m_ * n_); i++) {
-    data_[i] = uniform(rng);
-  }
-}
-
-real Matrix::dotRow(const Vector& vec, int64_t i) const {
-  assert(i >= 0);
-  assert(i < m_);
-  assert(vec.size() == n_);
-  real d = 0.0;
-  for (int64_t j = 0; j < n_; j++) {
-    d += at(i, j) * vec[j];
-  }
-  if (std::isnan(d)) {
-    throw std::runtime_error("Encountered NaN.");
-  }
-  return d;
-}
-
-void Matrix::addRow(const Vector& vec, int64_t i, real a) {
-  assert(i >= 0);
-  assert(i < m_);
-  assert(vec.size() == n_);
-  for (int64_t j = 0; j < n_; j++) {
-    data_[i * n_ + j] += a * vec[j];
-  }
-}
-
-void Matrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
-  if (ie == -1) {
-    ie = m_;
-  }
-  assert(ie <= nums.size());
-  for (auto i = ib; i < ie; i++) {
-    real n = nums[i - ib];
-    if (n != 0) {
-      for (auto j = 0; j < n_; j++) {
-        at(i, j) *= n;
-      }
-    }
-  }
-}
-
-void Matrix::divideRow(const Vector& denoms, int64_t ib, int64_t ie) {
-  if (ie == -1) {
-    ie = m_;
-  }
-  assert(ie <= denoms.size());
-  for (auto i = ib; i < ie; i++) {
-    real n = denoms[i - ib];
-    if (n != 0) {
-      for (auto j = 0; j < n_; j++) {
-        at(i, j) /= n;
-      }
-    }
-  }
-}
+Matrix::Matrix() : m_(0), n_(0) {}
 
-real Matrix::l2NormRow(int64_t i) const {
-  auto norm = 0.0;
-  for (auto j = 0; j < n_; j++) {
-    norm += at(i, j) * at(i, j);
-  }
-  if (std::isnan(norm)) {
-    throw std::runtime_error("Encountered NaN.");
-  }
-  return std::sqrt(norm);
-}
+Matrix::Matrix(int64_t m, int64_t n) : m_(m), n_(n) {}
 
-void Matrix::l2NormRow(Vector& norms) const {
-  assert(norms.size() == m_);
-  for (auto i = 0; i < m_; i++) {
-    norms[i] = l2NormRow(i);
+int64_t Matrix::size(int64_t dim) const {
+  assert(dim == 0 || dim == 1);
+  if (dim == 0) {
+    return m_;
   }
+  return n_;
 }
 
-void Matrix::save(std::ostream& out) {
-  out.write((char*)&m_, sizeof(int64_t));
-  out.write((char*)&n_, sizeof(int64_t));
-  out.write((char*)data_.data(), m_ * n_ * sizeof(real));
-}
-
-void Matrix::load(std::istream& in) {
-  in.read((char*)&m_, sizeof(int64_t));
-  in.read((char*)&n_, sizeof(int64_t));
-  data_ = std::vector<real>(m_ * n_);
-  in.read((char*)data_.data(), m_ * n_ * sizeof(real));
-}
-
-void Matrix::dump(std::ostream& out) const {
-  out << m_ << " " << n_ << std::endl;
-  for (int64_t i = 0; i < m_; i++) {
-    for (int64_t j = 0; j < n_; j++) {
-      if (j > 0) {
-        out << " ";
-      }
-      out << at(i, j);
-    }
-    out << std::endl;
-  }
-};
-
 } // namespace fasttext

+ 12 - 46
src/matrix.h

@@ -22,57 +22,23 @@ class Vector;
 
 class Matrix {
  protected:
-  std::vector<real> data_;
-  const int64_t m_;
-  const int64_t n_;
+  int64_t m_;
+  int64_t n_;
 
  public:
   Matrix();
   explicit Matrix(int64_t, int64_t);
-  Matrix(const Matrix&) = default;
-  Matrix& operator=(const Matrix&) = delete;
+  virtual ~Matrix() = default;
 
-  inline real* data() {
-    return data_.data();
-  }
-  inline const real* data() const {
-    return data_.data();
-  }
+  int64_t size(int64_t dim) const;
 
-  inline const real& at(int64_t i, int64_t j) const {
-    return data_[i * n_ + j];
-  };
-  inline real& at(int64_t i, int64_t j) {
-    return data_[i * n_ + j];
-  };
-
-  inline int64_t size(int64_t dim) const {
-    assert(dim == 0 || dim == 1);
-    if (dim == 0) {
-      return m_;
-    }
-    return n_;
-  }
-  inline int64_t rows() const {
-    return m_;
-  }
-  inline int64_t cols() const {
-    return n_;
-  }
-  void zero();
-  void uniform(real);
-  real dotRow(const Vector&, int64_t) const;
-  void addRow(const Vector&, int64_t, real);
-
-  void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
-  void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
-
-  real l2NormRow(int64_t i) const;
-  void l2NormRow(Vector& norms) const;
-
-  void save(std::ostream&);
-  void load(std::istream&);
-
-  void dump(std::ostream&) const;
+  virtual real dotRow(const Vector&, int64_t) const = 0;
+  virtual void addVectorToRow(const Vector&, int64_t, real) = 0;
+  virtual void addRowToVector(Vector& x, int32_t i) const = 0;
+  virtual void addRowToVector(Vector& x, int32_t i, real a) const = 0;
+  virtual void save(std::ostream&) const = 0;
+  virtual void load(std::istream&) = 0;
+  virtual void dump(std::ostream&) const = 0;
 };
+
 } // namespace fasttext

+ 9 - 35
src/model.cc

@@ -23,12 +23,9 @@ Model::Model(
     std::shared_ptr<Matrix> wi,
     std::shared_ptr<Matrix> wo,
     std::shared_ptr<Args> args,
+    const std::vector<int64_t>& targetCounts,
     int32_t seed)
-    : hidden_(args->dim),
-      output_(wo->size(0)),
-      grad_(args->dim),
-      rng(seed),
-      quant_(false) {
+    : hidden_(args->dim), output_(wo->size(0)), grad_(args->dim), rng(seed) {
   wi_ = wi;
   wo_ = wo;
   args_ = args;
@@ -41,24 +38,14 @@ Model::Model(
   t_log_.reserve(LOG_TABLE_SIZE + 1);
   initSigmoid();
   initLog();
-}
-
-void Model::setQuantizePointer(
-    std::shared_ptr<QMatrix> qwi,
-    std::shared_ptr<QMatrix> qwo,
-    bool qout) {
-  qwi_ = qwi;
-  qwo_ = qwo;
-  if (qout) {
-    osz_ = qwo_->getM();
-  }
+  setTargetCounts(targetCounts);
 }
 
 real Model::binaryLogistic(int32_t target, bool label, real lr) {
   real score = sigmoid(wo_->dotRow(hidden_, target));
   real alpha = lr * (real(label) - score);
   grad_.addRow(*wo_, target, alpha);
-  wo_->addRow(hidden_, target, alpha);
+  wo_->addVectorToRow(hidden_, target, alpha);
   if (label) {
     return -log(score);
   } else {
@@ -91,11 +78,7 @@ real Model::hierarchicalSoftmax(int32_t target, real lr) {
 }
 
 void Model::computeOutput(Vector& hidden, Vector& output) const {
-  if (quant_ && args_->qout) {
-    output.mul(*qwo_, hidden);
-  } else {
-    output.mul(*wo_, hidden);
-  }
+  output.mul(*wo_, hidden);
 }
 
 void Model::computeOutputSigmoid(Vector& hidden, Vector& output) const {
@@ -131,7 +114,7 @@ real Model::softmax(int32_t target, real lr) {
     real label = (i == target) ? 1.0 : 0.0;
     real alpha = lr * (label - output_[i]);
     grad_.addRow(*wo_, i, alpha);
-    wo_->addRow(hidden_, i, alpha);
+    wo_->addVectorToRow(hidden_, i, alpha);
   }
   return -log(output_[target]);
 }
@@ -151,11 +134,7 @@ void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden)
   assert(hidden.size() == hsz_);
   hidden.zero();
   for (auto it = input.cbegin(); it != input.cend(); ++it) {
-    if (quant_) {
-      hidden.addRow(*qwi_, *it);
-    } else {
-      hidden.addRow(*wi_, *it);
-    }
+    hidden.addRow(*wi_, *it);
   }
   hidden.mul(1.0 / input.size());
 }
@@ -250,12 +229,7 @@ void Model::dfs(
     return;
   }
 
-  real f;
-  if (quant_ && args_->qout) {
-    f = qwo_->dotRow(hidden, node - osz_);
-  } else {
-    f = wo_->dotRow(hidden, node - osz_);
-  }
+  real f = wo_->dotRow(hidden, node - osz_);
   f = 1. / (1 + std::exp(-f));
 
   dfs(k, threshold, tree[node].left, score + std_log(1.0 - f), heap, hidden);
@@ -307,7 +281,7 @@ void Model::update(
     grad_.mul(1.0 / input.size());
   }
   for (auto it = input.cbegin(); it != input.cend(); ++it) {
-    wi_->addRow(grad_, *it, 1.0);
+    wi_->addVectorToRow(grad_, *it, 1.0);
   }
 }
 

+ 2 - 7
src/model.h

@@ -15,7 +15,6 @@
 
 #include "args.h"
 #include "matrix.h"
-#include "qmatrix.h"
 #include "real.h"
 #include "vector.h"
 
@@ -33,8 +32,6 @@ class Model {
  protected:
   std::shared_ptr<Matrix> wi_;
   std::shared_ptr<Matrix> wo_;
-  std::shared_ptr<QMatrix> qwi_;
-  std::shared_ptr<QMatrix> qwo_;
   std::shared_ptr<Args> args_;
   Vector hidden_;
   Vector output_;
@@ -61,6 +58,7 @@ class Model {
   void initSigmoid();
   void initLog();
   void computeOutput(Vector&, Vector&) const;
+  void setTargetCounts(const std::vector<int64_t>&);
 
   static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
 
@@ -69,6 +67,7 @@ class Model {
       std::shared_ptr<Matrix>,
       std::shared_ptr<Matrix>,
       std::shared_ptr<Args>,
+      const std::vector<int64_t>&,
       int32_t);
 
   real binaryLogistic(int32_t, bool, real);
@@ -113,7 +112,6 @@ class Model {
   void computeOutputSoftmax(Vector&, Vector&) const;
   void computeOutputSoftmax();
 
-  void setTargetCounts(const std::vector<int64_t>&);
   void initTableNegatives(const std::vector<int64_t>&);
   void buildTree(const std::vector<int64_t>&);
   real getLoss() const;
@@ -122,9 +120,6 @@ class Model {
   real std_log(real) const;
 
   std::minstd_rand rng;
-  bool quant_;
-  void
-  setQuantizePointer(std::shared_ptr<QMatrix>, std::shared_ptr<QMatrix>, bool);
 
   static const int32_t kUnlimitedPredictions = -1;
   static const int32_t kAllLabelsAsTarget = -1;

+ 1 - 1
src/productquantizer.cc

@@ -229,7 +229,7 @@ void ProductQuantizer::compute_codes(const real* x, uint8_t* codes, int32_t n)
   }
 }
 
-void ProductQuantizer::save(std::ostream& out) {
+void ProductQuantizer::save(std::ostream& out) const {
   out.write((char*)&dim_, sizeof(dim_));
   out.write((char*)&nsubq_, sizeof(nsubq_));
   out.write((char*)&dsub_, sizeof(dsub_));

+ 1 - 1
src/productquantizer.h

@@ -56,7 +56,7 @@ class ProductQuantizer {
   void compute_code(const real*, uint8_t*) const;
   void compute_codes(const real*, uint8_t*, int32_t) const;
 
-  void save(std::ostream&);
+  void save(std::ostream&) const;
   void load(std::istream&);
 };
 

+ 0 - 59
src/qmatrix.h

@@ -1,59 +0,0 @@
-/**
- * Copyright (c) 2016-present, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-#pragma once
-
-#include <cstdint>
-#include <istream>
-#include <ostream>
-
-#include <memory>
-#include <vector>
-
-#include "real.h"
-
-#include "matrix.h"
-#include "vector.h"
-
-#include "productquantizer.h"
-
-namespace fasttext {
-
-class QMatrix {
- protected:
-  std::unique_ptr<ProductQuantizer> pq_;
-  std::unique_ptr<ProductQuantizer> npq_;
-
-  std::vector<uint8_t> codes_;
-  std::vector<uint8_t> norm_codes_;
-
-  bool qnorm_;
-
-  int64_t m_;
-  int64_t n_;
-
-  int32_t codesize_;
-
- public:
-  QMatrix();
-  QMatrix(const Matrix&, int32_t, bool);
-
-  int64_t getM() const;
-  int64_t getN() const;
-
-  void quantizeNorm(const Vector&);
-  void quantize(const Matrix&);
-
-  void addToVector(Vector& x, int32_t t) const;
-  real dotRow(const Vector&, int64_t) const;
-
-  void save(std::ostream&);
-  void load(std::istream&);
-};
-
-} // namespace fasttext

+ 36 - 32
src/qmatrix.cc → src/quantmatrix.cc

@@ -6,30 +6,29 @@
  * LICENSE file in the root directory of this source tree.
  */
 
-#include "qmatrix.h"
+#include "quantmatrix.h"
 
 #include <assert.h>
 #include <iostream>
 
 namespace fasttext {
 
-QMatrix::QMatrix() : qnorm_(false), m_(0), n_(0), codesize_(0) {}
+QuantMatrix::QuantMatrix() : Matrix(), qnorm_(false), codesize_(0) {}
 
-QMatrix::QMatrix(const Matrix& mat, int32_t dsub, bool qnorm)
-    : qnorm_(qnorm),
-      m_(mat.size(0)),
-      n_(mat.size(1)),
-      codesize_(m_ * ((n_ + dsub - 1) / dsub)) {
+QuantMatrix::QuantMatrix(DenseMatrix&& mat, int32_t dsub, bool qnorm)
+    : Matrix(mat.size(0), mat.size(1)),
+      qnorm_(qnorm),
+      codesize_(mat.size(0) * ((mat.size(1) + dsub - 1) / dsub)) {
   codes_.resize(codesize_);
   pq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer(n_, dsub));
   if (qnorm_) {
     norm_codes_.resize(m_);
     npq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer(1, 1));
   }
-  quantize(mat);
+  quantize(std::forward<DenseMatrix>(mat));
 }
 
-void QMatrix::quantizeNorm(const Vector& norms) {
+void QuantMatrix::quantizeNorm(const Vector& norms) {
   assert(qnorm_);
   assert(norms.size() == m_);
   auto dataptr = norms.data();
@@ -37,30 +36,19 @@ void QMatrix::quantizeNorm(const Vector& norms) {
   npq_->compute_codes(dataptr, norm_codes_.data(), m_);
 }
 
-void QMatrix::quantize(const Matrix& matrix) {
-  assert(m_ == matrix.size(0));
-  assert(n_ == matrix.size(1));
-  Matrix temp(matrix);
+void QuantMatrix::quantize(DenseMatrix&& mat) {
   if (qnorm_) {
-    Vector norms(temp.size(0));
-    temp.l2NormRow(norms);
-    temp.divideRow(norms);
+    Vector norms(mat.size(0));
+    mat.l2NormRow(norms);
+    mat.divideRow(norms);
     quantizeNorm(norms);
   }
-  auto dataptr = temp.data();
+  auto dataptr = mat.data();
   pq_->train(m_, dataptr);
   pq_->compute_codes(dataptr, codes_.data(), m_);
 }
 
-void QMatrix::addToVector(Vector& x, int32_t t) const {
-  real norm = 1;
-  if (qnorm_) {
-    norm = npq_->get_centroids(0, norm_codes_[t])[0];
-  }
-  pq_->addcode(x, codes_.data(), t, norm);
-}
-
-real QMatrix::dotRow(const Vector& vec, int64_t i) const {
+real QuantMatrix::dotRow(const Vector& vec, int64_t i) const {
   assert(i >= 0);
   assert(i < m_);
   assert(vec.size() == n_);
@@ -71,15 +59,27 @@ real QMatrix::dotRow(const Vector& vec, int64_t i) const {
   return pq_->mulcode(vec, codes_.data(), i, norm);
 }
 
-int64_t QMatrix::getM() const {
-  return m_;
+void QuantMatrix::addVectorToRow(const Vector&, int64_t, real) {
+  throw std::runtime_error("Operation not permitted on quantized matrices.");
 }
 
-int64_t QMatrix::getN() const {
-  return n_;
+void QuantMatrix::addRowToVector(Vector& x, int32_t i, real a) const {
+  real norm = 1;
+  if (qnorm_) {
+    norm = npq_->get_centroids(0, norm_codes_[i])[0];
+  }
+  pq_->addcode(x, codes_.data(), i, a * norm);
 }
 
-void QMatrix::save(std::ostream& out) {
+void QuantMatrix::addRowToVector(Vector& x, int32_t i) const {
+  real norm = 1;
+  if (qnorm_) {
+    norm = npq_->get_centroids(0, norm_codes_[i])[0];
+  }
+  pq_->addcode(x, codes_.data(), i, norm);
+}
+
+void QuantMatrix::save(std::ostream& out) const {
   out.write((char*)&qnorm_, sizeof(qnorm_));
   out.write((char*)&m_, sizeof(m_));
   out.write((char*)&n_, sizeof(n_));
@@ -92,7 +92,7 @@ void QMatrix::save(std::ostream& out) {
   }
 }
 
-void QMatrix::load(std::istream& in) {
+void QuantMatrix::load(std::istream& in) {
   in.read((char*)&qnorm_, sizeof(qnorm_));
   in.read((char*)&m_, sizeof(m_));
   in.read((char*)&n_, sizeof(n_));
@@ -109,4 +109,8 @@ void QMatrix::load(std::istream& in) {
   }
 }
 
+void QuantMatrix::dump(std::ostream&) const {
+  throw std::runtime_error("Operation not permitted on quantized matrices.");
+}
+
 } // namespace fasttext

+ 60 - 0
src/quantmatrix.h

@@ -0,0 +1,60 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <cstdint>
+#include <istream>
+#include <ostream>
+
+#include <memory>
+#include <vector>
+
+#include "real.h"
+
+#include "densematrix.h"
+#include "matrix.h"
+#include "vector.h"
+
+#include "productquantizer.h"
+
+namespace fasttext {
+
+class QuantMatrix : public Matrix {
+ protected:
+  std::unique_ptr<ProductQuantizer> pq_;
+  std::unique_ptr<ProductQuantizer> npq_;
+
+  std::vector<uint8_t> codes_;
+  std::vector<uint8_t> norm_codes_;
+
+  bool qnorm_;
+  int32_t codesize_;
+
+ public:
+  QuantMatrix();
+  QuantMatrix(DenseMatrix&&, int32_t, bool);
+  QuantMatrix(const QuantMatrix&) = delete;
+  QuantMatrix(QuantMatrix&&) = delete;
+  QuantMatrix& operator=(const QuantMatrix&) = delete;
+  QuantMatrix& operator=(QuantMatrix&&) = delete;
+  virtual ~QuantMatrix() = default;
+
+  void quantizeNorm(const Vector&);
+  void quantize(DenseMatrix&& mat);
+
+  real dotRow(const Vector&, int64_t) const override;
+  void addVectorToRow(const Vector&, int64_t, real) override;
+  void addRowToVector(Vector& x, int32_t i) const override;
+  void addRowToVector(Vector& x, int32_t i, real a) const override;
+  void save(std::ostream&) const override;
+  void load(std::istream&) override;
+  void dump(std::ostream&) const override;
+};
+
+} // namespace fasttext

+ 4 - 22
src/vector.cc

@@ -15,7 +15,6 @@
 #include <utility>
 
 #include "matrix.h"
-#include "qmatrix.h"
 
 namespace fasttext {
 
@@ -60,27 +59,18 @@ void Vector::addVector(const Vector& source, real s) {
   }
 }
 
-void Vector::addRow(const Matrix& A, int64_t i) {
+void Vector::addRow(const Matrix& A, int64_t i, real a) {
   assert(i >= 0);
   assert(i < A.size(0));
   assert(size() == A.size(1));
-  for (int64_t j = 0; j < A.size(1); j++) {
-    data_[j] += A.at(i, j);
-  }
+  A.addRowToVector(*this, i, a);
 }
 
-void Vector::addRow(const Matrix& A, int64_t i, real a) {
+void Vector::addRow(const Matrix& A, int64_t i) {
   assert(i >= 0);
   assert(i < A.size(0));
   assert(size() == A.size(1));
-  for (int64_t j = 0; j < A.size(1); j++) {
-    data_[j] += a * A.at(i, j);
-  }
-}
-
-void Vector::addRow(const QMatrix& A, int64_t i) {
-  assert(i >= 0);
-  A.addToVector(*this, i);
+  A.addRowToVector(*this, i);
 }
 
 void Vector::mul(const Matrix& A, const Vector& vec) {
@@ -91,14 +81,6 @@ void Vector::mul(const Matrix& A, const Vector& vec) {
   }
 }
 
-void Vector::mul(const QMatrix& A, const Vector& vec) {
-  assert(A.getM() == size());
-  assert(A.getN() == vec.size());
-  for (int64_t i = 0; i < size(); i++) {
-    data_[i] = A.dotRow(vec, i);
-  }
-}
-
 int64_t Vector::argmax() {
   real max = data_[0];
   int64_t argmax = 0;

+ 0 - 3
src/vector.h

@@ -17,7 +17,6 @@
 namespace fasttext {
 
 class Matrix;
-class QMatrix;
 
 class Vector {
  protected:
@@ -52,9 +51,7 @@ class Vector {
   void addVector(const Vector& source);
   void addVector(const Vector&, real);
   void addRow(const Matrix&, int64_t);
-  void addRow(const QMatrix&, int64_t);
   void addRow(const Matrix&, int64_t, real);
-  void mul(const QMatrix&, const Vector&);
   void mul(const Matrix&, const Vector&);
   int64_t argmax();
 };