瀏覽代碼

deeplearning/fastText 2/2

Reviewed By: azad-meta

Differential Revision: D53908330

fbshipit-source-id: b2215f0522c32a82cd876633210befefe9317d76
generatedunixname89002005287564 2 年之前
父節點
當前提交
ae1fe80e9f
共有 3 個文件被更改,包括 217 次插入157 次删除
  1. 5 5
      python/benchmarks/get_word_vector.py
  2. 141 85
      python/fasttext_module/fasttext/FastText.py
  3. 71 67
      setup.py

+ 5 - 5
python/benchmarks/get_word_vector.py

@@ -13,14 +13,13 @@ from fasttext import load_model
 from fasttext import tokenize
 import sys
 import time
-import tempfile
 import argparse
 
 
 def get_word_vector(data, model):
     t1 = time.time()
     print("Reading")
-    with open(data, 'r') as f:
+    with open(data, "r") as f:
         tokens = tokenize(f.read())
     t2 = time.time()
     print("Read TIME: " + str(t2 - t1))
@@ -43,8 +42,9 @@ def get_word_vector(data, model):
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
-        description='Simple benchmark for get_word_vector.')
-    parser.add_argument('model', help='A model file to use for benchmarking.')
-    parser.add_argument('data', help='A data file to use for benchmarking.')
+        description="Simple benchmark for get_word_vector."
+    )
+    parser.add_argument("model", help="A model file to use for benchmarking.")
+    parser.add_argument("data", help="A data file to use for benchmarking.")
     args = parser.parse_args()
     get_word_vector(args.data, args.model)

+ 141 - 85
python/fasttext_module/fasttext/FastText.py

@@ -12,7 +12,6 @@ from __future__ import unicode_literals
 import fasttext_pybind as fasttext
 import numpy as np
 import multiprocessing
-import sys
 from itertools import chain
 
 loss_name = fasttext.loss_name
@@ -98,10 +97,26 @@ class _FastText:
 
     def set_args(self, args=None):
         if args:
-            arg_names = ['lr', 'dim', 'ws', 'epoch', 'minCount',
-                         'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams',
-                         'loss', 'bucket', 'thread', 'lrUpdateRate', 't',
-                         'label', 'verbose', 'pretrainedVectors']
+            arg_names = [
+                "lr",
+                "dim",
+                "ws",
+                "epoch",
+                "minCount",
+                "minCountLabel",
+                "minn",
+                "maxn",
+                "neg",
+                "wordNgrams",
+                "loss",
+                "bucket",
+                "thread",
+                "lrUpdateRate",
+                "t",
+                "label",
+                "verbose",
+                "pretrainedVectors",
+            ]
             for arg_name in arg_names:
                 setattr(self, arg_name, getattr(args, arg_name))
 
@@ -127,21 +142,18 @@ class _FastText:
         whitespace (space, newline, tab, vertical tab) and the control
         characters carriage return, formfeed and the null character.
         """
-        if text.find('\n') != -1:
-            raise ValueError(
-                "predict processes one line at a time (remove \'\\n\')"
-            )
+        if text.find("\n") != -1:
+            raise ValueError("predict processes one line at a time (remove '\\n')")
         text += "\n"
         dim = self.get_dimension()
         b = fasttext.Vector(dim)
         self.f.getSentenceVector(b, text)
         return np.array(b)
 
-    def get_nearest_neighbors(self, word, k=10, on_unicode_error='strict'):
+    def get_nearest_neighbors(self, word, k=10, on_unicode_error="strict"):
         return self.f.getNN(word, k, on_unicode_error)
 
-    def get_analogies(self, wordA, wordB, wordC, k=10,
-                      on_unicode_error='strict'):
+    def get_analogies(self, wordA, wordB, wordC, k=10, on_unicode_error="strict"):
         return self.f.getAnalogies(wordA, wordB, wordC, k, on_unicode_error)
 
     def get_word_id(self, word):
@@ -164,7 +176,7 @@ class _FastText:
         """
         return self.f.getSubwordId(subword)
 
