args.cc 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include "args.h"
  9. #include <stdlib.h>
  10. #include <iostream>
  11. #include <stdexcept>
  12. #include <string>
  13. #include <unordered_map>
  14. namespace fasttext {
  15. Args::Args() {
  16. lr = 0.05;
  17. dim = 100;
  18. ws = 5;
  19. epoch = 5;
  20. minCount = 5;
  21. minCountLabel = 0;
  22. neg = 5;
  23. wordNgrams = 1;
  24. loss = loss_name::ns;
  25. model = model_name::sg;
  26. bucket = 2000000;
  27. minn = 3;
  28. maxn = 6;
  29. thread = 12;
  30. lrUpdateRate = 100;
  31. t = 1e-4;
  32. label = "__label__";
  33. verbose = 2;
  34. pretrainedVectors = "";
  35. saveOutput = false;
  36. seed = 0;
  37. qout = false;
  38. retrain = false;
  39. qnorm = false;
  40. cutoff = 0;
  41. dsub = 2;
  42. autotuneValidationFile = "";
  43. autotuneMetric = "f1";
  44. autotunePredictions = 1;
  45. autotuneDuration = 60 * 5; // 5 minutes
  46. autotuneModelSize = "";
  47. }
  48. std::string Args::lossToString(loss_name ln) const {
  49. switch (ln) {
  50. case loss_name::hs:
  51. return "hs";
  52. case loss_name::ns:
  53. return "ns";
  54. case loss_name::softmax:
  55. return "softmax";
  56. case loss_name::ova:
  57. return "one-vs-all";
  58. }
  59. return "Unknown loss!"; // should never happen
  60. }
  61. std::string Args::boolToString(bool b) const {
  62. if (b) {
  63. return "true";
  64. } else {
  65. return "false";
  66. }
  67. }
  68. std::string Args::modelToString(model_name mn) const {
  69. switch (mn) {
  70. case model_name::cbow:
  71. return "cbow";
  72. case model_name::sg:
  73. return "sg";
  74. case model_name::sup:
  75. return "sup";
  76. }
  77. return "Unknown model name!"; // should never happen
  78. }
  79. std::string Args::metricToString(metric_name mn) const {
  80. switch (mn) {
  81. case metric_name::f1score:
  82. return "f1score";
  83. case metric_name::f1scoreLabel:
  84. return "f1scoreLabel";
  85. case metric_name::precisionAtRecall:
  86. return "precisionAtRecall";
  87. case metric_name::precisionAtRecallLabel:
  88. return "precisionAtRecallLabel";
  89. case metric_name::recallAtPrecision:
  90. return "recallAtPrecision";
  91. case metric_name::recallAtPrecisionLabel:
  92. return "recallAtPrecisionLabel";
  93. }
  94. return "Unknown metric name!"; // should never happen
  95. }
  96. void Args::parseArgs(const std::vector<std::string>& args) {
  97. std::string command(args[1]);
  98. if (command == "supervised") {
  99. model = model_name::sup;
  100. loss = loss_name::softmax;
  101. minCount = 1;
  102. minn = 0;
  103. maxn = 0;
  104. lr = 0.1;
  105. } else if (command == "cbow") {
  106. model = model_name::cbow;
  107. }
  108. for (int ai = 2; ai < args.size(); ai += 2) {
  109. if (args[ai][0] != '-') {
  110. std::cerr << "Provided argument without a dash! Usage:" << std::endl;
  111. printHelp();
  112. exit(EXIT_FAILURE);
  113. }
  114. try {
  115. setManual(args[ai].substr(1));
  116. if (args[ai] == "-h") {
  117. std::cerr << "Here is the help! Usage:" << std::endl;
  118. printHelp();
  119. exit(EXIT_FAILURE);
  120. } else if (args[ai] == "-input") {
  121. input = std::string(args.at(ai + 1));
  122. } else if (args[ai] == "-output") {
  123. output = std::string(args.at(ai + 1));
  124. } else if (args[ai] == "-lr") {
  125. lr = std::stof(args.at(ai + 1));
  126. } else if (args[ai] == "-lrUpdateRate") {
  127. lrUpdateRate = std::stoi(args.at(ai + 1));
  128. } else if (args[ai] == "-dim") {
  129. dim = std::stoi(args.at(ai + 1));
  130. } else if (args[ai] == "-ws") {
  131. ws = std::stoi(args.at(ai + 1));
  132. } else if (args[ai] == "-epoch") {
  133. epoch = std::stoi(args.at(ai + 1));
  134. } else if (args[ai] == "-minCount") {
  135. minCount = std::stoi(args.at(ai + 1));
  136. } else if (args[ai] == "-minCountLabel") {
  137. minCountLabel = std::stoi(args.at(ai + 1));
  138. } else if (args[ai] == "-neg") {
  139. neg = std::stoi(args.at(ai + 1));
  140. } else if (args[ai] == "-wordNgrams") {
  141. wordNgrams = std::stoi(args.at(ai + 1));
  142. } else if (args[ai] == "-loss") {
  143. if (args.at(ai + 1) == "hs") {
  144. loss = loss_name::hs;
  145. } else if (args.at(ai + 1) == "ns") {
  146. loss = loss_name::ns;
  147. } else if (args.at(ai + 1) == "softmax") {
  148. loss = loss_name::softmax;
  149. } else if (
  150. args.at(ai + 1) == "one-vs-all" || args.at(ai + 1) == "ova") {
  151. loss = loss_name::ova;
  152. } else {
  153. std::cerr << "Unknown loss: " << args.at(ai + 1) << std::endl;
  154. printHelp();
  155. exit(EXIT_FAILURE);
  156. }
  157. } else if (args[ai] == "-bucket") {
  158. bucket = std::stoi(args.at(ai + 1));
  159. } else if (args[ai] == "-minn") {
  160. minn = std::stoi(args.at(ai + 1));
  161. } else if (args[ai] == "-maxn") {
  162. maxn = std::stoi(args.at(ai + 1));
  163. } else if (args[ai] == "-thread") {
  164. thread = std::stoi(args.at(ai + 1));
  165. } else if (args[ai] == "-t") {
  166. t = std::stof(args.at(ai + 1));
  167. } else if (args[ai] == "-label") {
  168. label = std::string(args.at(ai + 1));
  169. } else if (args[ai] == "-verbose") {
  170. verbose = std::stoi(args.at(ai + 1));
  171. } else if (args[ai] == "-pretrainedVectors") {
  172. pretrainedVectors = std::string(args.at(ai + 1));
  173. } else if (args[ai] == "-saveOutput") {
  174. saveOutput = true;
  175. ai--;
  176. } else if (args[ai] == "-seed") {
  177. seed = std::stoi(args.at(ai + 1));
  178. } else if (args[ai] == "-qnorm") {
  179. qnorm = true;
  180. ai--;
  181. } else if (args[ai] == "-retrain") {
  182. retrain = true;
  183. ai--;
  184. } else if (args[ai] == "-qout") {
  185. qout = true;
  186. ai--;
  187. } else if (args[ai] == "-cutoff") {
  188. cutoff = std::stoi(args.at(ai + 1));
  189. } else if (args[ai] == "-dsub") {
  190. dsub = std::stoi(args.at(ai + 1));
  191. } else if (args[ai] == "-autotune-validation") {
  192. autotuneValidationFile = std::string(args.at(ai + 1));
  193. } else if (args[ai] == "-autotune-metric") {
  194. autotuneMetric = std::string(args.at(ai + 1));
  195. getAutotuneMetric(); // throws exception if not able to parse
  196. getAutotuneMetricLabel(); // throws exception if not able to parse
  197. } else if (args[ai] == "-autotune-predictions") {
  198. autotunePredictions = std::stoi(args.at(ai + 1));
  199. } else if (args[ai] == "-autotune-duration") {
  200. autotuneDuration = std::stoi(args.at(ai + 1));
  201. } else if (args[ai] == "-autotune-modelsize") {
  202. autotuneModelSize = std::string(args.at(ai + 1));
  203. } else {
  204. std::cerr << "Unknown argument: " << args[ai] << std::endl;
  205. printHelp();
  206. exit(EXIT_FAILURE);
  207. }
  208. } catch (std::out_of_range) {
  209. std::cerr << args[ai] << " is missing an argument" << std::endl;
  210. printHelp();
  211. exit(EXIT_FAILURE);
  212. }
  213. }
  214. if (input.empty() || output.empty()) {
  215. std::cerr << "Empty input or output path." << std::endl;
  216. printHelp();
  217. exit(EXIT_FAILURE);
  218. }
  219. if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) {
  220. bucket = 0;
  221. }
  222. }
  223. void Args::printHelp() {
  224. printBasicHelp();
  225. printDictionaryHelp();
  226. printTrainingHelp();
  227. printAutotuneHelp();
  228. printQuantizationHelp();
  229. }
  230. void Args::printBasicHelp() {
  231. std::cerr << "\nThe following arguments are mandatory:\n"
  232. << " -input training file path\n"
  233. << " -output output file path\n"
  234. << "\nThe following arguments are optional:\n"
  235. << " -verbose verbosity level [" << verbose << "]\n";
  236. }
  237. void Args::printDictionaryHelp() {
  238. std::cerr << "\nThe following arguments for the dictionary are optional:\n"
  239. << " -minCount minimal number of word occurences ["
  240. << minCount << "]\n"
  241. << " -minCountLabel minimal number of label occurences ["
  242. << minCountLabel << "]\n"
  243. << " -wordNgrams max length of word ngram [" << wordNgrams
  244. << "]\n"
  245. << " -bucket number of buckets [" << bucket << "]\n"
  246. << " -minn min length of char ngram [" << minn
  247. << "]\n"
  248. << " -maxn max length of char ngram [" << maxn
  249. << "]\n"
  250. << " -t sampling threshold [" << t << "]\n"
  251. << " -label labels prefix [" << label << "]\n";
  252. }
  253. void Args::printTrainingHelp() {
  254. std::cerr
  255. << "\nThe following arguments for training are optional:\n"
  256. << " -lr learning rate [" << lr << "]\n"
  257. << " -lrUpdateRate change the rate of updates for the learning "
  258. "rate ["
  259. << lrUpdateRate << "]\n"
  260. << " -dim size of word vectors [" << dim << "]\n"
  261. << " -ws size of the context window [" << ws << "]\n"
  262. << " -epoch number of epochs [" << epoch << "]\n"
  263. << " -neg number of negatives sampled [" << neg << "]\n"
  264. << " -loss loss function {ns, hs, softmax, one-vs-all} ["
  265. << lossToString(loss) << "]\n"
  266. << " -thread number of threads (set to 1 to ensure "
  267. "reproducible results) ["
  268. << thread << "]\n"
  269. << " -pretrainedVectors pretrained word vectors for supervised "
  270. "learning ["
  271. << pretrainedVectors << "]\n"
  272. << " -saveOutput whether output params should be saved ["
  273. << boolToString(saveOutput) << "]\n"
  274. << " -seed random generator seed [" << seed << "]\n";
  275. }
  276. void Args::printAutotuneHelp() {
  277. std::cerr << "\nThe following arguments are for autotune:\n"
  278. << " -autotune-validation validation file to be used "
  279. "for evaluation\n"
  280. << " -autotune-metric metric objective {f1, "
  281. "f1:labelname} ["
  282. << autotuneMetric << "]\n"
  283. << " -autotune-predictions number of predictions used "
  284. "for evaluation ["
  285. << autotunePredictions << "]\n"
  286. << " -autotune-duration maximum duration in seconds ["
  287. << autotuneDuration << "]\n"
  288. << " -autotune-modelsize constraint model file size ["
  289. << autotuneModelSize << "] (empty = do not quantize)\n";
  290. }
  291. void Args::printQuantizationHelp() {
  292. std::cerr
  293. << "\nThe following arguments for quantization are optional:\n"
  294. << " -cutoff number of words and ngrams to retain ["
  295. << cutoff << "]\n"
  296. << " -retrain whether embeddings are finetuned if a cutoff "
  297. "is applied ["
  298. << boolToString(retrain) << "]\n"
  299. << " -qnorm whether the norm is quantized separately ["
  300. << boolToString(qnorm) << "]\n"
  301. << " -qout whether the classifier is quantized ["
  302. << boolToString(qout) << "]\n"
  303. << " -dsub size of each sub-vector [" << dsub << "]\n";
  304. }
  305. void Args::save(std::ostream& out) {
  306. out.write((char*)&(dim), sizeof(int));
  307. out.write((char*)&(ws), sizeof(int));
  308. out.write((char*)&(epoch), sizeof(int));
  309. out.write((char*)&(minCount), sizeof(int));
  310. out.write((char*)&(neg), sizeof(int));
  311. out.write((char*)&(wordNgrams), sizeof(int));
  312. out.write((char*)&(loss), sizeof(loss_name));
  313. out.write((char*)&(model), sizeof(model_name));
  314. out.write((char*)&(bucket), sizeof(int));
  315. out.write((char*)&(minn), sizeof(int));
  316. out.write((char*)&(maxn), sizeof(int));
  317. out.write((char*)&(lrUpdateRate), sizeof(int));
  318. out.write((char*)&(t), sizeof(double));
  319. }
  320. void Args::load(std::istream& in) {
  321. in.read((char*)&(dim), sizeof(int));
  322. in.read((char*)&(ws), sizeof(int));
  323. in.read((char*)&(epoch), sizeof(int));
  324. in.read((char*)&(minCount), sizeof(int));
  325. in.read((char*)&(neg), sizeof(int));
  326. in.read((char*)&(wordNgrams), sizeof(int));
  327. in.read((char*)&(loss), sizeof(loss_name));
  328. in.read((char*)&(model), sizeof(model_name));
  329. in.read((char*)&(bucket), sizeof(int));
  330. in.read((char*)&(minn), sizeof(int));
  331. in.read((char*)&(maxn), sizeof(int));
  332. in.read((char*)&(lrUpdateRate), sizeof(int));
  333. in.read((char*)&(t), sizeof(double));
  334. }
  335. void Args::dump(std::ostream& out) const {
  336. out << "dim"
  337. << " " << dim << std::endl;
  338. out << "ws"
  339. << " " << ws << std::endl;
  340. out << "epoch"
  341. << " " << epoch << std::endl;
  342. out << "minCount"
  343. << " " << minCount << std::endl;
  344. out << "neg"
  345. << " " << neg << std::endl;
  346. out << "wordNgrams"
  347. << " " << wordNgrams << std::endl;
  348. out << "loss"
  349. << " " << lossToString(loss) << std::endl;
  350. out << "model"
  351. << " " << modelToString(model) << std::endl;
  352. out << "bucket"
  353. << " " << bucket << std::endl;
  354. out << "minn"
  355. << " " << minn << std::endl;
  356. out << "maxn"
  357. << " " << maxn << std::endl;
  358. out << "lrUpdateRate"
  359. << " " << lrUpdateRate << std::endl;
  360. out << "t"
  361. << " " << t << std::endl;
  362. }
  363. bool Args::hasAutotune() const {
  364. return !autotuneValidationFile.empty();
  365. }
  366. bool Args::isManual(const std::string& argName) const {
  367. return (manualArgs_.count(argName) != 0);
  368. }
  369. void Args::setManual(const std::string& argName) {
  370. manualArgs_.emplace(argName);
  371. }
  372. metric_name Args::getAutotuneMetric() const {
  373. if (autotuneMetric.substr(0, 3) == "f1:") {
  374. return metric_name::f1scoreLabel;
  375. } else if (autotuneMetric == "f1") {
  376. return metric_name::f1score;
  377. } else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
  378. size_t semicolon = autotuneMetric.find(":", 18);
  379. if (semicolon != std::string::npos) {
  380. return metric_name::precisionAtRecallLabel;
  381. }
  382. return metric_name::precisionAtRecall;
  383. } else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
  384. size_t semicolon = autotuneMetric.find(":", 18);
  385. if (semicolon != std::string::npos) {
  386. return metric_name::recallAtPrecisionLabel;
  387. }
  388. return metric_name::recallAtPrecision;
  389. }
  390. throw std::runtime_error("Unknown metric : " + autotuneMetric);
  391. }
  392. std::string Args::getAutotuneMetricLabel() const {
  393. metric_name metric = getAutotuneMetric();
  394. std::string label;
  395. if (metric == metric_name::f1scoreLabel) {
  396. label = autotuneMetric.substr(3);
  397. } else if (
  398. metric == metric_name::precisionAtRecallLabel ||
  399. metric == metric_name::recallAtPrecisionLabel) {
  400. size_t semicolon = autotuneMetric.find(":", 18);
  401. label = autotuneMetric.substr(semicolon + 1);
  402. } else {
  403. return label;
  404. }
  405. if (label.empty()) {
  406. throw std::runtime_error("Empty metric label : " + autotuneMetric);
  407. }
  408. return label;
  409. }
  410. double Args::getAutotuneMetricValue() const {
  411. metric_name metric = getAutotuneMetric();
  412. double value = 0.0;
  413. if (metric == metric_name::precisionAtRecallLabel ||
  414. metric == metric_name::precisionAtRecall ||
  415. metric == metric_name::recallAtPrecisionLabel ||
  416. metric == metric_name::recallAtPrecision) {
  417. size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
  418. size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
  419. const std::string valueStr =
  420. autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
  421. value = std::stof(valueStr) / 100.0;
  422. }
  423. return value;
  424. }
  425. int64_t Args::getAutotuneModelSize() const {
  426. std::string modelSize = autotuneModelSize;
  427. if (modelSize.empty()) {
  428. return Args::kUnlimitedModelSize;
  429. }
  430. std::unordered_map<char, int> units = {
  431. {'k', 1000},
  432. {'K', 1000},
  433. {'m', 1000000},
  434. {'M', 1000000},
  435. {'g', 1000000000},
  436. {'G', 1000000000},
  437. };
  438. uint64_t multiplier = 1;
  439. char lastCharacter = modelSize.back();
  440. if (units.count(lastCharacter)) {
  441. multiplier = units[lastCharacter];
  442. modelSize = modelSize.substr(0, modelSize.size() - 1);
  443. }
  444. uint64_t size = 0;
  445. size_t nonNumericCharacter = 0;
  446. bool parseError = false;
  447. try {
  448. size = std::stol(modelSize, &nonNumericCharacter);
  449. } catch (std::invalid_argument&) {
  450. parseError = true;
  451. }
  452. if (!parseError && nonNumericCharacter != modelSize.size()) {
  453. parseError = true;
  454. }
  455. if (parseError) {
  456. throw std::invalid_argument(
  457. "Unable to parse model size " + autotuneModelSize);
  458. }
  459. return size * multiplier;
  460. }
  461. } // namespace fasttext