فهرست منبع

python bindings for nearest neighbors and analogies

Summary: FastText in C++ exposes the functions getNN and getAnalogies, but those were not available in python. This commit brings them in the python side as well.

Reviewed By: EdouardGrave

Differential Revision: D17093787

fbshipit-source-id: 369c0f3921291ddf73db1b0ddb1280649a44ba6d
Onur Çelebi 6 سال پیش
والد
کامیت
8acd5d360c
2فایلهای تغییر یافته به همراه62 افزوده شده و 43 حذف شده
  1. 50 43
      python/fasttext_module/fasttext/FastText.py
  2. 12 0
      python/fasttext_module/fasttext/pybind/fasttext_pybind.cc

+ 50 - 43
python/fasttext_module/fasttext/FastText.py

@@ -84,6 +84,12 @@ class _FastText(object):
         self.f.getSentenceVector(b, text)
         return np.array(b)
 
+    def get_nearest_neighbors(self, word, k=10):
+        return self.f.getNN(word, k)
+
+    def get_analogies(self, wordA, wordB, wordC, k=10):
+        return self.f.getAnalogies(wordA, wordB, wordC, k)
+
     def get_word_id(self, word):
         """
         Given a word, get the word id within the dictionary.
@@ -146,7 +152,8 @@ class _FastText(object):
 
         if type(text) == list:
             text = [check(entry) for entry in text]
-            predictions = self.f.multilinePredict(text, k, threshold, on_unicode_error)
+            predictions = self.f.multilinePredict(
+                text, k, threshold, on_unicode_error)
             dt = np.dtype([('probability', 'float64'), ('label', 'object')])
             result_as_pair = np.array(predictions, dtype=dt)
 
@@ -356,41 +363,41 @@ def load_model(path):
 
 
 unsupervised_default = {
-    'model' : "skipgram",
-    'lr' : 0.05,
-    'dim' : 100,
-    'ws' : 5,
-    'epoch' : 5,
-    'minCount' : 5,
-    'minCountLabel' : 0,
-    'minn' : 3,
-    'maxn' : 6,
-    'neg' : 5,
-    'wordNgrams' : 1,
-    'loss' : "ns",
-    'bucket' : 2000000,
-    'thread' : multiprocessing.cpu_count() - 1,
-    'lrUpdateRate' : 100,
-    't' : 1e-4,
-    'label' : "__label__",
-    'verbose' : 2,
-    'pretrainedVectors' : "",
-    'seed' : 0,
-    'autotuneValidationFile' : "",
-    'autotuneMetric' : "f1",
-    'autotunePredictions' : 1,
-    'autotuneDuration' : 60 * 5,  # 5 minutes
-    'autotuneModelSize' : ""
+    'model': "skipgram",
+    'lr': 0.05,
+    'dim': 100,
+    'ws': 5,
+    'epoch': 5,
+    'minCount': 5,
+    'minCountLabel': 0,
+    'minn': 3,
+    'maxn': 6,
+    'neg': 5,
+    'wordNgrams': 1,
+    'loss': "ns",
+    'bucket': 2000000,
+    'thread': multiprocessing.cpu_count() - 1,
+    'lrUpdateRate': 100,
+    't': 1e-4,
+    'label': "__label__",
+    'verbose': 2,
+    'pretrainedVectors': "",
+    'seed': 0,
+    'autotuneValidationFile': "",
+    'autotuneMetric': "f1",
+    'autotunePredictions': 1,
+    'autotuneDuration': 60 * 5,  # 5 minutes
+    'autotuneModelSize': ""
 }
 
 
 def read_args(arg_list, arg_dict, arg_names, default_values):
     param_map = {
-        'min_count' : 'minCount',
-        'word_ngrams' : 'wordNgrams',
-        'lr_update_rate' : 'lrUpdateRate',
-        'label_prefix' : 'label',
-        'pretrained_vectors' : 'pretrainedVectors'
+        'min_count': 'minCount',
+        'word_ngrams': 'wordNgrams',
+        'lr_update_rate': 'lrUpdateRate',
+        'label_prefix': 'label',
+        'pretrained_vectors': 'pretrainedVectors'
     }
 
     ret = {}
@@ -427,19 +434,19 @@ def train_supervised(*kargs, **kwargs):
     """
     supervised_default = unsupervised_default.copy()
     supervised_default.update({
-        'lr' : 0.1,
-        'minCount' : 1,
-        'minn' : 0,
-        'maxn' : 0,
-        'loss' : "softmax",
-        'model' : "supervised"
+        'lr': 0.1,
+        'minCount': 1,
+        'minn': 0,
+        'maxn': 0,
+        'loss': "softmax",
+        'model': "supervised"
     })
 
     arg_names = ['input', 'lr', 'dim', 'ws', 'epoch', 'minCount',
-        'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
-        'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors',
-        'seed', 'autotuneValidationFile', 'autotuneMetric',
-        'autotunePredictions', 'autotuneDuration', 'autotuneModelSize']
+                 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
+                 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors',
+                 'seed', 'autotuneValidationFile', 'autotuneMetric',
+                 'autotunePredictions', 'autotuneDuration', 'autotuneModelSize']
     args, manually_set_args = read_args(kargs, kwargs, arg_names,
                                         supervised_default)
     a = _build_args(args, manually_set_args)
@@ -463,8 +470,8 @@ def train_unsupervised(*kargs, **kwargs):
     part of the fastText repository.
     """
     arg_names = ['input', 'model', 'lr', 'dim', 'ws', 'epoch', 'minCount',
-        'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
-        'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
+                 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
+                 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
     args, manually_set_args = read_args(kargs, kwargs, arg_names,
                                         unsupervised_default)
     a = _build_args(args, manually_set_args)

+ 12 - 0
python/fasttext_module/fasttext/pybind/fasttext_pybind.cc

@@ -396,6 +396,18 @@ PYBIND11_MODULE(fasttext_pybind, m) {
           [](fasttext::FastText& m,
              fasttext::Vector& vec,
              const std::string word) { m.getWordVector(vec, word); })
+      .def(
+          "getNN",
+          [](fasttext::FastText& m, const std::string& word, int32_t k) {
+            return m.getNN(word, k);
+          })
+      .def(
+          "getAnalogies",
+          [](fasttext::FastText& m,
+             const std::string& wordA,
+             const std::string& wordB,
+             const std::string& wordC,
+             int32_t k) { return m.getAnalogies(k, wordA, wordB, wordC); })
       .def(
           "getSubwords",
           [](fasttext::FastText& m,