|
|
@@ -12,7 +12,6 @@ from __future__ import unicode_literals
|
|
|
import fasttext_pybind as fasttext
|
|
|
import numpy as np
|
|
|
import multiprocessing
|
|
|
-import sys
|
|
|
from itertools import chain
|
|
|
|
|
|
loss_name = fasttext.loss_name
|
|
|
@@ -98,10 +97,26 @@ class _FastText:
|
|
|
|
|
|
def set_args(self, args=None):
|
|
|
if args:
|
|
|
- arg_names = ['lr', 'dim', 'ws', 'epoch', 'minCount',
|
|
|
- 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams',
|
|
|
- 'loss', 'bucket', 'thread', 'lrUpdateRate', 't',
|
|
|
- 'label', 'verbose', 'pretrainedVectors']
|
|
|
+ arg_names = [
|
|
|
+ "lr",
|
|
|
+ "dim",
|
|
|
+ "ws",
|
|
|
+ "epoch",
|
|
|
+ "minCount",
|
|
|
+ "minCountLabel",
|
|
|
+ "minn",
|
|
|
+ "maxn",
|
|
|
+ "neg",
|
|
|
+ "wordNgrams",
|
|
|
+ "loss",
|
|
|
+ "bucket",
|
|
|
+ "thread",
|
|
|
+ "lrUpdateRate",
|
|
|
+ "t",
|
|
|
+ "label",
|
|
|
+ "verbose",
|
|
|
+ "pretrainedVectors",
|
|
|
+ ]
|
|
|
for arg_name in arg_names:
|
|
|
setattr(self, arg_name, getattr(args, arg_name))
|
|
|
|
|
|
@@ -127,21 +142,18 @@ class _FastText:
|
|
|
whitespace (space, newline, tab, vertical tab) and the control
|
|
|
characters carriage return, formfeed and the null character.
|
|
|
"""
|
|
|
- if text.find('\n') != -1:
|
|
|
- raise ValueError(
|
|
|
- "predict processes one line at a time (remove \'\\n\')"
|
|
|
- )
|
|
|
+ if text.find("\n") != -1:
|
|
|
+ raise ValueError("predict processes one line at a time (remove '\\n')")
|
|
|
text += "\n"
|
|
|
dim = self.get_dimension()
|
|
|
b = fasttext.Vector(dim)
|
|
|
self.f.getSentenceVector(b, text)
|
|
|
return np.array(b)
|
|
|
|
|
|
- def get_nearest_neighbors(self, word, k=10, on_unicode_error='strict'):
|
|
|
+ def get_nearest_neighbors(self, word, k=10, on_unicode_error="strict"):
|
|
|
return self.f.getNN(word, k, on_unicode_error)
|
|
|
|
|
|
- def get_analogies(self, wordA, wordB, wordC, k=10,
|
|
|
- on_unicode_error='strict'):
|
|
|
+ def get_analogies(self, wordA, wordB, wordC, k=10, on_unicode_error="strict"):
|
|
|
return self.f.getAnalogies(wordA, wordB, wordC, k, on_unicode_error)
|
|
|
|
|
|
def get_word_id(self, word):
|
|
|
@@ -164,7 +176,7 @@ class _FastText:
|
|
|
"""
|
|
|
return self.f.getSubwordId(subword)
|
|
|
|
|
|
- def get_subwords(self, word, on_unicode_error='strict'):
|
|
|
+ def get_subwords(self, word, on_unicode_error="strict"):
|
|
|
"""
|
|
|
Given a word, get the subwords and their indicies.
|
|
|
"""
|
|
|
@@ -180,7 +192,7 @@ class _FastText:
|
|
|
self.f.getInputVector(b, ind)
|
|
|
return np.array(b)
|
|
|
|
|
|
- def predict(self, text, k=1, threshold=0.0, on_unicode_error='strict'):
|
|
|
+ def predict(self, text, k=1, threshold=0.0, on_unicode_error="strict"):
|
|
|
"""
|
|
|
Given a string, get a list of labels and a list of
|
|
|
corresponding probabilities. k controls the number
|
|
|
@@ -204,17 +216,16 @@ class _FastText:
|
|
|
"""
|
|
|
|
|
|
def check(entry):
|
|
|
- if entry.find('\n') != -1:
|
|
|
- raise ValueError(
|
|
|
- "predict processes one line at a time (remove \'\\n\')"
|
|
|
- )
|
|
|
+ if entry.find("\n") != -1:
|
|
|
+ raise ValueError("predict processes one line at a time (remove '\\n')")
|
|
|
entry += "\n"
|
|
|
return entry
|
|
|
|
|
|
if type(text) == list:
|
|
|
text = [check(entry) for entry in text]
|
|
|
all_labels, all_probs = self.f.multilinePredict(
|
|
|
- text, k, threshold, on_unicode_error)
|
|
|
+ text, k, threshold, on_unicode_error
|
|
|
+ )
|
|
|
|
|
|
return all_labels, all_probs
|
|
|
else:
|
|
|
@@ -245,7 +256,7 @@ class _FastText:
|
|
|
raise ValueError("Can't get quantized Matrix")
|
|
|
return np.array(self.f.getOutputMatrix())
|
|
|
|
|
|
- def get_words(self, include_freq=False, on_unicode_error='strict'):
|
|
|
+ def get_words(self, include_freq=False, on_unicode_error="strict"):
|
|
|
"""
|
|
|
Get the entire list of words of the dictionary optionally
|
|
|
including the frequency of the individual words. This
|
|
|
@@ -258,7 +269,7 @@ class _FastText:
|
|
|
else:
|
|
|
return pair[0]
|
|
|
|
|
|
- def get_labels(self, include_freq=False, on_unicode_error='strict'):
|
|
|
+ def get_labels(self, include_freq=False, on_unicode_error="strict"):
|
|
|
"""
|
|
|
Get the entire list of labels of the dictionary optionally
|
|
|
including the frequency of the individual labels. Unsupervised
|
|
|
@@ -276,17 +287,15 @@ class _FastText:
|
|
|
else:
|
|
|
return self.get_words(include_freq)
|
|
|
|
|
|
- def get_line(self, text, on_unicode_error='strict'):
|
|
|
+ def get_line(self, text, on_unicode_error="strict"):
|
|
|
"""
|
|
|
Split a line of text into words and labels. Labels must start with
|
|
|
the prefix used to create the model (__label__ by default).
|
|
|
"""
|
|
|
|
|
|
def check(entry):
|
|
|
- if entry.find('\n') != -1:
|
|
|
- raise ValueError(
|
|
|
- "get_line processes one line at a time (remove \'\\n\')"
|
|
|
- )
|
|
|
+ if entry.find("\n") != -1:
|
|
|
+ raise ValueError("get_line processes one line at a time (remove '\\n')")
|
|
|
entry += "\n"
|
|
|
return entry
|
|
|
|
|
|
@@ -332,7 +341,7 @@ class _FastText:
|
|
|
thread=None,
|
|
|
verbose=None,
|
|
|
dsub=2,
|
|
|
- qnorm=False
|
|
|
+ qnorm=False,
|
|
|
):
|
|
|
"""
|
|
|
Quantize the model reducing the size of the model and
|
|
|
@@ -352,8 +361,7 @@ class _FastText:
|
|
|
if input is None:
|
|
|
input = ""
|
|
|
self.f.quantize(
|
|
|
- input, qout, cutoff, retrain, epoch, lr, thread, verbose, dsub,
|
|
|
- qnorm
|
|
|
+ input, qout, cutoff, retrain, epoch, lr, thread, verbose, dsub, qnorm
|
|
|
)
|
|
|
|
|
|
def set_matrices(self, input_matrix, output_matrix):
|
|
|
@@ -361,8 +369,9 @@ class _FastText:
|
|
|
Set input and output matrices. This function assumes you know what you
|
|
|
are doing.
|
|
|
"""
|
|
|
- self.f.setMatrices(input_matrix.astype(np.float32),
|
|
|
- output_matrix.astype(np.float32))
|
|
|
+ self.f.setMatrices(
|
|
|
+ input_matrix.astype(np.float32), output_matrix.astype(np.float32)
|
|
|
+ )
|
|
|
|
|
|
@property
|
|
|
def words(self):
|
|
|
@@ -437,41 +446,41 @@ def load_model(path):
|
|
|
|
|
|
|
|
|
unsupervised_default = {
|
|
|
- 'model': "skipgram",
|
|
|
- 'lr': 0.05,
|
|
|
- 'dim': 100,
|
|
|
- 'ws': 5,
|
|
|
- 'epoch': 5,
|
|
|
- 'minCount': 5,
|
|
|
- 'minCountLabel': 0,
|
|
|
- 'minn': 3,
|
|
|
- 'maxn': 6,
|
|
|
- 'neg': 5,
|
|
|
- 'wordNgrams': 1,
|
|
|
- 'loss': "ns",
|
|
|
- 'bucket': 2000000,
|
|
|
- 'thread': multiprocessing.cpu_count() - 1,
|
|
|
- 'lrUpdateRate': 100,
|
|
|
- 't': 1e-4,
|
|
|
- 'label': "__label__",
|
|
|
- 'verbose': 2,
|
|
|
- 'pretrainedVectors': "",
|
|
|
- 'seed': 0,
|
|
|
- 'autotuneValidationFile': "",
|
|
|
- 'autotuneMetric': "f1",
|
|
|
- 'autotunePredictions': 1,
|
|
|
- 'autotuneDuration': 60 * 5, # 5 minutes
|
|
|
- 'autotuneModelSize': ""
|
|
|
+ "model": "skipgram",
|
|
|
+ "lr": 0.05,
|
|
|
+ "dim": 100,
|
|
|
+ "ws": 5,
|
|
|
+ "epoch": 5,
|
|
|
+ "minCount": 5,
|
|
|
+ "minCountLabel": 0,
|
|
|
+ "minn": 3,
|
|
|
+ "maxn": 6,
|
|
|
+ "neg": 5,
|
|
|
+ "wordNgrams": 1,
|
|
|
+ "loss": "ns",
|
|
|
+ "bucket": 2000000,
|
|
|
+ "thread": multiprocessing.cpu_count() - 1,
|
|
|
+ "lrUpdateRate": 100,
|
|
|
+ "t": 1e-4,
|
|
|
+ "label": "__label__",
|
|
|
+ "verbose": 2,
|
|
|
+ "pretrainedVectors": "",
|
|
|
+ "seed": 0,
|
|
|
+ "autotuneValidationFile": "",
|
|
|
+ "autotuneMetric": "f1",
|
|
|
+ "autotunePredictions": 1,
|
|
|
+ "autotuneDuration": 60 * 5, # 5 minutes
|
|
|
+ "autotuneModelSize": "",
|
|
|
}
|
|
|
|
|
|
|
|
|
def read_args(arg_list, arg_dict, arg_names, default_values):
|
|
|
param_map = {
|
|
|
- 'min_count': 'minCount',
|
|
|
- 'word_ngrams': 'wordNgrams',
|
|
|
- 'lr_update_rate': 'lrUpdateRate',
|
|
|
- 'label_prefix': 'label',
|
|
|
- 'pretrained_vectors': 'pretrainedVectors'
|
|
|
+ "min_count": "minCount",
|
|
|
+ "word_ngrams": "wordNgrams",
|
|
|
+ "lr_update_rate": "lrUpdateRate",
|
|
|
+ "label_prefix": "label",
|
|
|
+ "pretrained_vectors": "pretrainedVectors",
|
|
|
}
|
|
|
|
|
|
ret = {}
|
|
|
@@ -507,22 +516,45 @@ def train_supervised(*kargs, **kwargs):
|
|
|
repository such as the dataset pulled by classification-example.sh.
|
|
|
"""
|
|
|
supervised_default = unsupervised_default.copy()
|
|
|
- supervised_default.update({
|
|
|
- 'lr': 0.1,
|
|
|
- 'minCount': 1,
|
|
|
- 'minn': 0,
|
|
|
- 'maxn': 0,
|
|
|
- 'loss': "softmax",
|
|
|
- 'model': "supervised"
|
|
|
- })
|
|
|
-
|
|
|
- arg_names = ['input', 'lr', 'dim', 'ws', 'epoch', 'minCount',
|
|
|
- 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
|
|
|
- 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors',
|
|
|
- 'seed', 'autotuneValidationFile', 'autotuneMetric',
|
|
|
- 'autotunePredictions', 'autotuneDuration', 'autotuneModelSize']
|
|
|
- args, manually_set_args = read_args(kargs, kwargs, arg_names,
|
|
|
- supervised_default)
|
|
|
+ supervised_default.update(
|
|
|
+ {
|
|
|
+ "lr": 0.1,
|
|
|
+ "minCount": 1,
|
|
|
+ "minn": 0,
|
|
|
+ "maxn": 0,
|
|
|
+ "loss": "softmax",
|
|
|
+ "model": "supervised",
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ arg_names = [
|
|
|
+ "input",
|
|
|
+ "lr",
|
|
|
+ "dim",
|
|
|
+ "ws",
|
|
|
+ "epoch",
|
|
|
+ "minCount",
|
|
|
+ "minCountLabel",
|
|
|
+ "minn",
|
|
|
+ "maxn",
|
|
|
+ "neg",
|
|
|
+ "wordNgrams",
|
|
|
+ "loss",
|
|
|
+ "bucket",
|
|
|
+ "thread",
|
|
|
+ "lrUpdateRate",
|
|
|
+ "t",
|
|
|
+ "label",
|
|
|
+ "verbose",
|
|
|
+ "pretrainedVectors",
|
|
|
+ "seed",
|
|
|
+ "autotuneValidationFile",
|
|
|
+ "autotuneMetric",
|
|
|
+ "autotunePredictions",
|
|
|
+ "autotuneDuration",
|
|
|
+ "autotuneModelSize",
|
|
|
+ ]
|
|
|
+ args, manually_set_args = read_args(kargs, kwargs, arg_names, supervised_default)
|
|
|
a = _build_args(args, manually_set_args)
|
|
|
ft = _FastText(args=a)
|
|
|
fasttext.train(ft.f, a)
|
|
|
@@ -544,11 +576,29 @@ def train_unsupervised(*kargs, **kwargs):
|
|
|
dataset pulled by the example script word-vector-example.sh, which is
|
|
|
part of the fastText repository.
|
|
|
"""
|
|
|
- arg_names = ['input', 'model', 'lr', 'dim', 'ws', 'epoch', 'minCount',
|
|
|
- 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
|
|
|
- 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
|
|
|
- args, manually_set_args = read_args(kargs, kwargs, arg_names,
|
|
|
- unsupervised_default)
|
|
|
+ arg_names = [
|
|
|
+ "input",
|
|
|
+ "model",
|
|
|
+ "lr",
|
|
|
+ "dim",
|
|
|
+ "ws",
|
|
|
+ "epoch",
|
|
|
+ "minCount",
|
|
|
+ "minCountLabel",
|
|
|
+ "minn",
|
|
|
+ "maxn",
|
|
|
+ "neg",
|
|
|
+ "wordNgrams",
|
|
|
+ "loss",
|
|
|
+ "bucket",
|
|
|
+ "thread",
|
|
|
+ "lrUpdateRate",
|
|
|
+ "t",
|
|
|
+ "label",
|
|
|
+ "verbose",
|
|
|
+ "pretrainedVectors",
|
|
|
+ ]
|
|
|
+ args, manually_set_args = read_args(kargs, kwargs, arg_names, unsupervised_default)
|
|
|
a = _build_args(args, manually_set_args)
|
|
|
ft = _FastText(args=a)
|
|
|
fasttext.train(ft.f, a)
|
|
|
@@ -557,12 +607,18 @@ def train_unsupervised(*kargs, **kwargs):
|
|
|
|
|
|
|
|
|
def cbow(*kargs, **kwargs):
|
|
|
- raise Exception("`cbow` is not supported any more. Please use `train_unsupervised` with model=`cbow`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module")
|
|
|
+ raise Exception(
|
|
|
+ "`cbow` is not supported any more. Please use `train_unsupervised` with model=`cbow`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module"
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def skipgram(*kargs, **kwargs):
|
|
|
- raise Exception("`skipgram` is not supported any more. Please use `train_unsupervised` with model=`skipgram`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module")
|
|
|
+ raise Exception(
|
|
|
+ "`skipgram` is not supported any more. Please use `train_unsupervised` with model=`skipgram`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module"
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def supervised(*kargs, **kwargs):
|
|
|
- raise Exception("`supervised` is not supported any more. Please use `train_supervised`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module")
|
|
|
+ raise Exception(
|
|
|
+ "`supervised` is not supported any more. Please use `train_supervised`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module"
|
|
|
+ )
|