fasttext.cc 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the BSD-style license found in the
  6. * LICENSE file in the root directory of this source tree. An additional grant
  7. * of patent rights can be found in the PATENTS file in the same directory.
  8. */
  9. #include "Matrix.h"
  10. #include "Vector.h"
  11. #include "Dictionary.h"
  12. #include "Model.h"
  13. #include "Utils.h"
  14. #include "Real.h"
  15. #include "Args.h"
  16. #include <iostream>
  17. #include <iomanip>
  18. #include <thread>
  19. #include <time.h>
  20. #include <string>
  21. #include <math.h>
  22. #include <vector>
  23. #include <atomic>
  24. #include <fenv.h>
  25. Args args;
  26. namespace info {
  27. clock_t start;
  28. std::atomic<int64_t> allWords(0);
  29. std::atomic<int64_t> allN(0);
  30. double allLoss(0.0);
  31. }
  32. void saveVectors(Dictionary& dict, Matrix& input, Matrix& output) {
  33. int32_t N = dict.getNumWords();
  34. std::wofstream ofs(args.output + ".vec");
  35. if (ofs.is_open()) {
  36. ofs << N << ' ' << args.dim << std::endl;
  37. for (int32_t i = 0; i < N; i++) {
  38. ofs << dict.getWord(i) << ' ';
  39. Vector embedding(args.dim);
  40. embedding.zero();
  41. const std::vector<int32_t>& ngrams = dict.getNgrams(i);
  42. for (auto it = ngrams.begin(); it != ngrams.end(); ++it) {
  43. embedding.addRow(input, *it);
  44. }
  45. embedding.mul(1.0 / ngrams.size());
  46. embedding.writeToStream(ofs);
  47. ofs << std::endl;
  48. }
  49. ofs.close();
  50. } else {
  51. std::wcout << "Error opening file for writing" << std::endl;
  52. }
  53. }
  54. void saveModel(Dictionary& dict, Matrix& input, Matrix& output) {
  55. std::ofstream ofs(args.output + ".bin");
  56. args.save(ofs);
  57. dict.save(ofs);
  58. input.save(ofs);
  59. output.save(ofs);
  60. ofs.close();
  61. }
  62. void loadModel(Dictionary& dict, Matrix& input, Matrix& output) {
  63. std::ifstream ifs(args.output + ".bin");
  64. args.load(ifs);
  65. dict.load(ifs);
  66. input.load(ifs);
  67. output.load(ifs);
  68. ifs.close();
  69. }
  70. void printInfo(Model& model, long long numTokens) {
  71. real progress = real(info::allWords) / (args.epoch * numTokens);
  72. real avLoss = info::allLoss / info::allN;
  73. float time = float(clock() - info::start) / CLOCKS_PER_SEC;
  74. float wst = float(info::allWords) / time;
  75. int eta = int(time / progress * (1 - progress) / args.thread);
  76. int etah = eta / 3600;
  77. int etam = (eta - etah * 3600) / 60;
  78. std::wcout << std::fixed;
  79. std::wcout << "\rProgress: " << std::setprecision(1) << 100 * progress << "%";
  80. std::wcout << " words/sec/thread: " << std::setprecision(0) << wst;
  81. std::wcout << " lr: " << std::setprecision(6) << model.getLearningRate();
  82. std::wcout << " loss: " << std::setprecision(6) << avLoss;
  83. std::wcout << " eta: " << etah << "h" << etam << "m ";
  84. std::wcout << std::flush;
  85. }
  86. void supervised(Model& model,
  87. const std::vector<int32_t>& line,
  88. const std::vector<int32_t>& labels,
  89. double& loss, int32_t& N) {
  90. if (labels.size() == 0 || line.size() == 0) return;
  91. std::uniform_int_distribution<> uniform(0, labels.size() - 1);
  92. int32_t i = uniform(model.rng);
  93. model.update(line, labels[i], loss, N);
  94. }
  95. void cbow(Dictionary& dict, Model& model,
  96. const std::vector<int32_t>& line,
  97. double& loss, int32_t& N) {
  98. int32_t n = line.size();
  99. std::vector<int32_t> bow;
  100. std::uniform_int_distribution<> uniform(1, args.ws);
  101. for (int32_t w = 0; w < n; w++) {
  102. int32_t wb = uniform(model.rng);
  103. bow.clear();
  104. for (int32_t c = -wb; c <= wb; c++) {
  105. if (c != 0 && w + c >= 0 && w + c < n) {
  106. const std::vector<int32_t>& ngrams = dict.getNgrams(line[w + c]);
  107. for (auto it = ngrams.cbegin(); it != ngrams.cend(); ++it) {
  108. bow.push_back(*it);
  109. }
  110. }
  111. }
  112. model.update(bow, line[w], loss, N);
  113. }
  114. }
  115. void skipGram(Dictionary& dict, Model& model,
  116. const std::vector<int32_t>& line,
  117. double& loss, int32_t& N) {
  118. int32_t n = line.size();
  119. std::uniform_int_distribution<> uniform(1, args.ws);
  120. for (int32_t w = 0; w < n; w++) {
  121. int32_t wb = uniform(model.rng);
  122. const std::vector<int32_t>& ngrams = dict.getNgrams(line[w]);
  123. for (int32_t c = -wb; c <= wb; c++) {
  124. if (c != 0 && w + c >= 0 && w + c < n) {
  125. int32_t target = line[w + c];
  126. model.update(ngrams, target, loss, N);
  127. }
  128. }
  129. }
  130. }
  131. void test(Dictionary& dict, Model& model) {
  132. int32_t N = 0;
  133. double precision = 0.0;
  134. std::vector<int32_t> line, labels;
  135. std::wifstream ifs(args.test);
  136. while (!ifs.eof()) {
  137. dict.getLine(ifs, line, labels, model.rng);
  138. dict.addNgrams(line, args.wordNgrams);
  139. if (labels.size() > 0 && line.size() > 0) {
  140. int32_t i = model.predict(line);
  141. for (auto& t : labels) {
  142. if (i == t) {
  143. precision += 1.0;
  144. break;
  145. }
  146. }
  147. N++;
  148. }
  149. }
  150. ifs.close();
  151. std::wcout << std::setprecision(3) << "P@1: " << precision / N << std::endl;
  152. std::wcout << std::setprecision(3) << "Sentences: " << N << std::endl;
  153. }
  154. void thread_function(Dictionary& dict, Matrix& input, Matrix& output,
  155. int32_t threadId) {
  156. std::wifstream ifs(args.input);
  157. utils::seek(ifs, threadId * utils::size(ifs) / args.thread);
  158. Model model(input, output, args.dim, args.lr, threadId);
  159. if (args.model == model_name::sup) {
  160. model.setLabelFreq(dict.getLabelFreq());
  161. } else {
  162. model.setLabelFreq(dict.getWordFreq());
  163. }
  164. const int64_t ntokens = dict.getNumTokens();
  165. int64_t tokenCount = 0;
  166. int64_t prevTokenCount = 0;
  167. double loss = 0.0;
  168. int32_t N = 0;
  169. std::vector<int32_t> line, labels;
  170. while (info::allWords < args.epoch * ntokens) {
  171. tokenCount += dict.getLine(ifs, line, labels, model.rng);
  172. if (args.model == model_name::sup) {
  173. dict.addNgrams(line, args.wordNgrams);
  174. supervised(model, line, labels, loss, N);
  175. } else if (args.model == model_name::cbow) {
  176. cbow(dict, model, line, loss, N);
  177. } else if (args.model == model_name::sg) {
  178. skipGram(dict, model, line, loss, N);
  179. }
  180. if (tokenCount - prevTokenCount > 10000) {
  181. info::allWords += tokenCount - prevTokenCount;
  182. prevTokenCount = tokenCount;
  183. info::allLoss += loss;
  184. info::allN += N;
  185. loss = 0.0;
  186. N = 0;
  187. real progress = real(info::allWords) / (args.epoch * ntokens);
  188. model.setLearningRate(args.lr * (1.0 - progress));
  189. if (threadId == 0) printInfo(model, ntokens);
  190. }
  191. }
  192. if (threadId == 0) {
  193. printInfo(model, ntokens);
  194. std::wcout << std::endl;
  195. }
  196. if (args.model == model_name::sup && threadId == 0) {
  197. test(dict, model);
  198. }
  199. ifs.close();
  200. }
  201. int main(int argc, char** argv) {
  202. std::locale::global(std::locale(""));
  203. args.parseArgs(argc, argv);
  204. utils::initTables();
  205. Dictionary dict;
  206. dict.readFromFile(args.input);
  207. Matrix input(dict.getNumWords() + args.bucket, args.dim);
  208. Matrix output;
  209. if (args.model == model_name::sup) {
  210. output = Matrix(dict.getNumLabels(), args.dim);
  211. } else {
  212. output = Matrix(dict.getNumWords(), args.dim);
  213. }
  214. input.uniform(1.0 / args.dim);
  215. output.zero();
  216. info::start = clock();
  217. std::vector<std::thread> threads;
  218. for (int32_t i = 0; i < args.thread; i++) {
  219. threads.push_back(std::thread(&thread_function, std::ref(dict),
  220. std::ref(input), std::ref(output), i));
  221. }
  222. for (auto it = threads.begin(); it != threads.end(); ++it) {
  223. it->join();
  224. }
  225. std::wcout << "training took: "
  226. << float(clock() - info::start) / CLOCKS_PER_SEC / args.thread
  227. << " s" << std::endl;
  228. if (args.output.size() != 0) {
  229. saveModel(dict, input, output);
  230. saveVectors(dict, input, output);
  231. }
  232. utils::freeTables();
  233. return 0;
  234. }