unsup_align.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2018-present, Facebook, Inc.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the MIT license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import codecs, sys, time, math, argparse, ot
  8. import numpy as np
  9. from utils import *
  10. parser = argparse.ArgumentParser(description='Wasserstein Procrustes for Embedding Alignment')
  11. parser.add_argument('--model_src', type=str, help='Path to source word embeddings')
  12. parser.add_argument('--model_tgt', type=str, help='Path to target word embeddings')
  13. parser.add_argument('--lexicon', type=str, help='Path to the evaluation lexicon')
  14. parser.add_argument('--output_src', default='', type=str, help='Path to save the aligned source embeddings')
  15. parser.add_argument('--output_tgt', default='', type=str, help='Path to save the aligned target embeddings')
  16. parser.add_argument('--seed', default=1111, type=int, help='Random number generator seed')
  17. parser.add_argument('--nepoch', default=5, type=int, help='Number of epochs')
  18. parser.add_argument('--niter', default=5000, type=int, help='Initial number of iterations')
  19. parser.add_argument('--bsz', default=500, type=int, help='Initial batch size')
  20. parser.add_argument('--lr', default=500., type=float, help='Learning rate')
  21. parser.add_argument('--nmax', default=20000, type=int, help='Vocabulary size for learning the alignment')
  22. parser.add_argument('--reg', default=0.05, type=float, help='Regularization parameter for sinkhorn')
  23. args = parser.parse_args()
  24. def objective(X, Y, R, n=5000):
  25. Xn, Yn = X[:n], Y[:n]
  26. C = -np.dot(np.dot(Xn, R), Yn.T)
  27. P = ot.sinkhorn(np.ones(n), np.ones(n), C, 0.025, stopThr=1e-3)
  28. return 1000 * np.linalg.norm(np.dot(Xn, R) - np.dot(P, Yn)) / n
  29. def sqrt_eig(x):
  30. U, s, VT = np.linalg.svd(x, full_matrices=False)
  31. return np.dot(U, np.dot(np.diag(np.sqrt(s)), VT))
  32. def align(X, Y, R, lr=10., bsz=200, nepoch=5, niter=1000,
  33. nmax=10000, reg=0.05, verbose=True):
  34. for epoch in range(1, nepoch + 1):
  35. for _it in range(1, niter + 1):
  36. # sample mini-batch
  37. xt = X[np.random.permutation(nmax)[:bsz], :]
  38. yt = Y[np.random.permutation(nmax)[:bsz], :]
  39. # compute OT on minibatch
  40. C = -np.dot(np.dot(xt, R), yt.T)
  41. P = ot.sinkhorn(np.ones(bsz), np.ones(bsz), C, reg, stopThr=1e-3)
  42. # compute gradient
  43. G = - np.dot(xt.T, np.dot(P, yt))
  44. R -= lr / bsz * G
  45. # project on orthogonal matrices
  46. U, s, VT = np.linalg.svd(R)
  47. R = np.dot(U, VT)
  48. bsz *= 2
  49. niter //= 4
  50. if verbose:
  51. print("epoch: %d obj: %.3f" % (epoch, objective(X, Y, R)))
  52. return R
  53. def convex_init(X, Y, niter=100, reg=0.05, apply_sqrt=False):
  54. n, d = X.shape
  55. if apply_sqrt:
  56. X, Y = sqrt_eig(X), sqrt_eig(Y)
  57. K_X, K_Y = np.dot(X, X.T), np.dot(Y, Y.T)
  58. K_Y *= np.linalg.norm(K_X) / np.linalg.norm(K_Y)
  59. K2_X, K2_Y = np.dot(K_X, K_X), np.dot(K_Y, K_Y)
  60. P = np.ones([n, n]) / float(n)
  61. for it in range(1, niter + 1):
  62. G = np.dot(P, K2_X) + np.dot(K2_Y, P) - 2 * np.dot(K_Y, np.dot(P, K_X))
  63. q = ot.sinkhorn(np.ones(n), np.ones(n), G, reg, stopThr=1e-3)
  64. alpha = 2.0 / float(2.0 + it)
  65. P = alpha * q + (1.0 - alpha) * P
  66. obj = np.linalg.norm(np.dot(P, K_X) - np.dot(K_Y, P))
  67. print(obj)
  68. return procrustes(np.dot(P, X), Y).T
  69. print("\n*** Wasserstein Procrustes ***\n")
  70. np.random.seed(args.seed)
  71. maxload = 200000
  72. w_src, x_src = load_vectors(args.model_src, maxload, norm=True, center=True)
  73. w_tgt, x_tgt = load_vectors(args.model_tgt, maxload, norm=True, center=True)
  74. src2trg, _ = load_lexicon(args.lexicon, w_src, w_tgt)
  75. print("\nComputing initial mapping with convex relaxation...")
  76. t0 = time.time()
  77. R0 = convex_init(x_src[:2500], x_tgt[:2500], reg=args.reg, apply_sqrt=True)
  78. print("Done [%03d sec]" % math.floor(time.time() - t0))
  79. print("\nComputing mapping with Wasserstein Procrustes...")
  80. t0 = time.time()
  81. R = align(x_src, x_tgt, R0.copy(), bsz=args.bsz, lr=args.lr, niter=args.niter,
  82. nepoch=args.nepoch, reg=args.reg, nmax=args.nmax)
  83. print("Done [%03d sec]" % math.floor(time.time() - t0))
  84. acc = compute_nn_accuracy(x_src, np.dot(x_tgt, R.T), src2trg)
  85. print("\nPrecision@1: %.3f\n" % acc)
  86. if args.output_src != '':
  87. x_src = x_src / np.linalg.norm(x_src, 2, 1).reshape([-1, 1])
  88. save_vectors(args.output_src, x_src, w_src)
  89. if args.output_tgt != '':
  90. x_tgt = x_tgt / np.linalg.norm(x_tgt, 2, 1).reshape([-1, 1])
  91. save_vectors(args.output_tgt, np.dot(x_tgt, R.T), w_tgt)