-    def get_subwords(self, word, on_unicode_error='strict'):
+    def get_subwords(self, word, on_unicode_error="strict"):
         """
         Given a word, get the subwords and their indicies.
         """
@@ -180,7 +192,7 @@ class _FastText:
         self.f.getInputVector(b, ind)
         return np.array(b)
 
-    def predict(self, text, k=1, threshold=0.0, on_unicode_error='strict'):
+    def predict(self, text, k=1, threshold=0.0, on_unicode_error="strict"):
         """
         Given a string, get a list of labels and a list of
         corresponding probabilities. k controls the number
@@ -204,17 +216,16 @@ class _FastText:
         """
 
         def check(entry):
-            if entry.find('\n') != -1:
-                raise ValueError(
-                    "predict processes one line at a time (remove \'\\n\')"
-                )
+            if entry.find("\n") != -1:
+                raise ValueError("predict processes one line at a time (remove '\\n')")
             entry += "\n"
             return entry
 
         if type(text) == list:
             text = [check(entry) for entry in text]
             all_labels, all_probs = self.f.multilinePredict(
-                text, k, threshold, on_unicode_error)
+                text, k, threshold, on_unicode_error
+            )
 
             return all_labels, all_probs
         else:
@@ -245,7 +256,7 @@ class _FastText:
             raise ValueError("Can't get quantized Matrix")
         return np.array(self.f.getOutputMatrix())
 
-    def get_words(self, include_freq=False, on_unicode_error='strict'):
+    def get_words(self, include_freq=False, on_unicode_error="strict"):
         """
         Get the entire list of words of the dictionary optionally
         including the frequency of the individual words. This
@@ -258,7 +269,7 @@ class _FastText:
         else:
             return pair[0]
 
-    def get_labels(self, include_freq=False, on_unicode_error='strict'):
+    def get_labels(self, include_freq=False, on_unicode_error="strict"):
         """
         Get the entire list of labels of the dictionary optionally
         including the frequency of the individual labels. Unsupervised
@@ -276,17 +287,15 @@ class _FastText:
         else:
             return self.get_words(include_freq)
 
-    def get_line(self, text, on_unicode_error='strict'):
+    def get_line(self, text, on_unicode_error="strict"):
         """
         Split a line of text into words and labels. Labels must start with
         the prefix used to create the model (__label__ by default).
         """
 
         def check(entry):
-            if entry.find('\n') != -1:
-                raise ValueError(
-                    "get_line processes one line at a time (remove \'\\n\')"
-                )
+            if entry.find("\n") != -1:
+                raise ValueError("get_line processes one line at a time (remove '\\n')")
             entry += "\n"
             return entry
 
@@ -332,7 +341,7 @@ class _FastText:
         thread=None,
         verbose=None,
         dsub=2,
-        qnorm=False
+        qnorm=False,
     ):
         """
         Quantize the model reducing the size of the model and
@@ -352,8 +361,7 @@ class _FastText:
         if input is None:
             input = ""
         self.f.quantize(
-            input, qout, cutoff, retrain, epoch, lr, thread, verbose, dsub,
-            qnorm
+            input, qout, cutoff, retrain, epoch, lr, thread, verbose, dsub, qnorm
         )
 
     def set_matrices(self, input_matrix, output_matrix):
@@ -361,8 +369,9 @@ class _FastText:
         Set input and output matrices. This function assumes you know what you
         are doing.
         """
-        self.f.setMatrices(input_matrix.astype(np.float32),
-                           output_matrix.astype(np.float32))
+        self.f.setMatrices(
+            input_matrix.astype(np.float32), output_matrix.astype(np.float32)
+        )
 
     @property
     def words(self):
