1
0

unsup_multialign.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright (c) 2019-present, Facebook, Inc.
  5. # All rights reserved.
  6. #
  7. # This source code is licensed under the license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. import io, os, ot, argparse, random
  10. import numpy as np
  11. from utils import *
  12. parser = argparse.ArgumentParser(description=' ')
  13. parser.add_argument('--embdir', default='data/', type=str)
  14. parser.add_argument('--outdir', default='output/', type=str)
  15. parser.add_argument('--lglist', default='en-fr-es-it-pt-de-pl-ru-da-nl-cs', type=str,
  16. help='list of languages. The first element is the pivot. Example: en-fr-es to align English, French and Spanish with English as the pivot.')
  17. parser.add_argument('--maxload', default=20000, type=int, help='Max number of loaded vectors')
  18. parser.add_argument('--uniform', action='store_true', help='switch to uniform probability of picking language pairs')
  19. # optimization parameters for the square loss
  20. parser.add_argument('--epoch', default=2, type=int, help='nb of epochs for square loss')
  21. parser.add_argument('--niter', default=500, type=int, help='max number of iteration per epoch for square loss')
  22. parser.add_argument('--lr', default=0.1, type=float, help='learning rate for square loss')
  23. parser.add_argument('--bsz', default=500, type=int, help='batch size for square loss')
  24. # optimization parameters for the RCSLS loss
  25. parser.add_argument('--altepoch', default=100, type=int, help='nb of epochs for RCSLS loss')
  26. parser.add_argument('--altlr', default=25, type=float, help='learning rate for RCSLS loss')
  27. parser.add_argument("--altbsz", type=int, default=1000, help="batch size for RCSLS")
  28. args = parser.parse_args()
  29. ###### SPECIFIC FUNCTIONS ######
  30. def getknn(sc, x, y, k=10):
  31. sidx = np.argpartition(sc, -k, axis=1)[:, -k:]
  32. ytopk = y[sidx.flatten(), :]
  33. ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1])
  34. f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx])
  35. df = np.dot(ytopk.sum(1).T, x)
  36. return f / k, df / k
  37. def rcsls(Xi, Xj, Zi, Zj, R, knn=10):
  38. X_trans = np.dot(Xi, R.T)
  39. f = 2 * np.sum(X_trans * Xj)
  40. df = 2 * np.dot(Xj.T, Xi)
  41. fk0, dfk0 = getknn(np.dot(X_trans, Zj.T), Xi, Zj, knn)
  42. fk1, dfk1 = getknn(np.dot(np.dot(Zi, R.T), Xj.T).T, Xj, Zi, knn)
  43. f = f - fk0 -fk1
  44. df = df - dfk0 - dfk1.T
  45. return -f / Xi.shape[0], -df.T / Xi.shape[0]
  46. def GWmatrix(emb0):
  47. N = np.shape(emb0)[0]
  48. N2 = .5* np.linalg.norm(emb0, axis=1).reshape(1, N)
  49. C2 = np.tile(N2.transpose(), (1, N)) + np.tile(N2, (N, 1))
  50. C2 -= np.dot(emb0,emb0.T)
  51. return C2
  52. def gromov_wasserstein(x_src, x_tgt, C2):
  53. N = x_src.shape[0]
  54. C1 = GWmatrix(x_src)
  55. M = ot.gromov_wasserstein(C1,C2,np.ones(N),np.ones(N),'square_loss',epsilon=0.55,max_iter=100,tol=1e-4)
  56. return procrustes(np.dot(M,x_tgt), x_src)
  57. def align(EMB, TRANS, lglist, args):
  58. nmax, l = args.maxload, len(lglist)
  59. # create a list of language pairs to sample from
  60. # (default == higher probability to pick a language pair contianing the pivot)
  61. # if --uniform: uniform probability of picking a language pair
  62. samples = []
  63. for i in range(l):
  64. for j in range(l):
  65. if j == i :
  66. continue
  67. if j > 0 and args.uniform == False:
  68. samples.append((0,j))
  69. if i > 0 and args.uniform == False:
  70. samples.append((i,0))
  71. samples.append((i,j))
  72. # optimization of the l2 loss
  73. print('start optimizing L2 loss')
  74. lr0, bsz, nepoch, niter = args.lr, args.bsz, args.epoch, args.niter
  75. for epoch in range(nepoch):
  76. print("start epoch %d / %d"%(epoch+1, nepoch))
  77. ones = np.ones(bsz)
  78. f, fold, nb, lr = 0.0, 0.0, 0.0, lr0
  79. for it in range(niter):
  80. if it > 1 and f > fold + 1e-3:
  81. lr /= 2
  82. if lr < .05:
  83. break
  84. fold = f
  85. f, nb = 0.0, 0.0
  86. for k in range(100 * (l-1)):
  87. (i,j) = random.choice(samples)
  88. embi = EMB[i][np.random.permutation(nmax)[:bsz], :]
  89. embj = EMB[j][np.random.permutation(nmax)[:bsz], :]
  90. perm = ot.sinkhorn(ones, ones, np.linalg.multi_dot([embi, -TRANS[i], TRANS[j].T,embj.T]), reg = 0.025, stopThr = 1e-3)
  91. grad = np.linalg.multi_dot([embi.T, perm, embj])
  92. f -= np.trace(np.linalg.multi_dot([TRANS[i].T, grad, TRANS[j]])) / embi.shape[0]
  93. nb += 1
  94. if i > 0:
  95. TRANS[i] = proj_ortho(TRANS[i] + lr * np.dot(grad, TRANS[j]))
  96. if j > 0:
  97. TRANS[j] = proj_ortho(TRANS[j] + lr * np.dot(grad.transpose(), TRANS[i]))
  98. print("iter %d / %d - epoch %d - loss: %.5f lr: %.4f" % (it, niter, epoch+1, f / nb , lr))
  99. print("end of epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
  100. niter, bsz = max(int(niter/2),2), min(1000, bsz * 2)
  101. #end for epoch in range(nepoch):
  102. # optimization of the RCSLS loss
  103. print('start optimizing RCSLS loss')
  104. f, fold, nb, lr = 0.0, 0.0, 0.0, args.altlr
  105. for epoch in range(args.altepoch):
  106. if epoch > 1 and f-fold > -1e-4 * abs(fold):
  107. lr/= 2
  108. if lr < 1e-1:
  109. break
  110. fold = f
  111. f, nb = 0.0, 0.0
  112. for k in range(round(nmax / args.altbsz) * 10 * (l-1)):
  113. (i,j) = random.choice(samples)
  114. sgdidx = np.random.choice(nmax, size=args.altbsz, replace=False)
  115. embi = EMB[i][sgdidx, :]
  116. embj = EMB[j][:nmax, :]
  117. # crude alignment approximation:
  118. T = np.dot(TRANS[i], TRANS[j].T)
  119. scores = np.linalg.multi_dot([embi, T, embj.T])
  120. perm = np.zeros_like(scores)
  121. perm[np.arange(len(scores)), scores.argmax(1)] = 1
  122. embj = np.dot(perm, embj)
  123. # normalization over a subset of embeddings for speed up
  124. fi, grad = rcsls(embi, embj, embi, embj, T.T)
  125. f += fi
  126. nb += 1
  127. if i > 0:
  128. TRANS[i] = proj_ortho(TRANS[i] - lr * np.dot(grad, TRANS[j]))
  129. if j > 0:
  130. TRANS[j] = proj_ortho(TRANS[j] - lr * np.dot(grad.transpose(), TRANS[i]))
  131. print("epoch %d - loss: %.5f - lr: %.4f" % (epoch+1, f / max(nb,1), lr))
  132. #end for epoch in range(args.altepoch):
  133. return TRANS
  134. def convex_init(X, Y, niter=100, reg=0.05, apply_sqrt=False):
  135. n, d = X.shape
  136. K_X, K_Y = np.dot(X, X.T), np.dot(Y, Y.T)
  137. K_Y *= np.linalg.norm(K_X) / np.linalg.norm(K_Y)
  138. K2_X, K2_Y = np.dot(K_X, K_X), np.dot(K_Y, K_Y)
  139. P = np.ones([n, n]) / float(n)
  140. for it in range(1, niter + 1):
  141. G = np.dot(P, K2_X) + np.dot(K2_Y, P) - 2 * np.dot(K_Y, np.dot(P, K_X))
  142. q = ot.sinkhorn(np.ones(n), np.ones(n), G, reg, stopThr=1e-3)
  143. alpha = 2.0 / float(2.0 + it)
  144. P = alpha * q + (1.0 - alpha) * P
  145. return procrustes(np.dot(P, X), Y).T
  146. ###### MAIN ######
  147. lglist = args.lglist.split('-')
  148. l = len(lglist)
  149. # embs:
  150. EMB = {}
  151. for i in range(l):
  152. fn = args.embdir + '/wiki.' + lglist[i] + '.vec'
  153. _, vecs = load_vectors(fn, maxload=args.maxload)
  154. EMB[i] = vecs
  155. #init
  156. print("Computing initial bilingual apping with Gromov-Wasserstein...")
  157. TRANS={}
  158. maxinit = 2000
  159. emb0 = EMB[0][:maxinit,:]
  160. C0 = GWmatrix(emb0)
  161. TRANS[0] = np.eye(300)
  162. for i in range(1, l):
  163. print("init "+lglist[i])
  164. embi = EMB[i][:maxinit,:]
  165. TRANS[i] = gromov_wasserstein(embi, emb0, C0)
  166. # align
  167. align(EMB, TRANS, lglist, args)
  168. print('saving matrices in ' + args.outdir)
  169. languages=''.join(lglist)
  170. for i in range(l):
  171. save_matrix(args.outdir + '/W-' + languages + '-' + lglist[i], TRANS[i])