loss.cc 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include "loss.h"
  9. #include "utils.h"
  10. #include <cmath>
  11. namespace fasttext {
  12. constexpr int64_t SIGMOID_TABLE_SIZE = 512;
  13. constexpr int64_t MAX_SIGMOID = 8;
  14. constexpr int64_t LOG_TABLE_SIZE = 512;
  15. bool comparePairs(
  16. const std::pair<real, int32_t>& l,
  17. const std::pair<real, int32_t>& r) {
  18. return l.first > r.first;
  19. }
  20. real std_log(real x) {
  21. return std::log(x + 1e-5);
  22. }
  23. Loss::Loss(std::shared_ptr<Matrix>& wo) : wo_(wo) {
  24. t_sigmoid_.reserve(SIGMOID_TABLE_SIZE + 1);
  25. for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
  26. real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
  27. t_sigmoid_.push_back(1.0 / (1.0 + std::exp(-x)));
  28. }
  29. t_log_.reserve(LOG_TABLE_SIZE + 1);
  30. for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
  31. real x = (real(i) + 1e-5) / LOG_TABLE_SIZE;
  32. t_log_.push_back(std::log(x));
  33. }
  34. }
  35. real Loss::log(real x) const {
  36. if (x > 1.0) {
  37. return 0.0;
  38. }
  39. int64_t i = int64_t(x * LOG_TABLE_SIZE);
  40. return t_log_[i];
  41. }
  42. real Loss::sigmoid(real x) const {
  43. if (x < -MAX_SIGMOID) {
  44. return 0.0;
  45. } else if (x > MAX_SIGMOID) {
  46. return 1.0;
  47. } else {
  48. int64_t i =
  49. int64_t((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
  50. return t_sigmoid_[i];
  51. }
  52. }
  53. void Loss::predict(
  54. int32_t k,
  55. real threshold,
  56. Predictions& heap,
  57. Model::State& state) const {
  58. computeOutput(state);
  59. findKBest(k, threshold, heap, state.output);
  60. std::sort_heap(heap.begin(), heap.end(), comparePairs);
  61. }
  62. void Loss::findKBest(
  63. int32_t k,
  64. real threshold,
  65. Predictions& heap,
  66. const Vector& output) const {
  67. for (int32_t i = 0; i < output.size(); i++) {
  68. if (output[i] < threshold) {
  69. continue;
  70. }
  71. if (heap.size() == k && std_log(output[i]) < heap.front().first) {
  72. continue;
  73. }
  74. heap.push_back(std::make_pair(std_log(output[i]), i));
  75. std::push_heap(heap.begin(), heap.end(), comparePairs);
  76. if (heap.size() > k) {
  77. std::pop_heap(heap.begin(), heap.end(), comparePairs);
  78. heap.pop_back();
  79. }
  80. }
  81. }
  82. BinaryLogisticLoss::BinaryLogisticLoss(std::shared_ptr<Matrix>& wo)
  83. : Loss(wo) {}
  84. real BinaryLogisticLoss::binaryLogistic(
  85. int32_t target,
  86. Model::State& state,
  87. bool labelIsPositive,
  88. real lr,
  89. bool backprop) const {
  90. real score = sigmoid(wo_->dotRow(state.hidden, target));
  91. if (backprop) {
  92. real alpha = lr * (real(labelIsPositive) - score);
  93. state.grad.addRow(*wo_, target, alpha);
  94. wo_->addVectorToRow(state.hidden, target, alpha);
  95. }
  96. if (labelIsPositive) {
  97. return -log(score);
  98. } else {
  99. return -log(1.0 - score);
  100. }
  101. }
  102. void BinaryLogisticLoss::computeOutput(Model::State& state) const {
  103. Vector& output = state.output;
  104. output.mul(*wo_, state.hidden);
  105. int32_t osz = output.size();
  106. for (int32_t i = 0; i < osz; i++) {
  107. output[i] = sigmoid(output[i]);
  108. }
  109. }
  110. OneVsAllLoss::OneVsAllLoss(std::shared_ptr<Matrix>& wo)
  111. : BinaryLogisticLoss(wo) {}
  112. real OneVsAllLoss::forward(
  113. const std::vector<int32_t>& targets,
  114. int32_t /* we take all targets here */,
  115. Model::State& state,
  116. real lr,
  117. bool backprop) {
  118. real loss = 0.0;
  119. int32_t osz = state.output.size();
  120. for (int32_t i = 0; i < osz; i++) {
  121. bool isMatch = utils::contains(targets, i);
  122. loss += binaryLogistic(i, state, isMatch, lr, backprop);
  123. }
  124. return loss;
  125. }
  126. NegativeSamplingLoss::NegativeSamplingLoss(
  127. std::shared_ptr<Matrix>& wo,
  128. int neg,
  129. const std::vector<int64_t>& targetCounts)
  130. : BinaryLogisticLoss(wo), neg_(neg), negatives_(), uniform_() {
  131. real z = 0.0;
  132. for (size_t i = 0; i < targetCounts.size(); i++) {
  133. z += pow(targetCounts[i], 0.5);
  134. }
  135. for (size_t i = 0; i < targetCounts.size(); i++) {
  136. real c = pow(targetCounts[i], 0.5);
  137. for (size_t j = 0; j < c * NegativeSamplingLoss::NEGATIVE_TABLE_SIZE / z;
  138. j++) {
  139. negatives_.push_back(i);
  140. }
  141. }
  142. uniform_ = std::uniform_int_distribution<size_t>(0, negatives_.size() - 1);
  143. }
  144. real NegativeSamplingLoss::forward(
  145. const std::vector<int32_t>& targets,
  146. int32_t targetIndex,
  147. Model::State& state,
  148. real lr,
  149. bool backprop) {
  150. assert(targetIndex >= 0);
  151. assert(targetIndex < targets.size());
  152. int32_t target = targets[targetIndex];
  153. real loss = binaryLogistic(target, state, true, lr, backprop);
  154. for (int32_t n = 0; n < neg_; n++) {
  155. auto negativeTarget = getNegative(target, state.rng);
  156. loss += binaryLogistic(negativeTarget, state, false, lr, backprop);
  157. }
  158. return loss;
  159. }
  160. int32_t NegativeSamplingLoss::getNegative(
  161. int32_t target,
  162. std::minstd_rand& rng) {
  163. int32_t negative;
  164. do {
  165. negative = negatives_[uniform_(rng)];
  166. } while (target == negative);
  167. return negative;
  168. }
  169. HierarchicalSoftmaxLoss::HierarchicalSoftmaxLoss(
  170. std::shared_ptr<Matrix>& wo,
  171. const std::vector<int64_t>& targetCounts)
  172. : BinaryLogisticLoss(wo),
  173. paths_(),
  174. codes_(),
  175. tree_(),
  176. osz_(targetCounts.size()) {
  177. buildTree(targetCounts);
  178. }
  179. void HierarchicalSoftmaxLoss::buildTree(const std::vector<int64_t>& counts) {
  180. tree_.resize(2 * osz_ - 1);
  181. for (int32_t i = 0; i < 2 * osz_ - 1; i++) {
  182. tree_[i].parent = -1;
  183. tree_[i].left = -1;
  184. tree_[i].right = -1;
  185. tree_[i].count = 1e15;
  186. tree_[i].binary = false;
  187. }
  188. for (int32_t i = 0; i < osz_; i++) {
  189. tree_[i].count = counts[i];
  190. }
  191. int32_t leaf = osz_ - 1;
  192. int32_t node = osz_;
  193. for (int32_t i = osz_; i < 2 * osz_ - 1; i++) {
  194. int32_t mini[2] = {0};
  195. for (int32_t j = 0; j < 2; j++) {
  196. if (leaf >= 0 && tree_[leaf].count < tree_[node].count) {
  197. mini[j] = leaf--;
  198. } else {
  199. mini[j] = node++;
  200. }
  201. }
  202. tree_[i].left = mini[0];
  203. tree_[i].right = mini[1];
  204. tree_[i].count = tree_[mini[0]].count + tree_[mini[1]].count;
  205. tree_[mini[0]].parent = i;
  206. tree_[mini[1]].parent = i;
  207. tree_[mini[1]].binary = true;
  208. }
  209. for (int32_t i = 0; i < osz_; i++) {
  210. std::vector<int32_t> path;
  211. std::vector<bool> code;
  212. int32_t j = i;
  213. while (tree_[j].parent != -1) {
  214. path.push_back(tree_[j].parent - osz_);
  215. code.push_back(tree_[j].binary);
  216. j = tree_[j].parent;
  217. }
  218. paths_.push_back(path);
  219. codes_.push_back(code);
  220. }
  221. }
  222. real HierarchicalSoftmaxLoss::forward(
  223. const std::vector<int32_t>& targets,
  224. int32_t targetIndex,
  225. Model::State& state,
  226. real lr,
  227. bool backprop) {
  228. real loss = 0.0;
  229. int32_t target = targets[targetIndex];
  230. const std::vector<bool>& binaryCode = codes_[target];
  231. const std::vector<int32_t>& pathToRoot = paths_[target];
  232. for (int32_t i = 0; i < pathToRoot.size(); i++) {
  233. loss += binaryLogistic(pathToRoot[i], state, binaryCode[i], lr, backprop);
  234. }
  235. return loss;
  236. }
  237. void HierarchicalSoftmaxLoss::predict(
  238. int32_t k,
  239. real threshold,
  240. Predictions& heap,
  241. Model::State& state) const {
  242. dfs(k, threshold, 2 * osz_ - 2, 0.0, heap, state.hidden);
  243. std::sort_heap(heap.begin(), heap.end(), comparePairs);
  244. }
  245. void HierarchicalSoftmaxLoss::dfs(
  246. int32_t k,
  247. real threshold,
  248. int32_t node,
  249. real score,
  250. Predictions& heap,
  251. const Vector& hidden) const {
  252. if (score < std_log(threshold)) {
  253. return;
  254. }
  255. if (heap.size() == k && score < heap.front().first) {
  256. return;
  257. }
  258. if (tree_[node].left == -1 && tree_[node].right == -1) {
  259. heap.push_back(std::make_pair(score, node));
  260. std::push_heap(heap.begin(), heap.end(), comparePairs);
  261. if (heap.size() > k) {
  262. std::pop_heap(heap.begin(), heap.end(), comparePairs);
  263. heap.pop_back();
  264. }
  265. return;
  266. }
  267. real f = wo_->dotRow(hidden, node - osz_);
  268. f = 1. / (1 + std::exp(-f));
  269. dfs(k, threshold, tree_[node].left, score + std_log(1.0 - f), heap, hidden);
  270. dfs(k, threshold, tree_[node].right, score + std_log(f), heap, hidden);
  271. }
  272. SoftmaxLoss::SoftmaxLoss(std::shared_ptr<Matrix>& wo) : Loss(wo) {}
  273. void SoftmaxLoss::computeOutput(Model::State& state) const {
  274. Vector& output = state.output;
  275. output.mul(*wo_, state.hidden);
  276. real max = output[0], z = 0.0;
  277. int32_t osz = output.size();
  278. for (int32_t i = 0; i < osz; i++) {
  279. max = std::max(output[i], max);
  280. }
  281. for (int32_t i = 0; i < osz; i++) {
  282. output[i] = exp(output[i] - max);
  283. z += output[i];
  284. }
  285. for (int32_t i = 0; i < osz; i++) {
  286. output[i] /= z;
  287. }
  288. }
  289. real SoftmaxLoss::forward(
  290. const std::vector<int32_t>& targets,
  291. int32_t targetIndex,
  292. Model::State& state,
  293. real lr,
  294. bool backprop) {
  295. computeOutput(state);
  296. assert(targetIndex >= 0);
  297. assert(targetIndex < targets.size());
  298. int32_t target = targets[targetIndex];
  299. if (backprop) {
  300. int32_t osz = wo_->size(0);
  301. for (int32_t i = 0; i < osz; i++) {
  302. real label = (i == target) ? 1.0 : 0.0;
  303. real alpha = lr * (label - state.output[i]);
  304. state.grad.addRow(*wo_, i, alpha);
  305. wo_->addVectorToRow(state.hidden, i, alpha);
  306. }
  307. }
  308. return -log(state.output[target]);
  309. };
  310. } // namespace fasttext