compute_accuracy.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #!/usr/bin/env python
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the MIT license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. from __future__ import division, absolute_import, print_function
  12. from fastText import load_model
  13. from fastText import util
  14. import argparse
  15. import numpy as np
  16. def process_question(question, cossims, model, words, vectors):
  17. correct = 0
  18. num_qs = 0
  19. num_lines = 0
  20. for line in question:
  21. num_lines += 1
  22. qwords = line.split()
  23. # We lowercase all words to correspond to the preprocessing
  24. # we applied to our data.
  25. qwords = [x.lower().strip() for x in qwords]
  26. # If one of the words is not in the vocabulary we skip this question
  27. found = True
  28. for w in qwords:
  29. if w not in words:
  30. found = False
  31. break
  32. if not found:
  33. continue
  34. # The first three words form the query
  35. # We retrieve their word vectors and normalize them
  36. query = qwords[:3]
  37. query = [model.get_word_vector(x) for x in query]
  38. query = [x / np.linalg.norm(x) for x in query]
  39. # Get the query vector. Example:
  40. # Germany - Berlin + France
  41. query = query[1] - query[0] + query[2]
  42. # We don't need to rank all the words, only until we found
  43. # the first word not equal to our set of query words.
  44. ban_set = list(map(lambda x: words.index(x), qwords[:3]))
  45. if words[util.find_nearest_neighbor(
  46. query, vectors, ban_set, cossims=cossims
  47. )] == qwords[3]:
  48. correct += 1
  49. num_qs += 1
  50. return correct, num_qs, num_lines
  51. # We use the same conventions as within compute-accuracy
  52. def print_compute_accuracy_score(
  53. question, correct, num_qs, total_accuracy, semantic_accuracy,
  54. syntactic_accuracy
  55. ):
  56. print(
  57. (
  58. "{0:>30}: ACCURACY TOP1: {3:.2f} % ({1} / {2})\t Total accuracy: {4:.2f} % Semantic accuracy: {5:.2f} % Syntactic accuracy: {6:.2f} %"
  59. ).format(
  60. question,
  61. correct,
  62. num_qs,
  63. correct / float(num_qs) * 100 if num_qs > 0 else 0,
  64. total_accuracy * 100,
  65. semantic_accuracy * 100,
  66. syntactic_accuracy * 100,
  67. )
  68. )
  69. if __name__ == "__main__":
  70. parser = argparse.ArgumentParser(
  71. description=(
  72. "compute_accuracy equivalent in Python. "
  73. "See https://github.com/tmikolov/word2vec/blob/master/demo-word-accuracy.sh"
  74. )
  75. )
  76. parser.add_argument(
  77. "model",
  78. help="Model to use",
  79. )
  80. parser.add_argument(
  81. "question_words",
  82. help="word questions similar to tmikolov's file (see help for link)",
  83. )
  84. parser.add_argument(
  85. "threshold",
  86. help="threshold used to limit number of words used",
  87. )
  88. args = parser.parse_args()
  89. args.threshold = int(args.threshold)
  90. # Retrieve list of normalized word vectors for the first words up
  91. # until the threshold count.
  92. f = load_model(args.model)
  93. # Gets words with associated frequeny sorted by default by descending order
  94. words, freq = f.get_words(include_freq=True)
  95. words = words[:args.threshold]
  96. vectors = np.zeros((len(words), f.get_dimension()), dtype=float)
  97. for i in range(len(words)):
  98. wv = f.get_word_vector(words[i])
  99. wv = wv / np.linalg.norm(wv)
  100. vectors[i] = wv
  101. total_correct = 0
  102. total_qs = 0
  103. total_num_lines = 0
  104. total_se_correct = 0
  105. total_se_qs = 0
  106. total_sy_correct = 0
  107. total_sy_qs = 0
  108. qid = 0
  109. questions = []
  110. with open(args.question_words, 'r') as fqw:
  111. questions = fqw.read().split(':')[1:]
  112. # For efficiency preallocate the memory to calculate cosine similarities
  113. cossims = np.zeros(len(words), dtype=float)
  114. for question in questions:
  115. quads = question.split('\n')
  116. question = quads[0].strip()
  117. quads = quads[1:-1]
  118. correct, num_qs, num_lines = process_question(
  119. quads, cossims, f, words, vectors
  120. )
  121. total_qs += num_qs
  122. total_correct += correct
  123. total_num_lines += num_lines
  124. if (qid < 5):
  125. total_se_correct += correct
  126. total_se_qs += num_qs
  127. else:
  128. total_sy_correct += correct
  129. total_sy_qs += num_qs
  130. print_compute_accuracy_score(
  131. question,
  132. correct,
  133. num_qs,
  134. total_correct / float(total_qs) if total_qs > 0 else 0,
  135. total_se_correct / float(total_se_qs) if total_se_qs > 0 else 0,
  136. total_sy_correct / float(total_sy_qs) if total_sy_qs > 0 else 0,
  137. )
  138. qid += 1
  139. print(
  140. "Questions seen / total: {0} {1} {2:.2f} %".
  141. format(
  142. total_qs,
  143. total_num_lines,
  144. total_qs / total_num_lines * 100 if total_num_lines > 0 else 0,
  145. )
  146. )