Przeglądaj źródła

Supervised alignement

Summary: Code for supervised alignment

Reviewed By: piotr-bojanowski

Differential Revision: D9554953

fbshipit-source-id: 12fa9677537e3baf551ed486107bec64ba35a359
Armand Joulin 7 lat temu
rodzic
commit
99f23802d4
5 zmienionych plików z 435 dodań i 0 usunięć
  1. 26 0
      alignment/README
  2. 144 0
      alignment/align.py
  3. 60 0
      alignment/eval.py
  4. 51 0
      alignment/example.sh
  5. 154 0
      alignment/utils.py

+ 26 - 0
alignment/README

@@ -0,0 +1,26 @@
+## Supervised Alignment of Word Embeddings
+
+This code aligns word embeddings from two languages with a bilingual lexicon. The details of our approach can be found in [1].
+
+The code is in Python 3 and requires [NumPy](http://www.numpy.org/).
+
+The script `example.sh` shows how to use this code to learn and evaluate a bilingual alignment of word embeddings.
+
+The word embeddings used in [1] can be found on the [fastText project page](fasttext.cc) and the supervised bilingual lexicons on the [MUSE project page](https://github.com/facebookresearch/MUSE).
+
+### Download
+
+Wikipedia fastText embeddings aligned with our method can be found [here](fasttext.cc/doc/en/aligned_vectors).
+
+### References
+
+If you use this code, please cite:
+
+[1] A. Joulin, P. Bojanowski, T. Mikolov, H. Jegou, E. Grave, [*Loss in Translation: Learning Bilingual Word Mapping with a Retrieval Criterion*](https://arxiv.org/abs/1804.07745)
+
+@InProceedings{joulin2018loss,
+    title={Loss in Translation: Learning Bilingual Word Mapping with a Retrieval Criterion},
+    author={Joulin, Armand and Bojanowski, Piotr and Mikolov, Tomas and J\'egou, Herv\'e and Grave, Edouard},
+    year={2018},
+    booktitle={Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing},
+}

+ 144 - 0
alignment/align.py

@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2018-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import argparse
+from utils import *
+
+parser = argparse.ArgumentParser(description='RCSLS for supervised word alignment')
+
+parser.add_argument("--src_emb", type=str, default='', help="Load source embeddings")
+parser.add_argument("--tgt_emb", type=str, default='', help="Load target embeddings")
+parser.add_argument('--center', action='store_true', help='whether to center embeddings or not')
+
+parser.add_argument("--dico_train", type=str, default='', help="train dictionary")
+parser.add_argument("--dico_test", type=str, default='', help="validation dictionary")
+
+parser.add_argument("--output", type=str, default='', help="where to save aligned embeddings")
+
+parser.add_argument("--knn", type=int, default=10, help="number of nearest neighbors in RCSL/CSLS")
+parser.add_argument("--maxneg", type=int, default=200000, help="Maximum number of negatives for the Extended RCSLS")
+parser.add_argument("--maxsup", type=int, default=-1, help="Maximum number of training examples")
+parser.add_argument("--maxload", type=int, default=200000, help="Maximum number of loaded vectors")
+
+parser.add_argument("--model", type=str, default="none", help="Set of constraints: spectral or none")
+parser.add_argument("--reg", type=float, default=0.0 , help='regularization parameters')
+
+parser.add_argument("--lr", type=float, default=1.0, help='learning rate')
+parser.add_argument("--niter", type=int, default=10, help='number of iterations')
+parser.add_argument('--sgd', action='store_true', help='use sgd')
+parser.add_argument("--batchsize", type=int, default=10000, help="batch size for sgd")
+
+params = parser.parse_args()
+
+###### SPECIFIC FUNCTIONS ######
+# functions specific to RCSLS
+# the rest of the functions are in utils.py
+
+def getknn(sc, x, y, k=10):
+    sidx = np.argpartition(sc, -k, axis=1)[:, -k:]
+    ytopk = y[sidx.flatten(), :]
+    ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1])
+    f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx])
+    df = np.dot(ytopk.sum(1).T, x)
+    return f / k, df / k
+
+
+def rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, knn=10):
+    X_trans = np.dot(X_src, R.T)
+    f = 2 * np.sum(X_trans * Y_tgt)
+    df = 2 * np.dot(Y_tgt.T, X_src)
+    fk0, dfk0 = getknn(np.dot(X_trans, Z_tgt.T), X_src, Z_tgt, knn)
+    fk1, dfk1 = getknn(np.dot(np.dot(Z_src, R.T), Y_tgt.T).T, Y_tgt, Z_src, knn)
+    f = f - fk0 -fk1
+    df = df - dfk0 - dfk1.T
+    return -f / X_src.shape[0], -df / X_src.shape[0]
+
+
+def proj_spectral(R):
+    U, s, V = np.linalg.svd(R)
+    s[s > 1] = 1
+    s[s < 0] = 0
+    return np.dot(U, np.dot(np.diag(s), V))
+
+
+###### MAIN ######
+
+# load word embeddings
+words_tgt, x_tgt = load_vectors(params.tgt_emb, maxload=params.maxload, center=params.center)
+words_src, x_src = load_vectors(params.src_emb, maxload=params.maxload, center=params.center)
+
+# load validation bilingual lexicon
+src2tgt, lexicon_size = load_lexicon(params.dico_test, words_src, words_tgt)
+
+# word --> vector indices
+idx_src = idx(words_src)
+idx_tgt = idx(words_tgt)
+
+# load train bilingual lexicon
+pairs = load_pairs(params.dico_train, idx_src, idx_tgt)
+if params.maxsup > 0 and params.maxsup < len(pairs):
+    pairs = pairs[:params.maxsup]
+
+# selecting training vector  pairs
+X_src, Y_tgt = select_vectors_from_pairs(x_src, x_tgt, pairs)
+
+# adding negatives for RCSLS
+Z_src = x_src[:params.maxneg, :]
+Z_tgt = x_tgt[:params.maxneg, :]
+
+# initialization:
+R = procrustes(X_src, Y_tgt)
+nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
+print("[init -- Procrustes] NN: %.4f"%(nnacc))
+sys.stdout.flush()
+
+# optimization
+fold, Rold = 0, []
+niter, lr = params.niter, params.lr
+
+for it in range(0, niter + 1):
+    if lr < 1e-4:
+        break
+
+    if params.sgd:
+        indices = np.random.choice(X_src.shape[0], size=params.batchsize, replace=False)
+        f, df = rcsls(X_src[indices, :], Y_tgt[indices, :], Z_src, Z_tgt, R, params.knn)
+    else:
+        f, df = rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, params.knn)
+
+    if params.reg > 0:
+        R *= (1 - lr * params.reg)
+    R -= lr * df
+    if params.model == "spectral":
+        R = proj_spectral(R)
+
+    print("[it=%d] f = %.4f" % (it, f))
+    sys.stdout.flush()
+
+    if f > fold and it > 0 and not params.sgd:
+        lr /= 2
+        f, R = fold, Rold
+
+    fold, Rold = f, R
+
+    if (it > 0 and it % 10 == 0) or it == niter:
+        nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
+        print("[it=%d] NN = %.4f - Coverage = %.4f" % (it, nnacc, len(src2tgt) / lexicon_size))
+
+nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size)
+print("[final] NN = %.4f - Coverage = %.4f" % (nnacc, len(src2tgt) / lexicon_size))
+
+if params.output != "":
+    print("Saving all aligned vectors at %s" % params.output)
+    words_full, x_full = load_vectors(params.src_emb, maxload=-1, center=params.center, verbose=False)
+    x = np.dot(x_full, R.T)
+    x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
+    save_vectors(params.output, x, words_full)
+    save_matrix(params.output + "-mat",  R)

