align.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2018-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. import numpy as np
  10. import argparse
  11. from utils import *
  12. parser = argparse.ArgumentParser(description='RCSLS for supervised word alignment')
  13. parser.add_argument("--src_emb", type=str, default='', help="Load source embeddings")
  14. parser.add_argument("--tgt_emb", type=str, default='', help="Load target embeddings")
  15. parser.add_argument('--center', action='store_true', help='whether to center embeddings or not')
  16. parser.add_argument("--dico_train", type=str, default='', help="train dictionary")
  17. parser.add_argument("--dico_test", type=str, default='', help="validation dictionary")
  18. parser.add_argument("--output", type=str, default='', help="where to save aligned embeddings")
  19. parser.add_argument("--knn", type=int, default=10, help="number of nearest neighbors in RCSL/CSLS")
  20. parser.add_argument("--maxneg", type=int, default=200000, help="Maximum number of negatives for the Extended RCSLS")
  21. parser.add_argument("--maxsup", type=int, default=-1, help="Maximum number of training examples")
  22. parser.add_argument("--maxload", type=int, default=200000, help="Maximum number of loaded vectors")
  23. parser.add_argument("--model", type=str, default="none", help="Set of constraints: spectral or none")
  24. parser.add_argument("--reg", type=float, default=0.0 , help='regularization parameters')
  25. parser.add_argument("--lr", type=float, default=1.0, help='learning rate')
  26. parser.add_argument("--niter", type=int, default=10, help='number of iterations')
  27. parser.add_argument('--sgd', action='store_true', help='use sgd')
  28. parser.add_argument("--batchsize", type=int, default=10000, help="batch size for sgd")
  29. params = parser.parse_args()
  30. ###### SPECIFIC FUNCTIONS ######
  31. # functions specific to RCSLS
  32. # the rest of the functions are in utils.py
  33. def getknn(sc, x, y, k=10):
  34. sidx = np.argpartition(sc, -k, axis=1)[:, -k:]
  35. ytopk = y[sidx.flatten(), :]
  36. ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1])
  37. f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx])
  38. df = np.dot(ytopk.sum(1).T, x)
  39. return f / k, df / k
  40. def rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, knn=10):
  41. X_trans = np.dot(X_src, R.T)
  42. f = 2 * np.sum(X_trans * Y_tgt)
  43. df = 2 * np.dot(Y_tgt.T, X_src)
  44. fk0, dfk0 = getknn(np.dot(X_trans, Z_tgt.T), X_src, Z_tgt, knn)
  45. fk1, dfk1 = getknn(np.dot(np.dot(Z_src, R.T), Y_tgt.T).T, Y_tgt, Z_src, knn)
  46. f = f - fk0 -fk1
  47. df = df - dfk0 - dfk1.T
  48. return -f / X_src.shape[0], -df / X_src.shape[0]
  49. def proj_spectral(R):
  50. U, s, V = np.linalg.svd(R)
  51. s[s > 1] = 1
  52. s[s < 0] = 0
  53. return np.dot(U, np.dot(np.diag(s), V))
  54. ###### MAIN ######
  55. # load word embeddings
  56. words_tgt, x_tgt = load_vectors(params.tgt_emb, maxload=params.maxload, center=params.center)
  57. words_src, x_src = load_vectors(params.src_emb, maxload=params.maxload, center=params.center)
  58. # load validation bilingual lexicon
  59. src2tgt, lexicon_size = load_lexicon(params.dico_test, words_src, words_tgt)
  60. # word --> vector indices
  61. idx_src = idx(words_src)
  62. idx_tgt = idx(words_tgt)
  63. # load train bilingual lexicon
  64. pairs = load_pairs(params.dico_train, idx_src, idx_tgt)
  65. if params.maxsup > 0 and params.maxsup < len(pairs):
  66. pairs = pairs[:params.maxsup]
  67. # selecting training vector pairs
  68. X_src, Y_tgt = select_vectors_from_pairs(x_src, x_tgt, pairs)
  69. # adding negatives for RCSLS
  70. Z_src = x_src[:params.maxneg, :]
  71. Z_tgt = x_tgt[:params.maxneg, :]
  72. # initialization:
  73. R = procrustes(X_src, Y_tgt)
  74. nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
  75. print("[init -- Procrustes] NN: %.4f"%(nnacc))
  76. sys.stdout.flush()
  77. # optimization
  78. fold, Rold = 0, []
  79. niter, lr = params.niter, params.lr
  80. for it in range(0, niter + 1):
  81. if lr < 1e-4:
  82. break
  83. if params.sgd:
  84. indices = np.random.choice(X_src.shape[0], size=params.batchsize, replace=False)
  85. f, df = rcsls(X_src[indices, :], Y_tgt[indices, :], Z_src, Z_tgt, R, params.knn)
  86. else:
  87. f, df = rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, params.knn)
  88. if params.reg > 0:
  89. R *= (1 - lr * params.reg)
  90. R -= lr * df
  91. if params.model == "spectral":
  92. R = proj_spectral(R)
  93. print("[it=%d] f = %.4f" % (it, f))
  94. sys.stdout.flush()
  95. if f > fold and it > 0 and not params.sgd:
  96. lr /= 2
  97. f, R = fold, Rold
  98. fold, Rold = f, R
  99. if (it > 0 and it % 10 == 0) or it == niter:
  100. nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
  101. print("[it=%d] NN = %.4f - Coverage = %.4f" % (it, nnacc, len(src2tgt) / lexicon_size))
  102. nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
  103. print("[final] NN = %.4f - Coverage = %.4f" % (nnacc, len(src2tgt) / lexicon_size))
  104. if params.output != "":
  105. print("Saving all aligned vectors at %s" % params.output)
  106. words_full, x_full = load_vectors(params.src_emb, maxload=-1, center=params.center, verbose=False)
  107. x = np.dot(x_full, R.T)
  108. x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
  109. save_vectors(params.output, x, words_full)
  110. save_matrix(params.output + "-mat", R)