eval.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2016-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the BSD-style license found in the
  8. # LICENSE file in the root directory of this source tree. An additional grant
  9. # of patent rights can be found in the PATENTS file in the same directory.
  10. #
  11. from __future__ import absolute_import
  12. from __future__ import division
  13. from __future__ import print_function
  14. from __future__ import unicode_literals
  15. import numpy as np
  16. from scipy import stats
  17. import os
  18. import math
  19. import argparse
  20. def compat_splitting(line):
  21. return line.decode('utf8').split()
  22. def similarity(v1, v2):
  23. n1 = np.linalg.norm(v1)
  24. n2 = np.linalg.norm(v2)
  25. return np.dot(v1, v2) / n1 / n2
  26. parser = argparse.ArgumentParser(description='Process some integers.')
  27. parser.add_argument(
  28. '--model',
  29. '-m',
  30. dest='modelPath',
  31. action='store',
  32. required=True,
  33. help='path to model'
  34. )
  35. parser.add_argument(
  36. '--data',
  37. '-d',
  38. dest='dataPath',
  39. action='store',
  40. required=True,
  41. help='path to data'
  42. )
  43. args = parser.parse_args()
  44. vectors = {}
  45. fin = open(args.modelPath, 'rb')
  46. for _, line in enumerate(fin):
  47. try:
  48. tab = compat_splitting(line)
  49. vec = np.array(tab[1:], dtype=float)
  50. word = tab[0]
  51. if np.linalg.norm(vec) == 0:
  52. continue
  53. if not word in vectors:
  54. vectors[word] = vec
  55. except ValueError:
  56. continue
  57. except UnicodeDecodeError:
  58. continue
  59. fin.close()
  60. mysim = []
  61. gold = []
  62. drop = 0.0
  63. nwords = 0.0
  64. fin = open(args.dataPath, 'rb')
  65. for line in fin:
  66. tline = compat_splitting(line)
  67. word1 = tline[0].lower()
  68. word2 = tline[1].lower()
  69. nwords = nwords + 1.0
  70. if (word1 in vectors) and (word2 in vectors):
  71. v1 = vectors[word1]
  72. v2 = vectors[word2]
  73. d = similarity(v1, v2)
  74. mysim.append(d)
  75. gold.append(float(tline[2]))
  76. else:
  77. drop = drop + 1.0
  78. fin.close()
  79. corr = stats.spearmanr(mysim, gold)
  80. dataset = os.path.basename(args.dataPath)
  81. print(
  82. "{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)"
  83. .format(dataset, corr[0] * 100, math.ceil(drop / nwords * 100.0))
  84. )