align.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2018-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. import numpy as np
  10. import argparse
  11. from utils import *
  12. import sys
  13. parser = argparse.ArgumentParser(description='RCSLS for supervised word alignment')
  14. parser.add_argument("--src_emb", type=str, default='', help="Load source embeddings")
  15. parser.add_argument("--tgt_emb", type=str, default='', help="Load target embeddings")
  16. parser.add_argument('--center', action='store_true', help='whether to center embeddings or not')
  17. parser.add_argument("--dico_train", type=str, default='', help="train dictionary")
  18. parser.add_argument("--dico_test", type=str, default='', help="validation dictionary")
  19. parser.add_argument("--output", type=str, default='', help="where to save aligned embeddings")
  20. parser.add_argument("--knn", type=int, default=10, help="number of nearest neighbors in RCSL/CSLS")
  21. parser.add_argument("--maxneg", type=int, default=200000, help="Maximum number of negatives for the Extended RCSLS")
  22. parser.add_argument("--maxsup", type=int, default=-1, help="Maximum number of training examples")
  23. parser.add_argument("--maxload", type=int, default=200000, help="Maximum number of loaded vectors")
  24. parser.add_argument("--model", type=str, default="none", help="Set of constraints: spectral or none")
  25. parser.add_argument("--reg", type=float, default=0.0 , help='regularization parameters')
  26. parser.add_argument("--lr", type=float, default=1.0, help='learning rate')
  27. parser.add_argument("--niter", type=int, default=10, help='number of iterations')
  28. parser.add_argument('--sgd', action='store_true', help='use sgd')
  29. parser.add_argument("--batchsize", type=int, default=10000, help="batch size for sgd")
  30. params = parser.parse_args()
  31. ###### SPECIFIC FUNCTIONS ######
  32. # functions specific to RCSLS
  33. # the rest of the functions are in utils.py
  34. def getknn(sc, x, y, k=10):
  35. sidx = np.argpartition(sc, -k, axis=1)[:, -k:]
  36. ytopk = y[sidx.flatten(), :]
  37. ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1])
  38. f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx])
  39. df = np.dot(ytopk.sum(1).T, x)
  40. return f / k, df / k
  41. def rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, knn=10):
  42. X_trans = np.dot(X_src, R.T)
  43. f = 2 * np.sum(X_trans * Y_tgt)
  44. df = 2 * np.dot(Y_tgt.T, X_src)
  45. fk0, dfk0 = getknn(np.dot(X_trans, Z_tgt.T), X_src, Z_tgt, knn)
  46. fk1, dfk1 = getknn(np.dot(np.dot(Z_src, R.T), Y_tgt.T).T, Y_tgt, Z_src, knn)
  47. f = f - fk0 -fk1
  48. df = df - dfk0 - dfk1.T
  49. return -f / X_src.shape[0], -df / X_src.shape[0]
  50. def proj_spectral(R):
  51. U, s, V = np.linalg.svd(R)
  52. s[s > 1] = 1
  53. s[s < 0] = 0
  54. return np.dot(U, np.dot(np.diag(s), V))
  55. ###### MAIN ######
  56. # load word embeddings
  57. words_tgt, x_tgt = load_vectors(params.tgt_emb, maxload=params.maxload, center=params.center)
  58. words_src, x_src = load_vectors(params.src_emb, maxload=params.maxload, center=params.center)
  59. # load validation bilingual lexicon
  60. src2tgt, lexicon_size = load_lexicon(params.dico_test, words_src, words_tgt)
  61. # word --> vector indices
  62. idx_src = idx(words_src)
  63. idx_tgt = idx(words_tgt)
  64. # load train bilingual lexicon
  65. pairs = load_pairs(params.dico_train, idx_src, idx_tgt)
  66. if params.maxsup > 0 and params.maxsup < len(pairs):
  67. pairs = pairs[:params.maxsup]
  68. # selecting training vector pairs
  69. X_src, Y_tgt = select_vectors_from_pairs(x_src, x_tgt, pairs)
  70. # adding negatives for RCSLS
  71. Z_src = x_src[:params.maxneg, :]
  72. Z_tgt = x_tgt[:params.maxneg, :]
  73. # initialization:
  74. R = procrustes(X_src, Y_tgt)
  75. nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
  76. print("[init -- Procrustes] NN: %.4f"%(nnacc))
  77. sys.stdout.flush()
  78. # optimization
  79. fold, Rold = 0, []
  80. niter, lr = params.niter, params.lr
  81. for it in range(0, niter + 1):
  82. if lr < 1e-4:
  83. break
  84. if params.sgd:
  85. indices = np.random.choice(X_src.shape[0], size=params.batchsize, replace=False)
  86. f, df = rcsls(X_src[indices, :], Y_tgt[indices, :], Z_src, Z_tgt, R, params.knn)
  87. else:
  88. f, df = rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, params.knn)
  89. if params.reg > 0:
  90. R *= (1 - lr * params.reg)
  91. R -= lr * df
  92. if params.model == "spectral":
  93. R = proj_spectral(R)
  94. print("[it=%d] f = %.4f" % (it, f))
  95. sys.stdout.flush()
  96. if f > fold and it > 0 and not params.sgd:
  97. lr /= 2
  98. f, R = fold, Rold
  99. fold, Rold = f, R
  100. if (it > 0 and it % 10 == 0) or it == niter:
  101. nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
  102. print("[it=%d] NN = %.4f - Coverage = %.4f" % (it, nnacc, len(src2tgt) / lexicon_size))
  103. nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
  104. print("[final] NN = %.4f - Coverage = %.4f" % (nnacc, len(src2tgt) / lexicon_size))
  105. if params.output != "":
  106. print("Saving all aligned vectors at %s" % params.output)
  107. words_full, x_full = load_vectors(params.src_emb, maxload=-1, center=params.center, verbose=False)
  108. x = np.dot(x_full, R.T)
  109. x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
  110. save_vectors(params.output, x, words_full)
  111. save_matrix(params.output + "-mat", R)