train_unsupervised.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the MIT license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. from __future__ import division, absolute_import, print_function
  12. from fastText import train_unsupervised
  13. import numpy as np
  14. import os
  15. from scipy import stats
  16. # Because of fasttext we don't need to account for OOV
  17. def compute_similarity(data_path):
  18. def similarity(v1, v2):
  19. n1 = np.linalg.norm(v1)
  20. n2 = np.linalg.norm(v2)
  21. return np.dot(v1, v2) / n1 / n2
  22. mysim = []
  23. gold = []
  24. with open(data_path, 'rb') as fin:
  25. for line in fin:
  26. tline = line.split()
  27. word1 = tline[0].lower()
  28. word2 = tline[1].lower()
  29. v1 = model.get_word_vector(word1)
  30. v2 = model.get_word_vector(word2)
  31. d = similarity(v1, v2)
  32. mysim.append(d)
  33. gold.append(float(tline[2]))
  34. corr = stats.spearmanr(mysim, gold)
  35. dataset = os.path.basename(data_path)
  36. correlation = corr[0] * 100
  37. return dataset, correlation, 0
  38. if __name__ == "__main__":
  39. model = train_unsupervised(
  40. input=os.path.join(os.getenv("DATADIR", ''), 'fil9'),
  41. model='skipgram',
  42. )
  43. model.save_model("fil9.bin")
  44. dataset, corr, oov = compute_similarity('rw.txt')
  45. print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)".format(dataset, corr, 0))