vector.cc 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the BSD-style license found in the
  6. * LICENSE file in the root directory of this source tree. An additional grant
  7. * of patent rights can be found in the PATENTS file in the same directory.
  8. */
  9. #include "vector.h"
  10. #include <assert.h>
  11. #include <iomanip>
  12. #include "matrix.h"
  13. #include "qmatrix.h"
  14. namespace fasttext {
  15. Vector::Vector(int64_t m) {
  16. m_ = m;
  17. data_ = new real[m];
  18. }
  19. Vector::~Vector() {
  20. delete[] data_;
  21. }
  22. int64_t Vector::size() const {
  23. return m_;
  24. }
  25. void Vector::zero() {
  26. for (int64_t i = 0; i < m_; i++) {
  27. data_[i] = 0.0;
  28. }
  29. }
  30. real Vector::norm() {
  31. real sum = 0;
  32. for (int64_t i = 0; i < m_; i++) {
  33. sum += data_[i] * data_[i];
  34. }
  35. return std::sqrt(sum);
  36. }
  37. void Vector::mul(real a) {
  38. for (int64_t i = 0; i < m_; i++) {
  39. data_[i] *= a;
  40. }
  41. }
  42. void Vector::addVector(const Vector& source) {
  43. assert(m_ == source.m_);
  44. for (int64_t i = 0; i < m_; i++) {
  45. data_[i] += source.data_[i];
  46. }
  47. }
  48. void Vector::addRow(const Matrix& A, int64_t i) {
  49. assert(i >= 0);
  50. assert(i < A.m_);
  51. assert(m_ == A.n_);
  52. for (int64_t j = 0; j < A.n_; j++) {
  53. data_[j] += A.at(i, j);
  54. }
  55. }
  56. void Vector::addRow(const Matrix& A, int64_t i, real a) {
  57. assert(i >= 0);
  58. assert(i < A.m_);
  59. assert(m_ == A.n_);
  60. for (int64_t j = 0; j < A.n_; j++) {
  61. data_[j] += a * A.at(i, j);
  62. }
  63. }
  64. void Vector::addRow(const QMatrix& A, int64_t i) {
  65. assert(i >= 0);
  66. A.addToVector(*this, i);
  67. }
  68. void Vector::mul(const Matrix& A, const Vector& vec) {
  69. assert(A.m_ == m_);
  70. assert(A.n_ == vec.m_);
  71. for (int64_t i = 0; i < m_; i++) {
  72. data_[i] = A.dotRow(vec, i);
  73. }
  74. }
  75. void Vector::mul(const QMatrix& A, const Vector& vec) {
  76. assert(A.m_ == m_);
  77. assert(A.n_ == vec.m_);
  78. for (int64_t i = 0; i < m_; i++) {
  79. data_[i] = A.dotRow(vec, i);
  80. }
  81. }
  82. int64_t Vector::argmax() {
  83. real max = data_[0];
  84. int64_t argmax = 0;
  85. for (int64_t i = 1; i < m_; i++) {
  86. if (data_[i] > max) {
  87. max = data_[i];
  88. argmax = i;
  89. }
  90. }
  91. return argmax;
  92. }
  93. real& Vector::operator[](int64_t i) {
  94. return data_[i];
  95. }
  96. const real& Vector::operator[](int64_t i) const {
  97. return data_[i];
  98. }
  99. std::ostream& operator<<(std::ostream& os, const Vector& v)
  100. {
  101. os << std::setprecision(5);
  102. for (int64_t j = 0; j < v.m_; j++) {
  103. os << v.data_[j] << ' ';
  104. }
  105. return os;
  106. }
  107. }