eval.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2016-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the BSD-style license found in the
  8. # LICENSE file in the root directory of this source tree. An additional grant
  9. # of patent rights can be found in the PATENTS file in the same directory.
  10. #
  11. from __future__ import absolute_import
  12. from __future__ import division
  13. from __future__ import print_function
  14. from __future__ import unicode_literals
  15. import numpy as np
  16. import heapq
  17. from scipy import stats
  18. import sys
  19. import os
  20. import math
  21. import argparse
  22. parser = argparse.ArgumentParser(description='Process some integers.')
  23. parser.add_argument('--model', '-m', dest='modelPath', action='store', required=True, help='path to model')
  24. parser.add_argument('--data', '-d', dest='dataPath', action='store', required=True, help='path to data')
  25. args = parser.parse_args()
  26. try:
  27. f = open(args.modelPath, 'r')
  28. except IOError:
  29. sys.exit(0)
  30. embeds = {}
  31. for i, line in enumerate(f):
  32. try:
  33. tab = line.decode('utf8').split()
  34. vec = np.array(tab[1:], dtype=float)
  35. word = tab[0]
  36. #word = tab[0].replace('í', 'i').replace('á', 'a').replace('ó', 'o').replace('ñ', 'n').replace('é', 'e').replace('ú', 'u')
  37. if not word in embeds:
  38. embeds[word] = vec
  39. except ValueError:
  40. continue
  41. except UnicodeDecodeError:
  42. continue
  43. def levenshtein(s1, s2):
  44. if len(s1) < len(s2):
  45. return levenshtein(s2, s1)
  46. # len(s1) >= len(s2)
  47. if len(s2) == 0:
  48. return len(s1)
  49. previous_row = range(len(s2) + 1)
  50. for i, c1 in enumerate(s1):
  51. current_row = [i + 1]
  52. for j, c2 in enumerate(s2):
  53. insertions = previous_row[j + 1] + 1 # j+1 instead of j since previous_row and current_row are one character longer
  54. deletions = current_row[j] + 1 # than s2
  55. substitutions = previous_row[j] + (c1 != c2)
  56. current_row.append(min(insertions, deletions, substitutions))
  57. previous_row = current_row
  58. return previous_row[-1]
  59. def findNearest(query, embeds):
  60. me = 100
  61. for w,vec in embeds.iteritems():
  62. e = levenshtein(query, w)
  63. if e < me:
  64. me = e
  65. nw = w
  66. # print("{0:s} {1:s} {2:f}".format(query, w, e))
  67. return nw
  68. def similarity(v1, v2):
  69. n1 = np.linalg.norm(v1)
  70. n2 = np.linalg.norm(v2)
  71. dp = np.dot(v1, v2)
  72. d = dp / n1 / n2
  73. return d
  74. f = open(args.dataPath, 'r')
  75. doEdit = False
  76. mysim = []
  77. gold = []
  78. mysimDrop = []
  79. goldDrop = []
  80. drop = 0.0
  81. nwords = 0.0
  82. for line in f:
  83. zz = line.decode('utf8').split()
  84. z1 = zz[0].lower()
  85. z2 = zz[1].lower()
  86. score = float(zz[2])
  87. nwords = nwords + 1.0
  88. if (z1 in embeds) and (z2 in embeds):
  89. v1 = embeds[z1]
  90. v2 = embeds[z2]
  91. d = similarity(v1, v2)
  92. mysim.append(d)
  93. gold.append(float(zz[2]))
  94. elif (doEdit):
  95. if (z1 in embeds):
  96. w1 = z1
  97. else:
  98. w1 = findNearest(z1, embeds)
  99. if (z2 in embeds):
  100. w2 = z2
  101. else:
  102. w2 = findNearest(z2, embeds)
  103. v1 = embeds[w1]
  104. v2 = embeds[w2]
  105. d = similarity(v1, v2)
  106. mysimDrop.append(d)
  107. goldDrop.append(score)
  108. drop = drop + 1.0
  109. sys.stdout.write(str(drop) + " ")
  110. sys.stdout.flush()
  111. else:
  112. drop = drop + 1.0
  113. pr = stats.spearmanr(mysim, gold)
  114. dataset = os.path.basename(args.dataPath)
  115. print("{0:20s} & {2:2.0f}\% & {1:2.0f}".format(dataset, pr[0] * 100, math.ceil(drop / nwords * 100.0)))