1
0

eval.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2018-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. import io
  10. import numpy as np
  11. import argparse
  12. from utils import *
  13. parser = argparse.ArgumentParser(description='Evaluation of word alignment')
  14. parser.add_argument("--src_emb", type=str, default='', help="Load source embeddings")
  15. parser.add_argument("--tgt_emb", type=str, default='', help="Load target embeddings")
  16. parser.add_argument('--center', action='store_true', help='whether to center embeddings or not')
  17. parser.add_argument("--src_mat", type=str, default='', help="Load source alignment matrix. If none given, the aligment matrix is the identity.")
  18. parser.add_argument("--tgt_mat", type=str, default='', help="Load target alignment matrix. If none given, the aligment matrix is the identity.")
  19. parser.add_argument("--dico_test", type=str, default='', help="test dictionary")
  20. parser.add_argument("--maxload", type=int, default=200000)
  21. parser.add_argument("--nomatch", action='store_true', help="no exact match in lexicon")
  22. params = parser.parse_args()
  23. ###### SPECIFIC FUNCTIONS ######
  24. # function specific to evaluation
  25. # the rest of the functions are in utils.py
  26. def load_transform(fname, d1=300, d2=300):
  27. fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
  28. R = np.zeros([d1, d2])
  29. for i, line in enumerate(fin):
  30. tokens = line.split(' ')
  31. R[i, :] = np.array(tokens[0:d2], dtype=float)
  32. return R
  33. ###### MAIN ######
  34. print("Evaluation of alignment on %s" % params.dico_test)
  35. if params.nomatch:
  36. print("running without exact string matches")
  37. words_tgt, x_tgt = load_vectors(params.tgt_emb, maxload=params.maxload, center=params.center)
  38. words_src, x_src = load_vectors(params.src_emb, maxload=params.maxload, center=params.center)
  39. if params.tgt_mat != "":
  40. R_tgt = load_transform(params.tgt_mat)
  41. x_tgt = np.dot(x_tgt, R_tgt)
  42. if params.src_mat != "":
  43. R_src = load_transform(params.src_mat)
  44. x_src = np.dot(x_src, R_src)
  45. src2tgt, lexicon_size = load_lexicon(params.dico_test, words_src, words_tgt)
  46. nnacc = compute_nn_accuracy(x_src, x_tgt, src2tgt, lexicon_size=lexicon_size)
  47. cslsproc = compute_csls_accuracy(x_src, x_tgt, src2tgt, lexicon_size=lexicon_size)
  48. print("NN = %.4f - CSLS = %.4f - Coverage = %.4f" % (nnacc, cslsproc, len(src2tgt) / lexicon_size))