fasttext.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #pragma once
  9. #include <time.h>
  10. #include <atomic>
  11. #include <chrono>
  12. #include <functional>
  13. #include <iostream>
  14. #include <memory>
  15. #include <queue>
  16. #include <set>
  17. #include <tuple>
  18. #include "args.h"
  19. #include "densematrix.h"
  20. #include "dictionary.h"
  21. #include "matrix.h"
  22. #include "meter.h"
  23. #include "model.h"
  24. #include "real.h"
  25. #include "utils.h"
  26. #include "vector.h"
  27. namespace fasttext {
  28. class FastText {
  29. public:
  30. using TrainCallback =
  31. std::function<void(float, float, double, double, int64_t)>;
  32. protected:
  33. std::shared_ptr<Args> args_;
  34. std::shared_ptr<Dictionary> dict_;
  35. std::shared_ptr<Matrix> input_;
  36. std::shared_ptr<Matrix> output_;
  37. std::shared_ptr<Model> model_;
  38. std::atomic<int64_t> tokenCount_{};
  39. std::atomic<real> loss_{};
  40. std::chrono::steady_clock::time_point start_;
  41. bool quant_;
  42. int32_t version;
  43. std::unique_ptr<DenseMatrix> wordVectors_;
  44. std::exception_ptr trainException_;
  45. void signModel(std::ostream&);
  46. bool checkModel(std::istream&);
  47. void startThreads(const TrainCallback& callback = {});
  48. void addInputVector(Vector&, int32_t) const;
  49. void trainThread(int32_t, const TrainCallback& callback);
  50. std::vector<std::pair<real, std::string>> getNN(
  51. const DenseMatrix& wordVectors,
  52. const Vector& queryVec,
  53. int32_t k,
  54. const std::set<std::string>& banSet);
  55. void lazyComputeWordVectors();
  56. void printInfo(real, real, std::ostream&);
  57. std::shared_ptr<Matrix> getInputMatrixFromFile(const std::string&) const;
  58. std::shared_ptr<Matrix> createRandomMatrix() const;
  59. std::shared_ptr<Matrix> createTrainOutputMatrix() const;
  60. std::vector<int64_t> getTargetCounts() const;
  61. std::shared_ptr<Loss> createLoss(std::shared_ptr<Matrix>& output);
  62. void supervised(
  63. Model::State& state,
  64. real lr,
  65. const std::vector<int32_t>& line,
  66. const std::vector<int32_t>& labels);
  67. void cbow(Model::State& state, real lr, const std::vector<int32_t>& line);
  68. void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line);
  69. std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
  70. void precomputeWordVectors(DenseMatrix& wordVectors);
  71. bool keepTraining(const int64_t ntokens) const;
  72. void buildModel();
  73. std::tuple<int64_t, double, double> progressInfo(real progress);
  74. public:
  75. FastText();
  76. int32_t getWordId(const std::string& word) const;
  77. int32_t getSubwordId(const std::string& subword) const;
  78. int32_t getLabelId(const std::string& label) const;
  79. void getWordVector(Vector& vec, const std::string& word) const;
  80. void getSubwordVector(Vector& vec, const std::string& subword) const;
  81. inline void getInputVector(Vector& vec, int32_t ind) {
  82. vec.zero();
  83. addInputVector(vec, ind);
  84. }
  85. const Args getArgs() const;
  86. std::shared_ptr<const Dictionary> getDictionary() const;
  87. std::shared_ptr<const DenseMatrix> getInputMatrix() const;
  88. void setMatrices(
  89. const std::shared_ptr<DenseMatrix>& inputMatrix,
  90. const std::shared_ptr<DenseMatrix>& outputMatrix);
  91. std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
  92. void saveVectors(const std::string& filename);
  93. void saveModel(const std::string& filename);
  94. void saveOutput(const std::string& filename);
  95. void loadModel(std::istream& in);
  96. void loadModel(const std::string& filename);
  97. void getSentenceVector(std::istream& in, Vector& vec);
  98. void quantize(const Args& qargs, const TrainCallback& callback = {});
  99. std::tuple<int64_t, double, double>
  100. test(std::istream& in, int32_t k, real threshold = 0.0);
  101. void test(std::istream& in, int32_t k, real threshold, Meter& meter) const;
  102. void predict(
  103. int32_t k,
  104. const std::vector<int32_t>& words,
  105. Predictions& predictions,
  106. real threshold = 0.0) const;
  107. bool predictLine(
  108. std::istream& in,
  109. std::vector<std::pair<real, std::string>>& predictions,
  110. int32_t k,
  111. real threshold) const;
  112. std::vector<std::pair<std::string, Vector>> getNgramVectors(
  113. const std::string& word) const;
  114. std::vector<std::pair<real, std::string>> getNN(
  115. const std::string& word,
  116. int32_t k);
  117. std::vector<std::pair<real, std::string>> getAnalogies(
  118. int32_t k,
  119. const std::string& wordA,
  120. const std::string& wordB,
  121. const std::string& wordC);
  122. void train(const Args& args, const TrainCallback& callback = {});
  123. void abort();
  124. int getDimension() const;
  125. bool isQuant() const;
  126. class AbortError : public std::runtime_error {
  127. public:
  128. AbortError() : std::runtime_error("Aborted.") {}
  129. };
  130. };
  131. } // namespace fasttext