utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2018-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import io
  8. import numpy as np
  9. import collections
  10. def load_vectors(fname, maxload=200000, norm=True, center=False, verbose=True):
  11. if verbose:
  12. print("Loading vectors from %s" % fname)
  13. fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
  14. n, d = map(int, fin.readline().split())
  15. if maxload > 0:
  16. n = min(n, maxload)
  17. x = np.zeros([n, d])
  18. words = []
  19. for i, line in enumerate(fin):
  20. if i >= n:
  21. break
  22. tokens = line.rstrip().split(' ')
  23. words.append(tokens[0])
  24. v = np.array(tokens[1:], dtype=float)
  25. x[i, :] = v
  26. if norm:
  27. x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
  28. if center:
  29. x -= x.mean(axis=0)[np.newaxis, :]
  30. x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
  31. if verbose:
  32. print("%d word vectors loaded" % (len(words)))
  33. return words, x
  34. def idx(words):
  35. w2i = {}
  36. for i, w in enumerate(words):
  37. if w not in w2i:
  38. w2i[w] = i
  39. return w2i
  40. def save_vectors(fname, x, words):
  41. n, d = x.shape
  42. fout = io.open(fname, 'w', encoding='utf-8')
  43. fout.write(u"%d %d\n" % (n, d))
  44. for i in range(n):
  45. fout.write(words[i] + " " + " ".join(map(lambda a: "%.4f" % a, x[i, :])) + "\n")
  46. fout.close()
  47. def save_matrix(fname, x):
  48. n, d = x.shape
  49. fout = io.open(fname, 'w', encoding='utf-8')
  50. fout.write(u"%d %d\n" % (n, d))
  51. for i in range(n):
  52. fout.write(" ".join(map(lambda a: "%.4f" % a, x[i, :])) + "\n")
  53. fout.close()
  54. def procrustes(X_src, Y_tgt):
  55. U, s, V = np.linalg.svd(np.dot(Y_tgt.T, X_src))
  56. return np.dot(U, V)
  57. def select_vectors_from_pairs(x_src, y_tgt, pairs):
  58. n = len(pairs)
  59. d = x_src.shape[1]
  60. x = np.zeros([n, d])
  61. y = np.zeros([n, d])
  62. for k, ij in enumerate(pairs):
  63. i, j = ij
  64. x[k, :] = x_src[i, :]
  65. y[k, :] = y_tgt[j, :]
  66. return x, y
  67. def load_lexicon(filename, words_src, words_tgt, verbose=True):
  68. f = io.open(filename, 'r', encoding='utf-8')
  69. lexicon = collections.defaultdict(set)
  70. idx_src , idx_tgt = idx(words_src), idx(words_tgt)
  71. vocab = set()
  72. for line in f:
  73. word_src, word_tgt = line.split()
  74. if word_src in idx_src and word_tgt in idx_tgt:
  75. lexicon[idx_src[word_src]].add(idx_tgt[word_tgt])
  76. vocab.add(word_src)
  77. if verbose:
  78. coverage = len(lexicon) / float(len(vocab))
  79. print("Coverage of source vocab: %.4f" % (coverage))
  80. return lexicon, float(len(vocab))
  81. def load_pairs(filename, idx_src, idx_tgt, verbose=True):
  82. f = io.open(filename, 'r', encoding='utf-8')
  83. pairs = []
  84. tot = 0
  85. for line in f:
  86. a, b = line.rstrip().split(' ')
  87. tot += 1
  88. if a in idx_src and b in idx_tgt:
  89. pairs.append((idx_src[a], idx_tgt[b]))
  90. if verbose:
  91. coverage = (1.0 * len(pairs)) / tot
  92. print("Found pairs for training: %d - Total pairs in file: %d - Coverage of pairs: %.4f" % (len(pairs), tot, coverage))
  93. return pairs
  94. def compute_nn_accuracy(x_src, x_tgt, lexicon, bsz=100, lexicon_size=-1):
  95. if lexicon_size < 0:
  96. lexicon_size = len(lexicon)
  97. idx_src = list(lexicon.keys())
  98. acc = 0.0
  99. x_src /= np.linalg.norm(x_src, axis=1)[:, np.newaxis] + 1e-8
  100. x_tgt /= np.linalg.norm(x_tgt, axis=1)[:, np.newaxis] + 1e-8
  101. for i in range(0, len(idx_src), bsz):
  102. e = min(i + bsz, len(idx_src))
  103. scores = np.dot(x_tgt, x_src[idx_src[i:e]].T)
  104. pred = scores.argmax(axis=0)
  105. for j in range(i, e):
  106. if pred[j - i] in lexicon[idx_src[j]]:
  107. acc += 1.0
  108. return acc / lexicon_size
  109. def compute_csls_accuracy(x_src, x_tgt, lexicon, lexicon_size=-1, k=10, bsz=1024):
  110. if lexicon_size < 0:
  111. lexicon_size = len(lexicon)
  112. idx_src = list(lexicon.keys())
  113. x_src /= np.linalg.norm(x_src, axis=1)[:, np.newaxis] + 1e-8
  114. x_tgt /= np.linalg.norm(x_tgt, axis=1)[:, np.newaxis] + 1e-8
  115. sr = x_src[list(idx_src)]
  116. sc = np.dot(sr, x_tgt.T)
  117. similarities = 2 * sc
  118. sc2 = np.zeros(x_tgt.shape[0])
  119. for i in range(0, x_tgt.shape[0], bsz):
  120. j = min(i + bsz, x_tgt.shape[0])
  121. sc_batch = np.dot(x_tgt[i:j, :], x_src.T)
  122. dotprod = np.partition(sc_batch, -k, axis=1)[:, -k:]
  123. sc2[i:j] = np.mean(dotprod, axis=1)
  124. similarities -= sc2[np.newaxis, :]
  125. nn = np.argmax(similarities, axis=1).tolist()
  126. correct = 0.0
  127. for k in range(0, len(lexicon)):
  128. if nn[k] in lexicon[idx_src[k]]:
  129. correct += 1.0
  130. return correct / lexicon_size