@@ -437,41 +446,41 @@ def load_model(path):
 
 
 unsupervised_default = {
-    'model': "skipgram",
-    'lr': 0.05,
-    'dim': 100,
-    'ws': 5,
-    'epoch': 5,
-    'minCount': 5,
-    'minCountLabel': 0,
-    'minn': 3,
-    'maxn': 6,
-    'neg': 5,
-    'wordNgrams': 1,
-    'loss': "ns",
-    'bucket': 2000000,
-    'thread': multiprocessing.cpu_count() - 1,
-    'lrUpdateRate': 100,
-    't': 1e-4,
-    'label': "__label__",
-    'verbose': 2,
-    'pretrainedVectors': "",
-    'seed': 0,
-    'autotuneValidationFile': "",
-    'autotuneMetric': "f1",
-    'autotunePredictions': 1,
-    'autotuneDuration': 60 * 5,  # 5 minutes
-    'autotuneModelSize': ""
+    "model": "skipgram",
+    "lr": 0.05,
+    "dim": 100,
+    "ws": 5,
+    "epoch": 5,
+    "minCount": 5,
+    "minCountLabel": 0,
+    "minn": 3,
+    "maxn": 6,
+    "neg": 5,
+    "wordNgrams": 1,
+    "loss": "ns",
+    "bucket": 2000000,
+    "thread": multiprocessing.cpu_count() - 1,
+    "lrUpdateRate": 100,
+    "t": 1e-4,
+    "label": "__label__",
+    "verbose": 2,
+    "pretrainedVectors": "",
+    "seed": 0,
+    "autotuneValidationFile": "",
+    "autotuneMetric": "f1",
+    "autotunePredictions": 1,
+    "autotuneDuration": 60 * 5,  # 5 minutes
+    "autotuneModelSize": "",
 }
 
 
 def read_args(arg_list, arg_dict, arg_names, default_values):
     param_map = {
-        'min_count': 'minCount',
-        'word_ngrams': 'wordNgrams',
-        'lr_update_rate': 'lrUpdateRate',
-        'label_prefix': 'label',
-        'pretrained_vectors': 'pretrainedVectors'
+        "min_count": "minCount",
+        "word_ngrams": "wordNgrams",
+        "lr_update_rate": "lrUpdateRate",
+        "label_prefix": "label",
+        "pretrained_vectors": "pretrainedVectors",
     }
 
     ret = {}
@@ -507,22 +516,45 @@ def train_supervised(*kargs, **kwargs):
     repository such as the dataset pulled by classification-example.sh.
     """
     supervised_default = unsupervised_default.copy()
-    supervised_default.update({
-        'lr': 0.1,
-        'minCount': 1,
-        'minn': 0,
-        'maxn': 0,
-        'loss': "softmax",
-        'model': "supervised"
-    })
-
-    arg_names = ['input', 'lr', 'dim', 'ws', 'epoch', 'minCount',
-                 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
-                 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors',
-                 'seed', 'autotuneValidationFile', 'autotuneMetric',
-                 'autotunePredictions', 'autotuneDuration', 'autotuneModelSize']
-    args, manually_set_args = read_args(kargs, kwargs, arg_names,
-                                        supervised_default)
+    supervised_default.update(
+        {
+            "lr": 0.1,
+            "minCount": 1,
+            "minn": 0,
+            "maxn": 0,
+            "loss": "softmax",
+            "model": "supervised",
+        }
+    )
+
+    arg_names = [
+        "input",
+        "lr",
+        "dim",
+        "ws",
+        "epoch",
+        "minCount",
+        "minCountLabel",
+        "minn",
+        "maxn",
+        "neg",
+        "wordNgrams",
+        "loss",
+        "bucket",
+        "thread",
+        "lrUpdateRate",
+        "t",
+        "label",
+        "verbose",
+        "pretrainedVectors",
+        "seed",
+        "autotuneValidationFile",
+        "autotuneMetric",
+        "autotunePredictions",
+        "autotuneDuration",
+        "autotuneModelSize",
+    ]
+    args, manually_set_args = read_args(kargs, kwargs, arg_names, supervised_default)
     a = _build_args(args, manually_set_args)
     ft = _FastText(args=a)
     fasttext.train(ft.f, a)
@@ -544,11 +576,29 @@ def train_unsupervised(*kargs, **kwargs):
     dataset pulled by the example script word-vector-example.sh, which is
     part of the fastText repository.
     """
