1
0

eval.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. /**
  2. * Copyright (c) 2017-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #include <unordered_map>
  9. #include <iostream>
  10. #include <fstream>
  11. #include <string>
  12. #include <vector>
  13. std::string EOS = "</s>";
  14. bool readWord(std::istream& in, std::string& word)
  15. {
  16. char c;
  17. std::streambuf& sb = *in.rdbuf();
  18. word.clear();
  19. while ((c = sb.sbumpc()) != EOF) {
  20. if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
  21. c == '\f' || c == '\0') {
  22. if (word.empty()) {
  23. if (c == '\n') {
  24. word += EOS;
  25. return true;
  26. }
  27. continue;
  28. } else {
  29. if (c == '\n')
  30. sb.sungetc();
  31. return true;
  32. }
  33. }
  34. word.push_back(c);
  35. }
  36. in.get();
  37. return !word.empty();
  38. }
  39. int main(int argc, char** argv) {
  40. int k = 10;
  41. if (argc < 4) {
  42. std::cerr<<"eval <pred> <gt> <kb> [<k>]"<<std::endl;
  43. exit(1);
  44. }
  45. if (argc == 5) { k = atoi(argv[4]);}
  46. std::string predfn(argv[1]);
  47. std::ifstream predf(predfn);
  48. std::string gtfn(argv[2]);
  49. std::ifstream gtf(gtfn);
  50. std::string kbfn(argv[3]);
  51. std::ifstream kbf(kbfn);
  52. if (!predf.is_open() || !gtf.is_open() || !kbf.is_open()) {
  53. std::cerr << "Files cannot be opened!" << std::endl;
  54. exit(EXIT_FAILURE);
  55. }
  56. std::unordered_map< std::string,
  57. std::unordered_map< std::string, bool > > KB;
  58. while (kbf.peek() != EOF) {
  59. std::string label, key, word;
  60. while (readWord(kbf, word)) {
  61. if (word == EOS) {break;}
  62. if (word.find("__label__") == 0) {label = word;}
  63. else {key += "|" + word;}
  64. }
  65. KB[key][label] = true;
  66. }
  67. kbf.close();
  68. double precision = 0.0;
  69. int32_t nexamples = 0;
  70. while (predf.peek() != EOF || gtf.peek() != EOF) {
  71. if (predf.peek() == EOF || gtf.peek() == EOF) {
  72. std::cerr<<"pred / gt files have diff sizes"<<std::endl;
  73. exit(1);
  74. }
  75. std::string label, key, word;
  76. while (readWord(gtf, word)) {
  77. if (word == EOS) {break;}
  78. if ( word.find("__label__") == 0) {label = word;}
  79. else {key += "|" + word;}
  80. }
  81. if (KB.find(key) == KB.end()) {
  82. std::cerr<<"empty key!"<<std::endl; exit(1);
  83. }
  84. int count = 0;bool eval = true;
  85. while (readWord(predf, word)) {
  86. if (word == EOS) {break;}
  87. if (!eval) {continue;}
  88. if (label == word) {precision += 1.0; eval = false;}
  89. else if (KB[key].find(word) == KB[key].end()) {count++;}
  90. if (count == k) {eval = false;}
  91. }
  92. nexamples++;
  93. }
  94. predf.close(); gtf.close();
  95. std::cout << "N:\t" << nexamples << std::endl;
  96. std::cout << "R@" << k << "\t" << precision / nexamples << std::endl;
  97. }