1
0

eval.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2016-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the MIT license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. #
  10. from __future__ import absolute_import
  11. from __future__ import division
  12. from __future__ import print_function
  13. from __future__ import unicode_literals
  14. import numpy as np
  15. from scipy import stats
  16. import os
  17. import math
  18. import argparse
  19. def compat_splitting(line):
  20. return line.decode('utf8').split()
  21. def similarity(v1, v2):
  22. n1 = np.linalg.norm(v1)
  23. n2 = np.linalg.norm(v2)
  24. return np.dot(v1, v2) / n1 / n2
  25. parser = argparse.ArgumentParser(description='Process some integers.')
  26. parser.add_argument(
  27. '--model',
  28. '-m',
  29. dest='modelPath',
  30. action='store',
  31. required=True,
  32. help='path to model'
  33. )
  34. parser.add_argument(
  35. '--data',
  36. '-d',
  37. dest='dataPath',
  38. action='store',
  39. required=True,
  40. help='path to data'
  41. )
  42. args = parser.parse_args()
  43. vectors = {}
  44. fin = open(args.modelPath, 'rb')
  45. for _, line in enumerate(fin):
  46. try:
  47. tab = compat_splitting(line)
  48. vec = np.array(tab[1:], dtype=float)
  49. word = tab[0]
  50. if np.linalg.norm(vec) == 0:
  51. continue
  52. if not word in vectors:
  53. vectors[word] = vec
  54. except ValueError:
  55. continue
  56. except UnicodeDecodeError:
  57. continue
  58. fin.close()
  59. mysim = []
  60. gold = []
  61. drop = 0.0
  62. nwords = 0.0
  63. fin = open(args.dataPath, 'rb')
  64. for line in fin:
  65. tline = compat_splitting(line)
  66. word1 = tline[0].lower()
  67. word2 = tline[1].lower()
  68. nwords = nwords + 1.0
  69. if (word1 in vectors) and (word2 in vectors):
  70. v1 = vectors[word1]
  71. v2 = vectors[word2]
  72. d = similarity(v1, v2)
  73. mysim.append(d)
  74. gold.append(float(tline[2]))
  75. else:
  76. drop = drop + 1.0
  77. fin.close()
  78. corr = stats.spearmanr(mysim, gold)
  79. dataset = os.path.basename(args.dataPath)
  80. print(
  81. "{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)"
  82. .format(dataset, corr[0] * 100, math.ceil(drop / nwords * 100.0))
  83. )