-    arg_names = ['input', 'model', 'lr', 'dim', 'ws', 'epoch', 'minCount',
-                 'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
-                 'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
-    args, manually_set_args = read_args(kargs, kwargs, arg_names,
-                                        unsupervised_default)
+    arg_names = [
+        "input",
+        "model",
+        "lr",
+        "dim",
+        "ws",
+        "epoch",
+        "minCount",
+        "minCountLabel",
+        "minn",
+        "maxn",
+        "neg",
+        "wordNgrams",
+        "loss",
+        "bucket",
+        "thread",
+        "lrUpdateRate",
+        "t",
+        "label",
+        "verbose",
+        "pretrainedVectors",
+    ]
+    args, manually_set_args = read_args(kargs, kwargs, arg_names, unsupervised_default)
     a = _build_args(args, manually_set_args)
     ft = _FastText(args=a)
     fasttext.train(ft.f, a)
@@ -557,12 +607,18 @@ def train_unsupervised(*kargs, **kwargs):
 
 
 def cbow(*kargs, **kwargs):
-    raise Exception("`cbow` is not supported any more. Please use `train_unsupervised` with model=`cbow`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module")
+    raise Exception(
+        "`cbow` is not supported any more. Please use `train_unsupervised` with model=`cbow`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module"
+    )
 
 
 def skipgram(*kargs, **kwargs):
-    raise Exception("`skipgram` is not supported any more. Please use `train_unsupervised` with model=`skipgram`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module")
+    raise Exception(
+        "`skipgram` is not supported any more. Please use `train_unsupervised` with model=`skipgram`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module"
+    )
 
 
 def supervised(*kargs, **kwargs):
-    raise Exception("`supervised` is not supported any more. Please use `train_supervised`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module")
+    raise Exception(
+        "`supervised` is not supported any more. Please use `train_supervised`. For more information please refer to https://fasttext.cc/blog/2019/06/25/blog-post.html#2-you-were-using-the-unofficial-fasttext-module"
+    )

+ 71 - 67
setup.py

@@ -21,33 +21,36 @@ import subprocess
 import platform
 import io
 
-__version__ = '0.9.2'
+__version__ = "0.9.2"
 FASTTEXT_SRC = "src"
 
 # Based on https://github.com/pybind/python_example
 
+
 class get_pybind_include:
     """Helper class to determine the pybind11 include path
 
     The purpose of this class is to postpone importing pybind11
     until it is actually installed, so that the ``get_include()``
-    method can be invoked. """
+    method can be invoked."""
 
     def __init__(self, user=False):
         try:
-            import pybind11
+            pass
         except ImportError:
-            if subprocess.call([sys.executable, '-m', 'pip', 'install', 'pybind11']):
-                raise RuntimeError('pybind11 install failed.')
+            if subprocess.call([sys.executable, "-m", "pip", "install", "pybind11"]):
+                raise RuntimeError("pybind11 install failed.")
 
         self.user = user
 
     def __str__(self):
         import pybind11
+
         return pybind11.get_include(self.user)
 
+
 try:
-    coverage_index = sys.argv.index('--coverage')
+    coverage_index = sys.argv.index("--coverage")
 except ValueError:
     coverage = False
 else:
@@ -55,7 +58,7 @@ else:
     coverage = True
 
 fasttext_src_files = map(str, os.listdir(FASTTEXT_SRC))
-fasttext_src_cc = list(filter(lambda x: x.endswith('.cc'), fasttext_src_files))
+fasttext_src_cc = list(filter(lambda x: x.endswith(".cc"), fasttext_src_files))
 
 fasttext_src_cc = list(
     map(lambda x: str(os.path.join(FASTTEXT_SRC, x)), fasttext_src_cc)
@@ -63,10 +66,11 @@ fasttext_src_cc = list(
 
 ext_modules = [
     Extension(
-        str('fasttext_pybind'),
+        str("fasttext_pybind"),
         [
-            str('python/fasttext_module/fasttext/pybind/fasttext_pybind.cc'),
-        ] + fasttext_src_cc,
+            str("python/fasttext_module/fasttext/pybind/fasttext_pybind.cc"),
+        ]
+        + fasttext_src_cc,
         include_dirs=[
             # Path to pybind11 headers
             get_pybind_include(),
@@ -74,9 +78,12 @@ ext_modules = [
             # Path to fasttext source code
             FASTTEXT_SRC,
         ],
-        language='c++',
-        extra_compile_args=["-O0 -fno-inline -fprofile-arcs -pthread -march=native" if coverage else
-                            "-O3 -funroll-loops -pthread -march=native"],
+        language="c++",
+        extra_compile_args=[
+            "-O0 -fno-inline -fprofile-arcs -pthread -march=native"
+            if coverage
+            else "-O3 -funroll-loops -pthread -march=native"
+        ],
     ),
 ]
 
@@ -88,8 +95,9 @@ def has_flag(compiler, flags):
     the specified compiler.
     """
     import tempfile
-    with tempfile.NamedTemporaryFile('w', suffix='.cpp') as f:
-        f.write('int main (int argc, char **argv) { return 0; }')
+
+    with tempfile.NamedTemporaryFile("w", suffix=".cpp") as f:
+        f.write("int main (int argc, char **argv) { return 0; }")
         try:
             compiler.compile([f.name], extra_postargs=flags)
         except setuptools.distutils.errors.CompileError:
@@ -98,57 +106,53 @@ def has_flag(compiler, flags):
 
 
 def cpp_flag(compiler):
-    """Return the -std=c++17 compiler flag.
-    """
-    standards = ['-std=c++17']
+    """Return the -std=c++17 compiler flag."""
+    standards = ["-std=c++17"]
     for standard in standards:
         if has_flag(compiler, [standard]):
             return standard
-    raise RuntimeError(
-        'Unsupported compiler -- at least C++17 support '
-        'is needed!'
-    )
+    raise RuntimeError("Unsupported compiler -- at least C++17 support " "is needed!")
 
 
 class BuildExt(build_ext):
     """A custom build extension for adding compiler-specific options."""
+
     c_opts = {
-        'msvc': ['/EHsc'],
-        'unix': [],
+        "msvc": ["/EHsc"],
+        "unix": [],
     }
 
     def build_extensions(self):
-        if sys.platform == 'darwin':
-            mac_osx_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]))
-            os.environ['MACOSX_DEPLOYMENT_TARGET'] = str(mac_osx_version)
-            all_flags = ['-stdlib=libc++', '-mmacosx-version-min=10.7']
+        if sys.platform == "darwin":
+            mac_osx_version = float(".".join(platform.mac_ver()[0].split(".")[:2]))
+            os.environ["MACOSX_DEPLOYMENT_TARGET"] = str(mac_osx_version)
+            all_flags = ["-stdlib=libc++", "-mmacosx-version-min=10.7"]
             if has_flag(self.compiler, [all_flags[0]]):
