1
0

loss.h 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #pragma once
  9. #include <memory>
  10. #include <random>
  11. #include <vector>
  12. #include "matrix.h"
  13. #include "model.h"
  14. #include "real.h"
  15. #include "utils.h"
  16. #include "vector.h"
  17. namespace fasttext {
  18. class Loss {
  19. private:
  20. void findKBest(
  21. int32_t k,
  22. real threshold,
  23. Predictions& heap,
  24. const Vector& output) const;
  25. protected:
  26. std::vector<real> t_sigmoid_;
  27. std::vector<real> t_log_;
  28. std::shared_ptr<Matrix>& wo_;
  29. real log(real x) const;
  30. real sigmoid(real x) const;
  31. public:
  32. explicit Loss(std::shared_ptr<Matrix>& wo);
  33. virtual ~Loss() = default;
  34. virtual real forward(
  35. const std::vector<int32_t>& targets,
  36. int32_t targetIndex,
  37. Model::State& state,
  38. real lr,
  39. bool backprop) = 0;
  40. virtual void computeOutput(Model::State& state) const = 0;
  41. virtual void predict(
  42. int32_t /*k*/,
  43. real /*threshold*/,
  44. Predictions& /*heap*/,
  45. Model::State& /*state*/) const;
  46. };
  47. class BinaryLogisticLoss : public Loss {
  48. protected:
  49. real binaryLogistic(
  50. int32_t target,
  51. Model::State& state,
  52. bool labelIsPositive,
  53. real lr,
  54. bool backprop) const;
  55. public:
  56. explicit BinaryLogisticLoss(std::shared_ptr<Matrix>& wo);
  57. virtual ~BinaryLogisticLoss() noexcept override = default;
  58. void computeOutput(Model::State& state) const override;
  59. };
  60. class OneVsAllLoss : public BinaryLogisticLoss {
  61. public:
  62. explicit OneVsAllLoss(std::shared_ptr<Matrix>& wo);
  63. ~OneVsAllLoss() noexcept override = default;
  64. real forward(
  65. const std::vector<int32_t>& targets,
  66. int32_t targetIndex,
  67. Model::State& state,
  68. real lr,
  69. bool backprop) override;
  70. };
  71. class NegativeSamplingLoss : public BinaryLogisticLoss {
  72. protected:
  73. static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
  74. int neg_;
  75. std::vector<int32_t> negatives_;
  76. std::uniform_int_distribution<size_t> uniform_;
  77. int32_t getNegative(int32_t target, std::minstd_rand& rng);
  78. public:
  79. explicit NegativeSamplingLoss(
  80. std::shared_ptr<Matrix>& wo,
  81. int neg,
  82. const std::vector<int64_t>& targetCounts);
  83. ~NegativeSamplingLoss() noexcept override = default;
  84. real forward(
  85. const std::vector<int32_t>& targets,
  86. int32_t targetIndex,
  87. Model::State& state,
  88. real lr,
  89. bool backprop) override;
  90. };
  91. class HierarchicalSoftmaxLoss : public BinaryLogisticLoss {
  92. protected:
  93. struct Node {
  94. int32_t parent;
  95. int32_t left;
  96. int32_t right;
  97. int64_t count;
  98. bool binary;
  99. };
  100. std::vector<std::vector<int32_t>> paths_;
  101. std::vector<std::vector<bool>> codes_;
  102. std::vector<Node> tree_;
  103. int32_t osz_;
  104. void buildTree(const std::vector<int64_t>& counts);
  105. void dfs(
  106. int32_t k,
  107. real threshold,
  108. int32_t node,
  109. real score,
  110. Predictions& heap,
  111. const Vector& hidden) const;
  112. public:
  113. explicit HierarchicalSoftmaxLoss(
  114. std::shared_ptr<Matrix>& wo,
  115. const std::vector<int64_t>& counts);
  116. ~HierarchicalSoftmaxLoss() noexcept override = default;
  117. real forward(
  118. const std::vector<int32_t>& targets,
  119. int32_t targetIndex,
  120. Model::State& state,
  121. real lr,
  122. bool backprop) override;
  123. void predict(
  124. int32_t k,
  125. real threshold,
  126. Predictions& heap,
  127. Model::State& state) const override;
  128. };
  129. class SoftmaxLoss : public Loss {
  130. public:
  131. explicit SoftmaxLoss(std::shared_ptr<Matrix>& wo);
  132. ~SoftmaxLoss() noexcept override = default;
  133. real forward(
  134. const std::vector<int32_t>& targets,
  135. int32_t targetIndex,
  136. Model::State& state,
  137. real lr,
  138. bool backprop) override;
  139. void computeOutput(Model::State& state) const override;
  140. };
  141. } // namespace fasttext