train_unsupervised.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #!/usr/bin/env python
  2. # Copyright (c) 2017-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree. An additional grant
  7. # of patent rights can be found in the PATENTS file in the same directory.
  8. from __future__ import absolute_import
  9. from __future__ import division
  10. from __future__ import print_function
  11. from __future__ import unicode_literals
  12. from __future__ import division, absolute_import, print_function
  13. from fastText import train_unsupervised
  14. import numpy as np
  15. import os
  16. from scipy import stats
  17. # Because of fasttext we don't need to account for OOV
  18. def compute_similarity(data_path):
  19. def similarity(v1, v2):
  20. n1 = np.linalg.norm(v1)
  21. n2 = np.linalg.norm(v2)
  22. return np.dot(v1, v2) / n1 / n2
  23. mysim = []
  24. gold = []
  25. with open(data_path, 'rb') as fin:
  26. for line in fin:
  27. tline = line.split()
  28. word1 = tline[0].lower()
  29. word2 = tline[1].lower()
  30. v1 = model.get_word_vector(word1)
  31. v2 = model.get_word_vector(word2)
  32. d = similarity(v1, v2)
  33. mysim.append(d)
  34. gold.append(float(tline[2]))
  35. corr = stats.spearmanr(mysim, gold)
  36. dataset = os.path.basename(data_path)
  37. correlation = corr[0] * 100
  38. return dataset, correlation, 0
  39. if __name__ == "__main__":
  40. model = train_unsupervised(
  41. input=os.path.join(os.getenv("DATADIR", ''), 'fil9'),
  42. model='skipgram',
  43. )
  44. model.save_model("fil9.bin")
  45. dataset, corr, oov = compute_similarity('rw.txt')
  46. print("{0:20s}: {1:2.0f} (OOV: {2:2.0f}%)".format(dataset, corr, 0))