-                self.c_opts['unix'] += [all_flags[0]]
+                self.c_opts["unix"] += [all_flags[0]]
             elif has_flag(self.compiler, all_flags):
-                self.c_opts['unix'] += all_flags
+                self.c_opts["unix"] += all_flags
             else:
                 raise RuntimeError(
-                    'libc++ is needed! Failed to compile with {} and {}.'.
-                    format(" ".join(all_flags), all_flags[0])
+                    "libc++ is needed! Failed to compile with {} and {}.".format(
+                        " ".join(all_flags), all_flags[0]
+                    )
                 )
         ct = self.compiler.compiler_type
         opts = self.c_opts.get(ct, [])
         extra_link_args = []
 
         if coverage:
-            coverage_option = '--coverage'
+            coverage_option = "--coverage"
             opts.append(coverage_option)
             extra_link_args.append(coverage_option)
 
-        if ct == 'unix':
+        if ct == "unix":
             opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version())
             opts.append(cpp_flag(self.compiler))
-            if has_flag(self.compiler, ['-fvisibility=hidden']):
-                opts.append('-fvisibility=hidden')
-        elif ct == 'msvc':
-            opts.append(
-                '/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()
-            )
+            if has_flag(self.compiler, ["-fvisibility=hidden"]):
+                opts.append("-fvisibility=hidden")
+        elif ct == "msvc":
+            opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version())
         for ext in self.extensions:
             ext.extra_compile_args = opts
             ext.extra_link_args = extra_link_args
