| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- /**
- * Copyright (c) 2017-present, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
- #include <unordered_map>
- #include <iostream>
- #include <fstream>
- #include <string>
- #include <vector>
- std::string EOS = "</s>";
- bool readWord(std::istream& in, std::string& word)
- {
- char c;
- std::streambuf& sb = *in.rdbuf();
- word.clear();
- while ((c = sb.sbumpc()) != EOF) {
- if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
- c == '\f' || c == '\0') {
- if (word.empty()) {
- if (c == '\n') {
- word += EOS;
- return true;
- }
- continue;
- } else {
- if (c == '\n')
- sb.sungetc();
- return true;
- }
- }
- word.push_back(c);
- }
- in.get();
- return !word.empty();
- }
- int main(int argc, char** argv) {
- int k = 10;
- if (argc < 4) {
- std::cerr<<"eval <pred> <gt> <kb> [<k>]"<<std::endl;
- exit(1);
- }
- if (argc == 5) { k = atoi(argv[4]);}
- std::string predfn(argv[1]);
- std::ifstream predf(predfn);
- std::string gtfn(argv[2]);
- std::ifstream gtf(gtfn);
- std::string kbfn(argv[3]);
- std::ifstream kbf(kbfn);
- if (!predf.is_open() || !gtf.is_open() || !kbf.is_open()) {
- std::cerr << "Files cannot be opened!" << std::endl;
- exit(EXIT_FAILURE);
- }
- std::unordered_map< std::string,
- std::unordered_map< std::string, bool > > KB;
- while (kbf.peek() != EOF) {
- std::string label, key, word;
- while (readWord(kbf, word)) {
- if (word == EOS) {break;}
- if (word.find("__label__") == 0) {label = word;}
- else {key += "|" + word;}
- }
- KB[key][label] = true;
- }
- kbf.close();
- double precision = 0.0;
- int32_t nexamples = 0;
- while (predf.peek() != EOF || gtf.peek() != EOF) {
- if (predf.peek() == EOF || gtf.peek() == EOF) {
- std::cerr<<"pred / gt files have diff sizes"<<std::endl;
- exit(1);
- }
- std::string label, key, word;
- while (readWord(gtf, word)) {
- if (word == EOS) {break;}
- if ( word.find("__label__") == 0) {label = word;}
- else {key += "|" + word;}
- }
- if (KB.find(key) == KB.end()) {
- std::cerr<<"empty key!"<<std::endl; exit(1);
- }
- int count = 0;bool eval = true;
- while (readWord(predf, word)) {
- if (word == EOS) {break;}
- if (!eval) {continue;}
- if (label == word) {precision += 1.0; eval = false;}
- else if (KB[key].find(word) == KB[key].end()) {count++;}
- if (count == k) {eval = false;}
- }
- nexamples++;
- }
- predf.close(); gtf.close();
- std::cout << "N:\t" << nexamples << std::endl;
- std::cout << "R@" << k << "\t" << precision / nexamples << std::endl;
- }
|