+ 60 - 0
alignment/eval.py

@@ -0,0 +1,60 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2018-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import io
+import numpy as np
+import argparse
+from utils import *
+
+parser = argparse.ArgumentParser(description='Evaluation of word alignment')
+parser.add_argument("--src_emb", type=str, default='', help="Load source embeddings")
+parser.add_argument("--tgt_emb", type=str, default='', help="Load target embeddings")
+parser.add_argument('--center', action='store_true', help='whether to center embeddings or not')
+parser.add_argument("--src_mat", type=str, default='', help="Load source alignment matrix. If none given, the aligment matrix is the identity.")
+parser.add_argument("--tgt_mat", type=str, default='', help="Load target alignment matrix. If none given, the aligment matrix is the identity.")
+parser.add_argument("--dico_test", type=str, default='', help="test dictionary")
+parser.add_argument("--maxload", type=int, default=200000)
+parser.add_argument("--nomatch", action='store_true', help="no exact match in lexicon")
+params = parser.parse_args()
+
+
+###### SPECIFIC FUNCTIONS ######
+# function specific to evaluation
+# the rest of the functions are in utils.py
+
+def load_transform(fname, d1=300, d2=300):
+    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
+    R = np.zeros([d1, d2])
+    for i, line in enumerate(fin):
+        tokens = line.split(' ')
+        R[i, :] = np.array(tokens[0:d2], dtype=float)
+    return R
+
+
+###### MAIN ######
+
+print("Evaluation of alignment on %s" % params.dico_test)
+if params.nomatch:
+    print("running without exact string matches")
+
+words_tgt, x_tgt = load_vectors(params.tgt_emb, maxload=params.maxload, center=params.center)
+words_src, x_src = load_vectors(params.src_emb, maxload=params.maxload, center=params.center)
+
+if params.tgt_mat != "":
+    R_tgt = load_transform(params.tgt_mat)
+    x_tgt = np.dot(x_tgt, R_tgt)
+if params.src_mat != "":
+    R_src = load_transform(params.src_mat)
+    x_src = np.dot(x_src, R_src)
+
+src2tgt, lexicon_size = load_lexicon(params.dico_test, words_src, words_tgt)
+
+nnacc = compute_nn_accuracy(x_src, x_tgt, src2tgt, lexicon_size=lexicon_size)
+cslsproc = compute_csls_accuracy(x_src, x_tgt, src2tgt, lexicon_size=lexicon_size)
+print("NN = %.4f - CSLS = %.4f - Coverage = %.4f" % (nnacc, cslsproc, len(src2tgt) / lexicon_size))

