dictionary.h 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the BSD-style license found in the
  6. * LICENSE file in the root directory of this source tree. An additional grant
  7. * of patent rights can be found in the PATENTS file in the same directory.
  8. */
  9. #ifndef FASTTEXT_DICTIONARY_H
  10. #define FASTTEXT_DICTIONARY_H
  11. #include <vector>
  12. #include <string>
  13. #include <istream>
  14. #include <ostream>
  15. #include <random>
  16. #include <memory>
  17. #include <unordered_map>
  18. #include "args.h"
  19. #include "real.h"
  20. namespace fasttext {
  21. typedef int32_t id_type;
  22. enum class entry_type : int8_t {word=0, label=1};
  23. struct entry {
  24. std::string word;
  25. int64_t count;
  26. entry_type type;
  27. std::vector<int32_t> subwords;
  28. };
  29. class Dictionary {
  30. private:
  31. static const int32_t MAX_VOCAB_SIZE = 30000000;
  32. static const int32_t MAX_LINE_SIZE = 1024;
  33. int32_t find(const std::string&) const;
  34. int32_t find(const std::string&, uint32_t h) const;
  35. void initTableDiscard();
  36. void initNgrams();
  37. void reset(std::istream&) const;
  38. void pushHash(std::vector<int32_t>&, int32_t) const;
  39. void addSubwords(std::vector<int32_t>&, const std::string&, int32_t) const;
  40. std::shared_ptr<Args> args_;
  41. std::vector<int32_t> word2int_;
  42. std::vector<entry> words_;
  43. std::vector<real> pdiscard_;
  44. int32_t size_;
  45. int32_t nwords_;
  46. int32_t nlabels_;
  47. int64_t ntokens_;
  48. int64_t pruneidx_size_;
  49. std::unordered_map<int32_t, int32_t> pruneidx_;
  50. void addWordNgrams(
  51. std::vector<int32_t>& line,
  52. const std::vector<int32_t>& hashes,
  53. int32_t n) const;
  54. public:
  55. static const std::string EOS;
  56. static const std::string BOW;
  57. static const std::string EOW;
  58. explicit Dictionary(std::shared_ptr<Args>);
  59. int32_t nwords() const;
  60. int32_t nlabels() const;
  61. int64_t ntokens() const;
  62. int32_t getId(const std::string&) const;
  63. int32_t getId(const std::string&, uint32_t h) const;
  64. entry_type getType(int32_t) const;
  65. entry_type getType(const std::string&) const;
  66. bool discard(int32_t, real) const;
  67. std::string getWord(int32_t) const;
  68. const std::vector<int32_t>& getSubwords(int32_t) const;
  69. const std::vector<int32_t> getSubwords(const std::string&) const;
  70. void computeSubwords(const std::string&, std::vector<int32_t>&) const;
  71. void computeSubwords(
  72. const std::string&,
  73. std::vector<int32_t>&,
  74. std::vector<std::string>&) const;
  75. void getSubwords(
  76. const std::string&,
  77. std::vector<int32_t>&,
  78. std::vector<std::string>&) const;
  79. uint32_t hash(const std::string& str) const;
  80. void add(const std::string&);
  81. bool readWord(std::istream&, std::string&) const;
  82. void readFromFile(std::istream&);
  83. std::string getLabel(int32_t) const;
  84. void save(std::ostream&) const;
  85. void load(std::istream&);
  86. std::vector<int64_t> getCounts(entry_type) const;
  87. int32_t getLine(std::istream&, std::vector<int32_t>&,
  88. std::vector<int32_t>&, std::minstd_rand&) const;
  89. int32_t getLine(std::istream&, std::vector<int32_t>&,
  90. std::minstd_rand&) const;
  91. void threshold(int64_t, int64_t);
  92. void prune(std::vector<int32_t>&);
  93. bool isPruned() { return pruneidx_size_ >= 0; }
  94. };
  95. }
  96. #endif