@@ -160,43 +164,43 @@ def _get_readme():
     Use pandoc to generate rst from md.
     pandoc --from=markdown --to=rst --output=python/README.rst python/README.md
     """
-    with io.open("python/README.rst", encoding='utf-8') as fid:
+    with io.open("python/README.rst", encoding="utf-8") as fid:
         return fid.read()
 
 
 setup(
-    name='fasttext',
+    name="fasttext",
     version=__version__,
-    author='Onur Celebi',
-    author_email='[email protected]',
-    description='fasttext Python bindings',
+    author="Onur Celebi",
+    author_email="[email protected]",
+    description="fasttext Python bindings",
     long_description=_get_readme(),
     ext_modules=ext_modules,
-    url='https://github.com/facebookresearch/fastText',
-    license='MIT',
+    url="https://github.com/facebookresearch/fastText",
+    license="MIT",
     classifiers=[
-        'Development Status :: 3 - Alpha',
-        'Intended Audience :: Developers',
-        'Intended Audience :: Science/Research',
-        'License :: OSI Approved :: MIT License',
-        'Programming Language :: Python :: 2.7',
-        'Programming Language :: Python :: 3.4',
-        'Programming Language :: Python :: 3.5',
-        'Programming Language :: Python :: 3.6',
-        'Topic :: Software Development',
-        'Topic :: Scientific/Engineering',
-        'Operating System :: Microsoft :: Windows',
-        'Operating System :: POSIX',
-        'Operating System :: Unix',
-        'Operating System :: MacOS',
+        "Development Status :: 3 - Alpha",
+        "Intended Audience :: Developers",
+        "Intended Audience :: Science/Research",
+        "License :: OSI Approved :: MIT License",
+        "Programming Language :: Python :: 2.7",
+        "Programming Language :: Python :: 3.4",
+        "Programming Language :: Python :: 3.5",
+        "Programming Language :: Python :: 3.6",
+        "Topic :: Software Development",
+        "Topic :: Scientific/Engineering",
+        "Operating System :: Microsoft :: Windows",
+        "Operating System :: POSIX",
+        "Operating System :: Unix",
+        "Operating System :: MacOS",
     ],
-    install_requires=['pybind11>=2.2', "setuptools >= 0.7.0", "numpy"],
-    cmdclass={'build_ext': BuildExt},
+    install_requires=["pybind11>=2.2", "setuptools >= 0.7.0", "numpy"],
+    cmdclass={"build_ext": BuildExt},
     packages=[
-        str('fasttext'),
-        str('fasttext.util'),
-        str('fasttext.tests'),
+        str("fasttext"),
+        str("fasttext.util"),
+        str("fasttext.tests"),
     ],
-    package_dir={str(''): str('python/fasttext_module')},
+    package_dir={str(""): str("python/fasttext_module")},
     zip_safe=False,
 )