+ 51 - 0
alignment/example.sh

@@ -0,0 +1,51 @@
+#!/bin/usr/env sh
+# Copyright (c) 2018-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+set -e
+s=${1:-en}
+t=${2:-es}
+echo "Example based on the ${s}->${t} alignment"
+
+if [ ! -d data/ ]; then
+  mkdir -p data;
+fi
+
+if [ ! -d res/ ]; then
+  mkdir -p res;
+fi
+
+dico_train=data/${s}-${t}.0-5000.txt
+if [ ! -f "${dico_train}" ]; then
+  DICO=$(basename -- "${dico_train}")
+  wget -c "https://s3.amazonaws.com/arrival/dictionaries/${DICO}" -P data/
+fi
+
+dico_test=data/${s}-${t}.5000-6500.txt
+if [ ! -f "${dico_test}" ]; then
+  DICO=$(basename -- "${dico_test}")
+  wget -c "https://s3.amazonaws.com/arrival/dictionaries/${DICO}" -P data/
+fi
+
+src_emb=data/wiki.${s}.vec
+if [ ! -f "${src_emb}" ]; then
+  DICO=$(basename -- "${src_emb}")
+  wget -c "https://s3-us-west-1.amazonaws.com/fasttext-vectors/${DICO}" -P data/
+fi
+
+tgt_emb=data/wiki.${t}.vec
+if [ ! -f "${tgt_emb}" ]; then
+  DICO=$(basename -- "${tgt_emb}")
+  wget -c "https://s3-us-west-1.amazonaws.com/fasttext-vectors/${DICO}" -P data/
+fi
+
+output=res/wiki.${s}-${t}.vec
+
+python3 align.py --src_emb "${src_emb}" --tgt_emb "${tgt_emb}" \
+  --dico_train "${dico_train}" --dico_test "${dico_test}" --output "${output}" \
+  --lr 25 --niter 10
+python3 eval.py --src_emb "${output}" --tgt_emb "${tgt_emb}" \
+  --dico_test "${dico_test}"

+ 154 - 0
alignment/utils.py

