eval.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2016-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the BSD-style license found in the
  8. # LICENSE file in the root directory of this source tree. An additional grant
  9. # of patent rights can be found in the PATENTS file in the same directory.
  10. #
  11. from __future__ import absolute_import
  12. from __future__ import division
  13. from __future__ import print_function
  14. from __future__ import unicode_literals
  15. import numpy as np
  16. from scipy import stats
  17. import sys
  18. import os
  19. import math
  20. import argparse
  21. def compat_splitting(line):
  22. return line.decode('utf8').split()
  23. def similarity(v1, v2):
  24. n1 = np.linalg.norm(v1)
  25. n2 = np.linalg.norm(v2)
  26. return np.dot(v1, v2) / n1 / n2
  27. parser = argparse.ArgumentParser(description='Process some integers.')
  28. parser.add_argument('--model', '-m', dest='modelPath', action='store', required=True, help='path to model')
  29. parser.add_argument('--data', '-d', dest='dataPath', action='store', required=True, help='path to data')
  30. args = parser.parse_args()
  31. vectors = {}
  32. fin = open(args.modelPath, 'rb')
  33. for i, line in enumerate(fin):
  34. try:
  35. tab = compat_splitting(line)
  36. vec = np.array(tab[1:], dtype=float)
  37. word = tab[0]
  38. if np.linalg.norm(vec) == 0:
  39. continue
  40. if not word in vectors:
  41. vectors[word] = vec
  42. except ValueError:
  43. continue
  44. except UnicodeDecodeError:
  45. continue
  46. fin.close()
  47. mysim = []
  48. gold = []
  49. drop = 0.0
  50. nwords = 0.0
  51. fin = open(args.dataPath, 'rb')
  52. for line in fin:
  53. tline = compat_splitting(line)
  54. word1 = tline[0].lower()
  55. word2 = tline[1].lower()
  56. nwords = nwords + 1.0
  57. if (word1 in vectors) and (word2 in vectors):
  58. v1 = vectors[word1]
  59. v2 = vectors[word2]
  60. d = similarity(v1, v2)
  61. mysim.append(d)
  62. gold.append(float(tline[2]))
  63. else:
  64. drop = drop + 1.0
  65. fin.close()
  66. corr = stats.spearmanr(mysim, gold)
  67. dataset = os.path.basename(args.dataPath)
  68. print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)"
  69. .format(dataset, corr[0] * 100, math.ceil(drop / nwords * 100.0)))