@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+# Copyright (c) 2018-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import io
+import numpy as np
+import collections
+
+
+def load_vectors(fname, maxload=200000, norm=True, center=False, verbose=True):
+    if verbose:
+        print("Loading vectors from %s" % fname)
+    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
+    n, d = map(int, fin.readline().split())
+    if maxload > 0:
+        n = min(n, maxload)
+    x = np.zeros([n, d])
+    words = []
+    for i, line in enumerate(fin):
+        if i >= n:
+            break
+        tokens = line.rstrip().split(' ')
+        words.append(tokens[0])
+        v = np.array(tokens[1:], dtype=float)
+        x[i, :] = v
+    if norm:
+        x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
+    if center:
+        x -= x.mean(axis=0)[np.newaxis, :]
+        x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
+    if verbose:
+        print("%d word vectors loaded" % (len(words)))
+    return words, x
+
+
+def idx(words):
+    w2i = {}
+    for i, w in enumerate(words):
+        if w not in w2i:
+            w2i[w] = i
+    return w2i
+
+
+def save_vectors(fname, x, words):
+    n, d = x.shape
+    fout = io.open(fname, 'w', encoding='utf-8')
+    fout.write(u"%d %d\n" % (n, d))
+    for i in range(n):
+        fout.write(words[i] + " " + " ".join(map(lambda a: "%.4f" % a, x[i, :])) + "\n")
+    fout.close()
+
+
+def save_matrix(fname, x):
+    n, d = x.shape
+    fout = io.open(fname, 'w', encoding='utf-8')
+    fout.write(u"%d %d\n" % (n, d))
+    for i in range(n):
+        fout.write(" ".join(map(lambda a: "%.4f" % a, x[i, :])) + "\n")
+    fout.close()
+
+
+def procrustes(X_src, Y_tgt):
+    U, s, V = np.linalg.svd(np.dot(Y_tgt.T, X_src))
+    return np.dot(U, V)
+
+
+def select_vectors_from_pairs(x_src, y_tgt, pairs):
+    n = len(pairs)
+    d = x_src.shape[1]
+    x = np.zeros([n, d])
+    y = np.zeros([n, d])
+    for k, ij in enumerate(pairs):
+        i, j = ij
+        x[k, :] = x_src[i, :]
+        y[k, :] = y_tgt[j, :]
+    return x, y
+
+
+def load_lexicon(filename, words_src, words_tgt, verbose=True):
+    f = io.open(filename, 'r', encoding='utf-8')
+    lexicon = collections.defaultdict(set)
+    idx_src , idx_tgt = idx(words_src), idx(words_tgt)
+    vocab = set()
+    for line in f:
+        word_src, word_tgt = line.split()
+        if word_src in idx_src and word_tgt in idx_tgt:
+            lexicon[idx_src[word_src]].add(idx_tgt[word_tgt])
+        vocab.add(word_src)
+    if verbose:
+        coverage = len(lexicon) / float(len(vocab))
+        print("Coverage of source vocab: %.4f" % (coverage))
+    return lexicon, float(len(vocab))
+
+
+def load_pairs(filename, idx_src, idx_tgt, verbose=True):
+    f = io.open(filename, 'r', encoding='utf-8')
+    pairs = []
+    tot = 0
+    for line in f:
+        a, b = line.rstrip().split(' ')
+        tot += 1
+        if a in idx_src and b in idx_tgt:
+            pairs.append((idx_src[a], idx_tgt[b]))
+    if verbose:
+        coverage = (1.0 * len(pairs)) / tot
+        print("Found pairs for training: %d - Total pairs in file: %d - Coverage of pairs: %.4f" % (len(pairs), tot, coverage))
+    return pairs
+
+
+def compute_nn_accuracy(x_src, x_tgt, lexicon, bsz=100, lexicon_size=-1):
+    if lexicon_size < 0:
+        lexicon_size = len(lexicon)
+    idx_src = list(lexicon.keys())
+    acc = 0.0
+    x_src /= np.linalg.norm(x_src, axis=1)[:, np.newaxis] + 1e-8
+    x_tgt /= np.linalg.norm(x_tgt, axis=1)[:, np.newaxis] + 1e-8
+    for i in range(0, len(idx_src), bsz):
+        e = min(i + bsz, len(idx_src))
+        scores = np.dot(x_tgt, x_src[idx_src[i:e]].T)
+        pred = scores.argmax(axis=0)
+        for j in range(i, e):
+            if pred[j - i] in lexicon[idx_src[j]]:
+                acc += 1.0
+    return acc / lexicon_size
+
+
+def compute_csls_accuracy(x_src, x_tgt, lexicon, lexicon_size=-1, k=10, bsz=1024):
+    if lexicon_size < 0:
+        lexicon_size = len(lexicon)
+    idx_src = list(lexicon.keys())
+
+    x_src /= np.linalg.norm(x_src, axis=1)[:, np.newaxis] + 1e-8
+    x_tgt /= np.linalg.norm(x_tgt, axis=1)[:, np.newaxis] + 1e-8
+
+    sr = x_src[list(idx_src)]
+    sc = np.dot(sr, x_tgt.T)
+    similarities = 2 * sc
+    sc2 = np.zeros(x_tgt.shape[0])
+    for i in range(0, x_tgt.shape[0], bsz):
+        j = min(i + bsz, x_tgt.shape[0])
+        sc_batch = np.dot(x_tgt[i:j, :], x_src.T)
+        dotprod = np.partition(sc_batch, -k, axis=1)[:, -k:]
+        sc2[i:j] = np.mean(dotprod, axis=1)
+    similarities -= sc2[np.newaxis, :]
+
+    nn = np.argmax(similarities, axis=1).tolist()
+    correct = 0.0
+    for k in range(0, len(lexicon)):
+        if nn[k] in lexicon[idx_src[k]]:
+            correct += 1.0
+    return correct / lexicon_size