Quellcode durchsuchen

autotune

Summary: This commit implements automatic hyperparameters optimization (autotune). When `-autotune-validation` argument is provided, fastText will search the hyperparameters that gives the best f1-score on this validation file.

Reviewed By: EdouardGrave

Differential Revision: D17050215

fbshipit-source-id: 13333181ee1162147f94085cac4ebabdd9b80d67
Onur Çelebi vor 6 Jahren
Ursprung
Commit
d0bf803d4c

+ 2 - 0
CMakeLists.txt

@@ -19,6 +19,7 @@ set(CMAKE_CXX_FLAGS " -pthread -std=c++11 -funroll-loops -O3 -march=native")
 
 set(HEADER_FILES
     src/args.h
+    src/autotune.h
     src/densematrix.h
     src/dictionary.h
     src/fasttext.h
@@ -34,6 +35,7 @@ set(HEADER_FILES
 
 set(SOURCE_FILES
     src/args.cc
+    src/autotune.cc
     src/densematrix.cc
     src/dictionary.cc
     src/fasttext.cc

+ 4 - 1
Makefile

@@ -8,7 +8,7 @@
 
 CXX = c++
 CXXFLAGS = -pthread -std=c++11 -march=native
-OBJS = args.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
+OBJS = args.o autotune.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o utils.o meter.o fasttext.o
 INCLUDES = -I.
 
 opt: CXXFLAGS += -O3 -funroll-loops -DNDEBUG
@@ -23,6 +23,9 @@ debug: fasttext
 args.o: src/args.cc src/args.h
 	$(CXX) $(CXXFLAGS) -c src/args.cc
 
+autotune.o: src/autotune.cc src/autotune.h
+	$(CXX) $(CXXFLAGS) -c src/autotune.cc
+
 matrix.o: src/matrix.cc src/matrix.h
 	$(CXX) $(CXXFLAGS) -c src/matrix.cc
 

+ 76 - 0
docs/autotune.md

@@ -0,0 +1,76 @@
+---
+id: autotune
+title: Automatic hyperparameter optimization
+---
+
+As we saw in [the tutorial](/docs/en/supervised-tutorial.html#more-epochs-and-larger-learning-rate), finding the best hyperparameters is crucial for building efficient models. However, searching the best hyperparameters manually is difficult. Parameters are dependent and the effect of each parameter vary from one dataset to another.
+
+FastText's autotune feature allows you to find automatically the best hyperparameters for your dataset.
+
+# How to use it
+
+In order to activate hyperparameter optimization, we must provide a validation file with the `-autotune-validation` argument.
+
+For example, using the same data as our [tutorial example](/docs/en/supervised-tutorial.html#our-first-classifier), the autotune can be used in the following way:
+
+```sh
+>> ./fasttext supervised -input cooking.train -output model_cooking -autotune-validation cooking.valid
+```
+
+Then, fastText will search the hyperparameters that gives the best f1-score on `cooking.valid` file:
+```sh
+Progress: 100.0% Trials:   27 Best score:  0.406763 ETA:   0h 0m 0s
+```
+
+Now we can test the obtained model with:
+```sh
+>> ./fasttext test model_cooking.bin data/cooking.valid
+N       3000
+P@1     0.666
+R@1     0.288
+```
+
+By default, the search will take 5 minutes. You can set the timeout in seconds with the `-autotune-duration` argument. For example, if you want to set the limit to 10 minutes:
+
+```sh
+>> ./fasttext supervised -input cooking.train -output model_cooking -autotune-validation cooking.valid -autotune-duration 600
+```
+
+While autotuning, fastText displays the best f1-score found so far. If we decide to stop the tuning before the time limit, we can send one `SIGINT` signal (via `CTLR-C` for example). FastText will then finish the current training, and retrain with the best parameters found so far.
+
+
+
+# Constrain model size
+
+As you may know, fastText can compress the model with [quantization](/docs/en/cheatsheet.html#quantization). However, this compression task comes with its own [hyperparameters](/docs/en/options.html) (`-cutoff`, `-retrain`, `-qnorm`, `-qout`, `-dsub`) that have a consequence on the accuracy and the size of the final model.
+
+Fortunately, autotune can also find the hyperparameters for this compression task while targeting the desired model size. To this end, we can set the `-autotune-modelsize` argument:
+
+```sh
+>> ./fasttext supervised -input cooking.train -output model_cooking -autotune-validation cooking.valid -autotune-modelsize 2M
+```
+
+This will produce a `.ftz` file with the best accuracy having the desired size:
+```sh
+>> ls -la model_cooking.ftz
+-rw-r--r--. 1 celebio users 1990862 Aug 25 05:39 model_cooking.ftz
+>> ./fasttext test model_cooking.ftz data/cooking.valid
+N       3000
+P@1     0.57
+R@1     0.246
+```
+
+
+# How to set the optimization metric?
+
+By default, autotune will test the validation file you provide, exactly the same way as `./fasttext test model_cooking.bin cooking.valid` and try to optimize to get the highest [f1-score](https://en.wikipedia.org/wiki/F1_score).
+
+But, if we want to optimize the score of a specific label, say `__label__baking`, we can set the `-autotune-metric` argument:
+
+```sh
+>> ./fasttext supervised -input cooking.train -output model_cooking -autotune-validation cooking.valid -autotune-metric f1:__label__baking
+```
+
+This is equivalent to manually optimize the f1-score we get when we test with `./fasttext test-label model_cooking.bin cooking.valid | grep __label__baking` in command line.
+
+Sometimes, you may be interested in predicting more than one label. For example, if you were optimizing the hyperparameters manually to get the best score to predict two labels, you would test with `./fasttext test model_cooking.bin cooking.valid 2`. You can also tell autotune to optimize the parameters by testing two labels with the `-autotune-predictions` argument.

+ 23 - 0
docs/cheatsheet.md

@@ -64,3 +64,26 @@ All other commands such as test also work with this model
 ```bash
 $ ./fasttext test model.ftz test.txt
 ```
+
+## Autotune
+
+Activate hyperparameter optimization with `-autotune-validation` argument:
+
+```bash
+$ ./fasttext supervised -input train.txt -output model -autotune-validation valid.txt
+```
+
+Set timeout (in seconds):
+```bash
+$ ./fasttext supervised -input train.txt -output model -autotune-validation valid.txt -autotune-duration 600
+```
+
+Constrain the final model size:
+```bash
+$ ./fasttext supervised -input train.txt -output model -autotune-validation valid.txt -autotune-modelsize 2M
+```
+
+
+
+
+

+ 10 - 0
docs/options.md

@@ -48,3 +48,13 @@ The following arguments are mandatory:
 
 Defaults may vary by mode. (Word-representation modes `skipgram` and `cbow` use a default `-minCount` of 5.)
 
+
+Hyperparameter optimization (autotune) is activated when you provide a validation file with `-autotune-validation` argument.
+```text
+The following arguments are for autotune:
+  -autotune-validation            validation file to be used for evaluation
+  -autotune-metric                metric objective {f1, f1:labelname} [f1]
+  -autotune-predictions           number of predictions used for evaluation  [1]
+  -autotune-duration              maximum duration in seconds [300]
+  -autotune-modelsize             constraint model file size [] (empty = do not quantize)
+```

+ 24 - 7
python/fasttext_module/fasttext/FastText.py

@@ -325,12 +325,17 @@ def _parse_loss_string(string):
         raise ValueError("Unrecognized loss name")
 
 
-def _build_args(args):
+def _build_args(args, manually_set_args):
     args["model"] = _parse_model_string(args["model"])
     args["loss"] = _parse_loss_string(args["loss"])
+    if type(args["autotuneModelSize"]) == int:
+        args["autotuneModelSize"] = str(args["autotuneModelSize"])
+
     a = fasttext.args()
     for (k, v) in args.items():
         setattr(a, k, v)
+        if k in manually_set_args:
+            a.setManual(k)
     a.output = ""  # User should use save_model
     a.saveOutput = 0  # Never use this
     if a.wordNgrams <= 1 and a.maxn == 0:
@@ -370,6 +375,12 @@ unsupervised_default = {
     'label' : "__label__",
     'verbose' : 2,
     'pretrainedVectors' : "",
+    'seed' : 0,
+    'autotuneValidationFile' : "",
+    'autotuneMetric' : "f1",
+    'autotunePredictions' : 1,
+    'autotuneDuration' : 60 * 5,  # 5 minutes
+    'autotuneModelSize' : ""
 }
 
 
@@ -383,6 +394,7 @@ def read_args(arg_list, arg_dict, arg_names, default_values):
     }
 
     ret = {}
+    manually_set_args = set()
     for (arg_name, arg_value) in chain(zip(arg_names, arg_list), arg_dict.items()):
         if arg_name in param_map:
             arg_name = param_map[arg_name]
@@ -391,12 +403,13 @@ def read_args(arg_list, arg_dict, arg_names, default_values):
         if arg_name in ret:
             raise TypeError("multiple values for argument '%s'" % arg_name)
         ret[arg_name] = arg_value
+        manually_set_args.add(arg_name)
 
     for (arg_name, arg_value) in default_values.items():
         if arg_name not in ret:
             ret[arg_name] = arg_value
 
-    return ret
+    return (ret, manually_set_args)
 
 
 def train_supervised(*kargs, **kwargs):
@@ -424,9 +437,12 @@ def train_supervised(*kargs, **kwargs):
 
     arg_names = ['input', 'lr', 'dim', 'ws', 'epoch', 'minCount',
         'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
-        'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
-    params = read_args(kargs, kwargs, arg_names, supervised_default)
-    a = _build_args(params)
+        'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors',
+        'seed', 'autotuneValidationFile', 'autotuneMetric',
+        'autotunePredictions', 'autotuneDuration', 'autotuneModelSize']
+    args, manually_set_args = read_args(kargs, kwargs, arg_names,
+                                        supervised_default)
+    a = _build_args(args, manually_set_args)
     ft = _FastText(args=a)
     fasttext.train(ft.f, a)
     return ft
@@ -449,8 +465,9 @@ def train_unsupervised(*kargs, **kwargs):
     arg_names = ['input', 'model', 'lr', 'dim', 'ws', 'epoch', 'minCount',
         'minCountLabel', 'minn', 'maxn', 'neg', 'wordNgrams', 'loss', 'bucket',
         'thread', 'lrUpdateRate', 't', 'label', 'verbose', 'pretrainedVectors']
-    params = read_args(kargs, kwargs, arg_names, unsupervised_default)
-    a = _build_args(params)
+    args, manually_set_args = read_args(kargs, kwargs, arg_names,
+                                        unsupervised_default)
+    a = _build_args(args, manually_set_args)
     ft = _FastText(args=a)
     fasttext.train(ft.f, a)
     return ft

+ 28 - 2
python/fasttext_module/fasttext/pybind/fasttext_pybind.cc

@@ -7,6 +7,7 @@
  */
 
 #include <args.h>
+#include <autotune.h>
 #include <densematrix.h>
 #include <fasttext.h>
 #include <pybind11/pybind11.h>
@@ -93,12 +94,24 @@ PYBIND11_MODULE(fasttext_pybind, m) {
       .def_readwrite("verbose", &fasttext::Args::verbose)
       .def_readwrite("pretrainedVectors", &fasttext::Args::pretrainedVectors)
       .def_readwrite("saveOutput", &fasttext::Args::saveOutput)
+      .def_readwrite("seed", &fasttext::Args::seed)
 
       .def_readwrite("qout", &fasttext::Args::qout)
       .def_readwrite("retrain", &fasttext::Args::retrain)
       .def_readwrite("qnorm", &fasttext::Args::qnorm)
       .def_readwrite("cutoff", &fasttext::Args::cutoff)
-      .def_readwrite("dsub", &fasttext::Args::dsub);
+      .def_readwrite("dsub", &fasttext::Args::dsub)
+
+      .def_readwrite(
+          "autotuneValidationFile", &fasttext::Args::autotuneValidationFile)
+      .def_readwrite("autotuneMetric", &fasttext::Args::autotuneMetric)
+      .def_readwrite(
+          "autotunePredictions", &fasttext::Args::autotunePredictions)
+      .def_readwrite("autotuneDuration", &fasttext::Args::autotuneDuration)
+      .def_readwrite("autotuneModelSize", &fasttext::Args::autotuneModelSize)
+      .def("setManual", [](fasttext::Args& m, const std::string& argName) {
+        m.setManual(argName);
+      });
 
   py::enum_<fasttext::model_name>(m, "model_name")
       .value("cbow", fasttext::model_name::cbow)
@@ -113,9 +126,22 @@ PYBIND11_MODULE(fasttext_pybind, m) {
       .value("ova", fasttext::loss_name::ova)
       .export_values();
 
+  py::enum_<fasttext::metric_name>(m, "metric_name")
+      .value("f1score", fasttext::metric_name::f1score)
+      .value("labelf1score", fasttext::metric_name::labelf1score)
+      .export_values();
+
   m.def(
       "train",
-      [](fasttext::FastText& ft, fasttext::Args& a) { ft.train(a); },
+      [](fasttext::FastText& ft, fasttext::Args& a) {
+        if (a.hasAutotune()) {
+          fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(
+              &ft, [](fasttext::FastText*) {}));
+          autotune.train(a);
+        } else {
+          ft.train(a);
+        }
+      },
       py::call_guard<py::gil_scoped_release>());
 
   py::class_<fasttext::Vector>(m, "Vector", py::buffer_protocol())

+ 124 - 2
src/args.cc

@@ -10,8 +10,10 @@
 
 #include <stdlib.h>
 
+#include <cassert>
 #include <iostream>
 #include <stdexcept>
+#include <unordered_map>
 
 namespace fasttext {
 
@@ -36,12 +38,19 @@ Args::Args() {
   verbose = 2;
   pretrainedVectors = "";
   saveOutput = false;
+  seed = 0;
 
   qout = false;
   retrain = false;
   qnorm = false;
   cutoff = 0;
   dsub = 2;
+
+  autotuneValidationFile = "";
+  autotuneMetric = "f1";
+  autotunePredictions = 1;
+  autotuneDuration = 60 * 5; // 5 minutes
+  autotuneModelSize = "";
 }
 
 std::string Args::lossToString(loss_name ln) const {
@@ -78,6 +87,16 @@ std::string Args::modelToString(model_name mn) const {
   return "Unknown model name!"; // should never happen
 }
 
+std::string Args::metricToString(metric_name mn) const {
+  switch (mn) {
+    case metric_name::f1score:
+      return "f1score";
+    case metric_name::labelf1score:
+      return "labelf1score";
+  }
+  return "Unknown metric name!"; // should never happen
+}
+
 void Args::parseArgs(const std::vector<std::string>& args) {
   std::string command(args[1]);
   if (command == "supervised") {
@@ -97,6 +116,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
       exit(EXIT_FAILURE);
     }
     try {
+      setManual(args[ai].substr(1));
+
       if (args[ai] == "-h") {
         std::cerr << "Here is the help! Usage:" << std::endl;
         printHelp();
@@ -157,6 +178,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
       } else if (args[ai] == "-saveOutput") {
         saveOutput = true;
         ai--;
+      } else if (args[ai] == "-seed") {
+        seed = std::stoi(args.at(ai + 1));
       } else if (args[ai] == "-qnorm") {
         qnorm = true;
         ai--;
@@ -170,6 +193,18 @@ void Args::parseArgs(const std::vector<std::string>& args) {
         cutoff = std::stoi(args.at(ai + 1));
       } else if (args[ai] == "-dsub") {
         dsub = std::stoi(args.at(ai + 1));
+      } else if (args[ai] == "-autotune-validation") {
+        autotuneValidationFile = std::string(args.at(ai + 1));
+      } else if (args[ai] == "-autotune-metric") {
+        autotuneMetric = std::string(args.at(ai + 1));
+        getAutotuneMetric(); // throws exception if not able to parse
+        getAutotuneMetricLabel(); // throws exception if not able to parse
+      } else if (args[ai] == "-autotune-predictions") {
+        autotunePredictions = std::stoi(args.at(ai + 1));
+      } else if (args[ai] == "-autotune-duration") {
+        autotuneDuration = std::stoi(args.at(ai + 1));
+      } else if (args[ai] == "-autotune-modelsize") {
+        autotuneModelSize = std::string(args.at(ai + 1));
       } else {
         std::cerr << "Unknown argument: " << args[ai] << std::endl;
         printHelp();
@@ -195,6 +230,7 @@ void Args::printHelp() {
   printBasicHelp();
   printDictionaryHelp();
   printTrainingHelp();
+  printAutotuneHelp();
   printQuantizationHelp();
 }
 
@@ -235,11 +271,27 @@ void Args::printTrainingHelp() {
       << "  -neg                number of negatives sampled [" << neg << "]\n"
       << "  -loss               loss function {ns, hs, softmax, one-vs-all} ["
       << lossToString(loss) << "]\n"
-      << "  -thread             number of threads (set to 1 to ensure reproducible results) [" << thread << "]\n"
+      << "  -thread             number of threads (set to 1 to ensure reproducible results) ["
+      << thread << "]\n"
       << "  -pretrainedVectors  pretrained word vectors for supervised learning ["
       << pretrainedVectors << "]\n"
       << "  -saveOutput         whether output params should be saved ["
-      << boolToString(saveOutput) << "]\n";
+      << boolToString(saveOutput) << "]\n"
+      << "  -seed               random generator seed  [" << seed << "]\n";
+}
+
+void Args::printAutotuneHelp() {
+  std::cerr
+      << "\nThe following arguments are for autotune:\n"
+      << "  -autotune-validation            validation file to be used for evaluation\n"
+      << "  -autotune-metric                metric objective {f1, f1:labelname} ["
+      << autotuneMetric << "]\n"
+      << "  -autotune-predictions           number of predictions used for evaluation  ["
+      << autotunePredictions << "]\n"
+      << "  -autotune-duration              maximum duration in seconds ["
+      << autotuneDuration << "]\n"
+      << "  -autotune-modelsize             constraint model file size ["
+      << autotuneModelSize << "] (empty = do not quantize)\n";
 }
 
 void Args::printQuantizationHelp() {
@@ -317,4 +369,74 @@ void Args::dump(std::ostream& out) const {
       << " " << t << std::endl;
 }
 
+bool Args::hasAutotune() const {
+  return !autotuneValidationFile.empty();
+}
+
+bool Args::isManual(const std::string& argName) const {
+  return (manualArgs_.count(argName) != 0);
+}
+
+void Args::setManual(const std::string& argName) {
+  manualArgs_.emplace(argName);
+}
+
+metric_name Args::getAutotuneMetric() const {
+  if (autotuneMetric.substr(0, 3) == "f1:") {
+    return metric_name::labelf1score;
+  } else if (autotuneMetric == "f1") {
+    return metric_name::f1score;
+  }
+  throw std::runtime_error("Unknown metric : " + autotuneMetric);
+}
+
+std::string Args::getAutotuneMetricLabel() const {
+  if (getAutotuneMetric() == metric_name::labelf1score) {
+    std::string label = autotuneMetric.substr(3);
+    if (label.empty()) {
+      throw std::runtime_error("Empty metric label : " + autotuneMetric);
+    }
+    return label;
+  }
+  return std::string();
+}
+
+int64_t Args::getAutotuneModelSize() const {
+  std::string modelSize = autotuneModelSize;
+  if (modelSize.empty()) {
+    return Args::kUnlimitedModelSize;
+  }
+  std::unordered_map<char, int> units = {
+      {'k', 1000},
+      {'K', 1000},
+      {'m', 1000000},
+      {'M', 1000000},
+      {'g', 1000000000},
+      {'G', 1000000000},
+  };
+  uint64_t multiplier = 1;
+  char lastCharacter = modelSize.back();
+  if (units.count(lastCharacter)) {
+    multiplier = units[lastCharacter];
+    modelSize = modelSize.substr(0, modelSize.size() - 1);
+  }
+  uint64_t size = 0;
+  size_t nonNumericCharacter = 0;
+  bool parseError = false;
+  try {
+    size = std::stol(modelSize, &nonNumericCharacter);
+  } catch (std::invalid_argument&) {
+    parseError = true;
+  }
+  if (!parseError && nonNumericCharacter != modelSize.size()) {
+    parseError = true;
+  }
+  if (parseError) {
+    throw std::invalid_argument(
+        "Unable to parse model size " + autotuneModelSize);
+  }
+
+  return size * multiplier;
+}
+
 } // namespace fasttext

+ 21 - 1
src/args.h

@@ -11,18 +11,21 @@
 #include <istream>
 #include <ostream>
 #include <string>
+#include <unordered_set>
 #include <vector>
 
 namespace fasttext {
 
 enum class model_name : int { cbow = 1, sg, sup };
 enum class loss_name : int { hs = 1, ns, softmax, ova };
+enum class metric_name : int { f1score = 1, labelf1score };
 
 class Args {
  protected:
-  std::string lossToString(loss_name) const;
   std::string boolToString(bool) const;
   std::string modelToString(model_name) const;
+  std::string metricToString(metric_name) const;
+  std::unordered_set<std::string> manualArgs_;
 
  public:
   Args();
@@ -48,6 +51,7 @@ class Args {
   int verbose;
   std::string pretrainedVectors;
   bool saveOutput;
+  int seed;
 
   bool qout;
   bool retrain;
@@ -55,14 +59,30 @@ class Args {
   size_t cutoff;
   size_t dsub;
 
+  std::string autotuneValidationFile;
+  std::string autotuneMetric;
+  int autotunePredictions;
+  int autotuneDuration;
+  std::string autotuneModelSize;
+
   void parseArgs(const std::vector<std::string>& args);
   void printHelp();
   void printBasicHelp();
   void printDictionaryHelp();
   void printTrainingHelp();
+  void printAutotuneHelp();
   void printQuantizationHelp();
   void save(std::ostream&);
   void load(std::istream&);
   void dump(std::ostream&) const;
+  bool hasAutotune() const;
+  bool isManual(const std::string& argName) const;
+  void setManual(const std::string& argName);
+  std::string lossToString(loss_name) const;
+  metric_name getAutotuneMetric() const;
+  std::string getAutotuneMetricLabel() const;
+  int64_t getAutotuneModelSize() const;
+
+  static constexpr double kUnlimitedModelSize = -1.0;
 };
 } // namespace fasttext

+ 458 - 0
src/autotune.cc

@@ -0,0 +1,458 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include "autotune.h"
+
+#include <algorithm>
+#include <cassert>
+#include <csignal>
+#include <functional>
+#include <iomanip>
+#include <iostream>
+#include <numeric>
+#include <random>
+#include <thread>
+
+#define LOG_VAL(name, val)                        \
+  if (autotuneArgs.verbose > 2) {                 \
+    std::cout << #name " = " << val << std::endl; \
+  }
+#define LOG_VAL_NAN(name, val)                      \
+  if (autotuneArgs.verbose > 2) {                   \
+    if (std::isnan(val)) {                          \
+      std::cout << #name " = NaN" << std::endl;     \
+    } else {                                        \
+      std::cout << #name " = " << val << std::endl; \
+    }                                               \
+  }
+
+namespace {
+
+std::function<void()> interruptSignalHandler;
+
+void signalHandler(int signal) {
+  if (signal == SIGINT) {
+    interruptSignalHandler();
+  }
+}
+
+class ElapsedTimeMarker {
+  std::chrono::steady_clock::time_point start_;
+
+ public:
+  ElapsedTimeMarker() {
+    start_ = std::chrono::steady_clock::now();
+  }
+  double getElapsed() {
+    return fasttext::utils::getDuration(
+        start_, std::chrono::steady_clock::now());
+  }
+};
+
+} // namespace
+
+namespace fasttext {
+
+template <typename T>
+T getArgGauss(
+    T val,
+    std::minstd_rand& rng,
+    double startSigma,
+    double endSigma,
+    double t,
+    bool linear) {
+  T returnValue;
+  const double stddev = startSigma -
+      ((startSigma - endSigma) / 0.5) *
+          std::min(0.5, std::max((t - 0.25), 0.0));
+
+  std::normal_distribution<double> normal(0.0, stddev);
+
+  const double coeff = normal(rng);
+  double updateCoeff = 0.0;
+
+  if (linear) {
+    updateCoeff = coeff;
+    returnValue = static_cast<T>(updateCoeff + val);
+  } else {
+    updateCoeff = std::pow(2.0, coeff);
+    returnValue = static_cast<T>(updateCoeff * val);
+  }
+
+  return returnValue;
+}
+
+template <typename T>
+T updateArgGauss(
+    T val,
+    T min,
+    T max,
+    double startSigma,
+    double endSigma,
+    double t,
+    bool linear,
+    std::minstd_rand& rng) {
+  T retVal = getArgGauss(val, rng, startSigma, endSigma, t, linear);
+  if (retVal > max) {
+    retVal = max;
+  }
+  if (retVal < min) {
+    retVal = min;
+  }
+  return retVal;
+}
+
+AutotuneStrategy::AutotuneStrategy(
+    const Args& originalArgs,
+    std::minstd_rand::result_type seed)
+    : bestArgs_(),
+      maxDuration_(originalArgs.autotuneDuration),
+      rng_(seed),
+      trials_(0),
+      bestMinnIndex_(0),
+      bestDsubExponent_(1),
+      bestNonzeroBucket_(2000000) {
+  minnChoices_ = {0, 2, 3};
+  updateBest(originalArgs);
+}
+
+Args AutotuneStrategy::ask(double elapsed) {
+  const double t = std::min(1.0, elapsed / maxDuration_);
+  trials_++;
+
+  if (trials_ == 1) {
+    return bestArgs_;
+  }
+
+  Args args = bestArgs_;
+
+  if (!args.isManual("epoch")) {
+    args.epoch = updateArgGauss(args.epoch, 1, 100, 2.8, 2.5, t, false, rng_);
+  }
+  if (!args.isManual("lr")) {
+    args.lr = updateArgGauss(args.lr, 0.01, 5.0, 1.9, 1.0, t, false, rng_);
+  };
+  if (!args.isManual("dim")) {
+    args.dim = updateArgGauss(args.dim, 1, 1000, 1.4, 0.3, t, false, rng_);
+  }
+  if (!args.isManual("wordNgrams")) {
+    args.wordNgrams =
+        updateArgGauss(args.wordNgrams, 1, 5, 4.3, 2.4, t, true, rng_);
+  }
+  if (!args.isManual("dsub")) {
+    int dsubExponent =
+        updateArgGauss(bestDsubExponent_, 1, 4, 2.0, 1.0, t, true, rng_);
+    args.dsub = (1 << dsubExponent);
+  }
+  if (!args.isManual("minn")) {
+    int minnIndex = updateArgGauss(
+        bestMinnIndex_,
+        0,
+        static_cast<int>(minnChoices_.size() - 1),
+        4.0,
+        1.4,
+        t,
+        true,
+        rng_);
+    args.minn = minnChoices_[minnIndex];
+  }
+  if (!args.isManual("maxn")) {
+    if (args.minn == 0) {
+      args.maxn = 0;
+    } else {
+      args.maxn = args.minn + 3;
+    }
+  }
+  if (!args.isManual("bucket")) {
+    if (args.wordNgrams <= 1 && args.maxn == 0) {
+      args.bucket = 0;
+    } else {
+      int nonZeroBucket = updateArgGauss(
+          bestNonzeroBucket_, 10000, 10000000, 2.0, 1.5, t, false, rng_);
+      args.bucket = nonZeroBucket;
+    }
+  }
+  if (!args.isManual("loss")) {
+    args.loss = loss_name::softmax;
+  }
+
+  return args;
+}
+
+int AutotuneStrategy::getIndex(int val, const std::vector<int>& choices) {
+  auto found = std::find(choices.begin(), choices.end(), val);
+  int ind = 0;
+  if (found != choices.end()) {
+    ind = std::distance(choices.begin(), found);
+  }
+  return ind;
+}
+
+void AutotuneStrategy::updateBest(const Args& args) {
+  bestArgs_ = args;
+  bestMinnIndex_ = getIndex(args.minn, minnChoices_);
+  bestDsubExponent_ = log2(args.dsub);
+  if (args.bucket != 0) {
+    bestNonzeroBucket_ = args.bucket;
+  }
+}
+
+Autotune::Autotune(const std::shared_ptr<FastText>& fastText)
+    : fastText_(fastText),
+      elapsed_(0.),
+      bestScore_(0.),
+      trials_(0),
+      sizeConstraintFailed_(0),
+      continueTraining_(false),
+      strategy_(),
+      timer_() {}
+
+void Autotune::printInfo(double maxDuration) {
+  double progress = elapsed_ * 100 / maxDuration;
+  progress = std::min(progress, 100.0);
+
+  std::cerr << "\r";
+  std::cerr << std::fixed;
+  std::cerr << "Progress: ";
+  std::cerr << std::setprecision(1) << std::setw(5) << progress << "%";
+  std::cerr << " Trials: " << std::setw(4) << trials_;
+  std::cerr << " Best score: " << std::setw(9) << std::setprecision(6);
+  if (bestScore_ == Autotune::kUnknownBestScore) {
+    std::cerr << "unknown";
+  } else {
+    std::cerr << bestScore_;
+  }
+  std::cerr << " ETA: "
+            << utils::ClockPrint(std::max(maxDuration - elapsed_, 0.0));
+  std::cerr << std::flush;
+}
+
+void Autotune::timer(
+    const std::chrono::steady_clock::time_point& start,
+    double maxDuration) {
+  elapsed_ = 0.0;
+  while (keepTraining(maxDuration)) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(500));
+    elapsed_ = utils::getDuration(start, std::chrono::steady_clock::now());
+    printInfo(maxDuration);
+  }
+  abort();
+}
+
+bool Autotune::keepTraining(double maxDuration) const {
+  return continueTraining_ && elapsed_ < maxDuration;
+}
+
+void Autotune::abort() {
+  if (continueTraining_) {
+    continueTraining_ = false;
+    fastText_->abort();
+  }
+}
+
+void Autotune::startTimer(const Args& args) {
+  std::chrono::steady_clock::time_point start =
+      std::chrono::steady_clock::now();
+  timer_ = std::thread([=]() { timer(start, args.autotuneDuration); });
+  bestScore_ = Autotune::kUnknownBestScore;
+  trials_ = 0;
+  continueTraining_ = true;
+
+  auto previousSignalHandler = std::signal(SIGINT, signalHandler);
+  interruptSignalHandler = [&]() {
+    std::signal(SIGINT, previousSignalHandler);
+    std::cerr << std::endl << "Aborting autotune..." << std::endl;
+    abort();
+  };
+}
+
+double Autotune::getMetricScore(
+    Meter& meter,
+    const metric_name& metricName,
+    const std::string& metricLabel) const {
+  double score = 0.0;
+  if (metricName == metric_name::f1score) {
+    score = meter.f1Score();
+  } else if (metricName == metric_name::labelf1score) {
+    int32_t labelId = fastText_->getDictionary()->getId(metricLabel);
+    if (labelId == -1) {
+      throw std::runtime_error("Unknown autotune metric label");
+    }
+    labelId = labelId - fastText_->getDictionary()->nwords();
+    score = meter.f1Score(labelId);
+  } else {
+    throw std::runtime_error("Unknown metric");
+  }
+  return score;
+}
+
+void Autotune::printArgs(const Args& args, const Args& autotuneArgs) {
+  LOG_VAL(epoch, args.epoch)
+  LOG_VAL(lr, args.lr)
+  LOG_VAL(dim, args.dim)
+  LOG_VAL(minCount, args.minCount)
+  LOG_VAL(wordNgrams, args.wordNgrams)
+  LOG_VAL(minn, args.minn)
+  LOG_VAL(maxn, args.maxn)
+  LOG_VAL(bucket, args.bucket)
+  LOG_VAL(dsub, args.dsub)
+  LOG_VAL(loss, args.lossToString(args.loss))
+}
+
+int Autotune::getCutoffForFileSize(
+    bool qout,
+    bool qnorm,
+    int dsub,
+    int64_t fileSize) const {
+  int64_t outModelSize = 0;
+  const int64_t outM = fastText_->getOutputMatrix()->size(0);
+  const int64_t outN = fastText_->getOutputMatrix()->size(1);
+  if (qout) {
+    const int64_t outputPqSize = 16 + 4 * (outN * (1 << 8));
+    outModelSize =
+        21 + (outM * ((outN + 2 - 1) / 2)) + outputPqSize + (qnorm ? outM : 0);
+  } else {
+    outModelSize = 16 + 4 * (outM * outN);
+  }
+  const int64_t dim = fastText_->getInputMatrix()->size(1);
+
+  int target = (fileSize - (107) - 4 * (1 << 8) * dim - outModelSize);
+  int cutoff = target / ((dim + dsub - 1) / dsub + (qnorm ? 1 : 0) + 10);
+
+  return std::max(cutoff, kCutoffLimit);
+}
+
+bool Autotune::quantize(Args& args, const Args& autotuneArgs) {
+  if (autotuneArgs.getAutotuneModelSize() == Args::kUnlimitedModelSize) {
+    return true;
+  }
+  auto outputSize = fastText_->getOutputMatrix()->size(0);
+
+  args.qnorm = true;
+  args.qout = (outputSize >= kCutoffLimit);
+  args.retrain = true;
+  args.cutoff = getCutoffForFileSize(
+      args.qout, args.qnorm, args.dsub, autotuneArgs.getAutotuneModelSize());
+  LOG_VAL(cutoff, args.cutoff);
+  if (args.cutoff == kCutoffLimit) {
+    return false;
+  }
+  fastText_->quantize(args);
+
+  return true;
+}
+
+void Autotune::printSkippedArgs(const Args& autotuneArgs) {
+  std::unordered_set<std::string> argsToCheck = {"epoch",
+                                                 "lr",
+                                                 "dim",
+                                                 "wordNgrams",
+                                                 "loss",
+                                                 "bucket",
+                                                 "minn",
+                                                 "maxn",
+                                                 "dsub"};
+  for (const auto& arg : argsToCheck) {
+    if (autotuneArgs.isManual(arg)) {
+      std::cerr << "Warning : " << arg
+                << " is manually set to a specific value. "
+                << "It will not be automatically optimized." << std::endl;
+    }
+  }
+}
+
+void Autotune::train(const Args& autotuneArgs) {
+  std::ifstream validationFileStream(autotuneArgs.autotuneValidationFile);
+  if (!validationFileStream.is_open()) {
+    throw std::invalid_argument("Validation file cannot be opened!");
+  }
+  printSkippedArgs(autotuneArgs);
+
+  bool sizeConstraintWarning = false;
+  int verbose = autotuneArgs.verbose;
+  Args bestTrainArgs(autotuneArgs);
+  Args trainArgs(autotuneArgs);
+  trainArgs.verbose = 0;
+  strategy_ = std::unique_ptr<AutotuneStrategy>(
+      new AutotuneStrategy(trainArgs, autotuneArgs.seed));
+  startTimer(autotuneArgs);
+
+  while (keepTraining(autotuneArgs.autotuneDuration)) {
+    trials_++;
+
+    trainArgs = strategy_->ask(elapsed_);
+    ElapsedTimeMarker elapsedTimeMarker;
+    double currentScore = std::numeric_limits<double>::quiet_NaN();
+    try {
+      fastText_->train(trainArgs);
+      bool sizeConstraintOK = quantize(trainArgs, autotuneArgs);
+      if (sizeConstraintOK) {
+        Meter meter;
+        fastText_->test(
+            validationFileStream, autotuneArgs.autotunePredictions, 0.0, meter);
+
+        currentScore = getMetricScore(
+            meter,
+            autotuneArgs.getAutotuneMetric(),
+            autotuneArgs.getAutotuneMetricLabel());
+
+        if (bestScore_ == Autotune::kUnknownBestScore ||
+            (currentScore > bestScore_)) {
+          bestTrainArgs = trainArgs;
+          bestScore_ = currentScore;
+          strategy_->updateBest(bestTrainArgs);
+        }
+      } else {
+        sizeConstraintFailed_++;
+        if (!sizeConstraintWarning && trials_ > 10 &&
+            sizeConstraintFailed_ > (trials_ / 2)) {
+          sizeConstraintWarning = true;
+          std::cerr
+              << std::endl
+              << "Warning : requested model size is probably too small. You may want to increase `autotune-modelsize`."
+              << std::endl;
+        }
+      }
+    } catch (DenseMatrix::EncounteredNaNError&) {
+      // ignore diverging loss and go on
+    } catch (TimeoutError&) {
+      break;
+    } catch (FastText::AbortError&) {
+      break;
+    }
+    LOG_VAL(Trial, trials_)
+    printArgs(trainArgs, autotuneArgs);
+    LOG_VAL_NAN(currentScore, currentScore)
+    LOG_VAL(train took, elapsedTimeMarker.getElapsed())
+  }
+  if (timer_.joinable()) {
+    timer_.join();
+  }
+
+  if (bestScore_ == Autotune::kUnknownBestScore) {
+    std::string errorMessage;
+    if (sizeConstraintWarning) {
+      errorMessage =
+          "Couldn't fulfil model size constraint: please increase `autotune-modelsize`.";
+    } else {
+      errorMessage =
+          "Didn't have enough time to train once: please increase `autotune-duration`.";
+    }
+    throw std::runtime_error(errorMessage);
+  } else {
+    std::cerr << std::endl;
+    std::cerr << "Training again with best arguments" << std::endl;
+    bestTrainArgs.verbose = verbose;
+    LOG_VAL(Best selected args, 0)
+    printArgs(bestTrainArgs, autotuneArgs);
+    fastText_->train(bestTrainArgs);
+    quantize(bestTrainArgs, autotuneArgs);
+  }
+}
+
+} // namespace fasttext

+ 90 - 0
src/autotune.h

@@ -0,0 +1,90 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the MIT license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+
+#include <istream>
+#include <memory>
+#include <random>
+#include <thread>
+#include <vector>
+
+#include "args.h"
+#include "fasttext.h"
+
+namespace fasttext {
+
+class AutotuneStrategy {
+ private:
+  Args bestArgs_;
+  int maxDuration_;
+  std::minstd_rand rng_;
+  int trials_;
+  int bestMinnIndex_;
+  int bestDsubExponent_;
+  int bestNonzeroBucket_;
+  std::vector<int> minnChoices_;
+  int getIndex(int val, const std::vector<int>& choices);
+
+ public:
+  explicit AutotuneStrategy(
+      const Args& args,
+      std::minstd_rand::result_type seed);
+  Args ask(double elapsed);
+  void updateBest(const Args& args);
+};
+
+class Autotune {
+ protected:
+  std::shared_ptr<FastText> fastText_;
+  double elapsed_;
+  double bestScore_;
+  int32_t trials_;
+  int32_t sizeConstraintFailed_;
+  std::atomic<bool> continueTraining_;
+  std::unique_ptr<AutotuneStrategy> strategy_;
+  std::thread timer_;
+
+  bool keepTraining(double maxDuration) const;
+  void printInfo(double maxDuration);
+  void timer(
+      const std::chrono::steady_clock::time_point& start,
+      double maxDuration);
+  void abort();
+  void startTimer(const Args& args);
+  double getMetricScore(
+      Meter& meter,
+      const metric_name& metricName,
+      const std::string& metricLabel) const;
+  void printArgs(const Args& args, const Args& autotuneArgs);
+  void printSkippedArgs(const Args& autotuneArgs);
+  bool quantize(Args& args, const Args& autotuneArgs);
+  int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
+      const;
+
+  class TimeoutError : public std::runtime_error {
+   public:
+    TimeoutError() : std::runtime_error("Autotune timed out.") {}
+  };
+
+  static constexpr double kUnknownBestScore = -1.0;
+  static constexpr int kCutoffLimit = 256;
+
+ public:
+  Autotune() = delete;
+  explicit Autotune(const std::shared_ptr<FastText>& fastText);
+  Autotune(const Autotune&) = delete;
+  Autotune(Autotune&&) = delete;
+  Autotune& operator=(const Autotune&) = delete;
+  Autotune& operator=(Autotune&&) = delete;
+  ~Autotune() noexcept = default;
+
+  void train(const Args& args);
+};
+
+} // namespace fasttext

+ 17 - 4
src/densematrix.cc

@@ -11,8 +11,8 @@
 #include <exception>
 #include <random>
 #include <stdexcept>
+#include <thread>
 #include <utility>
-
 #include "utils.h"
 #include "vector.h"
 
@@ -29,14 +29,27 @@ void DenseMatrix::zero() {
   std::fill(data_.begin(), data_.end(), 0.0);
 }
 
-void DenseMatrix::uniform(real a) {
-  std::minstd_rand rng(1);
+void DenseMatrix::uniformThread(real a, int block, int32_t seed) {
+  std::minstd_rand rng(block + seed);
   std::uniform_real_distribution<> uniform(-a, a);
-  for (int64_t i = 0; i < (m_ * n_); i++) {
+  int64_t blockSize = (m_ * n_) / 10;
+  for (int64_t i = blockSize * block;
+       i < (m_ * n_) && i < blockSize * (block + 1);
+       i++) {
     data_[i] = uniform(rng);
   }
 }
 
+void DenseMatrix::uniform(real a, unsigned int thread, int32_t seed) {
+  std::vector<std::thread> threads;
+  for (int i = 0; i < thread; i++) {
+    threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
+  }
+  for (int32_t i = 0; i < threads.size(); i++) {
+    threads[i].join();
+  }
+}
+
 void DenseMatrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
   if (ie == -1) {
     ie = m_;

+ 2 - 1
src/densematrix.h

@@ -24,6 +24,7 @@ class Vector;
 class DenseMatrix : public Matrix {
  protected:
   std::vector<real> data_;
+  void uniformThread(real, int, int32_t);
 
  public:
   DenseMatrix();
@@ -56,7 +57,7 @@ class DenseMatrix : public Matrix {
     return n_;
   }
   void zero();
-  void uniform(real);
+  void uniform(real, unsigned int, int32_t);
 
   void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
   void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);

+ 31 - 14
src/fasttext.cc

@@ -106,6 +106,9 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
 }
 
 void FastText::saveVectors(const std::string& filename) {
+  if (!input_ || !output_) {
+    throw std::runtime_error("Model never trained");
+  }
   std::ofstream ofs(filename);
   if (!ofs.is_open()) {
     throw std::invalid_argument(
@@ -170,6 +173,9 @@ void FastText::saveModel(const std::string& filename) {
   if (!ofs.is_open()) {
     throw std::invalid_argument(filename + " cannot be opened for saving!");
   }
+  if (!input_ || !output_) {
+    throw std::runtime_error("Model never trained");
+  }
   signModel(ofs);
   args_->save(ofs);
   dict_->save(ofs);
@@ -241,10 +247,7 @@ void FastText::loadModel(std::istream& in) {
 }
 
 void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
-  std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
-  double t =
-      std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
-          .count();
+  double t = utils::getDuration(start_, std::chrono::steady_clock::now());
   double lr = args_->lr * (1.0 - progress);
   double wst = 0;
 
@@ -255,17 +258,14 @@ void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
     eta = t * (100 - progress) / progress;
     wst = double(tokenCount_) / t / args_->thread;
   }
-  int32_t etah = eta / 3600;
-  int32_t etam = (eta % 3600) / 60;
 
   log_stream << std::fixed;
   log_stream << "Progress: ";
   log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
   log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
   log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
-  log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
-  log_stream << " ETA: " << std::setw(3) << etah;
-  log_stream << "h" << std::setw(2) << etam << "m";
+  log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
+  log_stream << " ETA: " << utils::ClockPrint(eta);
   log_stream << std::flush;
 }
 
@@ -278,6 +278,9 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
   std::iota(idx.begin(), idx.end(), 0);
   auto eosid = dict_->getId(Dictionary::EOS);
   std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
+    if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
+      return false;
+    }
     return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
   });
   idx.erase(idx.begin() + cutoff, idx.end());
@@ -399,6 +402,9 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
   std::vector<int32_t> line;
   std::vector<int32_t> labels;
   Predictions predictions;
+  Model::State state(args_->dim, dict_->nlabels(), 0);
+  in.clear();
+  in.seekg(0, std::ios_base::beg);
 
   while (in.peek() != EOF) {
     line.clear();
@@ -596,7 +602,7 @@ void FastText::trainThread(int32_t threadId) {
   std::ifstream ifs(args_->input);
   utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
 
-  Model::State state(args_->dim, output_->size(0), threadId);
+  Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
 
   const int64_t ntokens = dict_->ntokens();
   int64_t localTokenCount = 0;
@@ -623,7 +629,7 @@ void FastText::trainThread(int32_t threadId) {
         }
       }
     }
-  } catch (DenseMatrix::EncounteredNaNError &) {
+  } catch (DenseMatrix::EncounteredNaNError&) {
     trainException_ = std::current_exception();
   }
   if (threadId == 0)
@@ -662,7 +668,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
   dict_->init();
   std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
       dict_->nwords() + args_->bucket, args_->dim);
-  input->uniform(1.0 / args_->dim);
+  input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
 
   for (size_t i = 0; i < n; i++) {
     int32_t idx = dict_->getId(words[i]);
@@ -679,7 +685,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
 std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
   std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
       dict_->nwords() + args_->bucket, args_->dim);
-  input->uniform(1.0 / args_->dim);
+  input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
 
   return input;
 }
@@ -715,12 +721,21 @@ void FastText::train(const Args& args) {
     input_ = createRandomMatrix();
   }
   output_ = createTrainOutputMatrix();
+  quant_ = false;
   auto loss = createLoss(output_);
   bool normalizeGradient = (args_->model == model_name::sup);
   model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
   startThreads();
 }
 
+void FastText::abort() {
+  try {
+    throw AbortError();
+  } catch (AbortError&) {
+    trainException_ = std::current_exception();
+  }
+}
+
 void FastText::startThreads() {
   start_ = std::chrono::steady_clock::now();
   tokenCount_ = 0;
@@ -744,7 +759,9 @@ void FastText::startThreads() {
     threads[i].join();
   }
   if (trainException_) {
-    std::rethrow_exception(trainException_);
+    std::exception_ptr exception = trainException_;
+    trainException_ = nullptr;
+    std::rethrow_exception(exception);
   }
   if (args_->verbose > 0) {
     std::cerr << "\r";

+ 6 - 0
src/fasttext.h

@@ -143,9 +143,15 @@ class FastText {
 
   void train(const Args& args);
 
+  void abort();
+
   int getDimension() const;
 
   bool isQuant() const;
 
+  class AbortError : public std::runtime_error {
+   public:
+    AbortError() : std::runtime_error("Aborted.") {}
+  };
 };
 } // namespace fasttext

+ 19 - 6
src/main.cc

@@ -11,6 +11,7 @@
 #include <queue>
 #include <stdexcept>
 #include "args.h"
+#include "autotune.h"
 #include "fasttext.h"
 
 using namespace fasttext;
@@ -351,19 +352,31 @@ void analogies(const std::vector<std::string> args) {
 void train(const std::vector<std::string> args) {
   Args a = Args();
   a.parseArgs(args);
-  FastText fasttext;
-  std::string outputFileName(a.output + ".bin");
+  std::shared_ptr<FastText> fasttext = std::make_shared<FastText>();
+  std::string outputFileName;
+
+  if (a.hasAutotune() &&
+      a.getAutotuneModelSize() != Args::kUnlimitedModelSize) {
+    outputFileName = a.output + ".ftz";
+  } else {
+    outputFileName = a.output + ".bin";
+  }
   std::ofstream ofs(outputFileName);
   if (!ofs.is_open()) {
     throw std::invalid_argument(
         outputFileName + " cannot be opened for saving.");
   }
   ofs.close();
-  fasttext.train(a);
-  fasttext.saveModel(outputFileName);
-  fasttext.saveVectors(a.output + ".vec");
+  if (a.hasAutotune()) {
+    Autotune autotune(fasttext);
+    autotune.train(a);
+  } else {
+    fasttext->train(a);
+  }
+  fasttext->saveModel(outputFileName);
+  fasttext->saveVectors(a.output + ".vec");
   if (a.saveOutput) {
-    fasttext.saveOutput(a.output + ".output");
+    fasttext->saveOutput(a.output + ".output");
   }
 }
 

+ 13 - 0
src/meter.cc

@@ -26,10 +26,14 @@ void Meter::log(
   for (const auto& prediction : predictions) {
     labelMetrics_[prediction.second].predicted++;
 
+    real score = std::exp(prediction.first);
+    real gold = 0.0;
     if (utils::contains(labels, prediction.second)) {
       labelMetrics_[prediction.second].predictedGold++;
       metrics_.predictedGold++;
+      gold = 1.0;
     }
+    labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
   }
 
   for (const auto& label : labels) {
@@ -57,6 +61,15 @@ double Meter::recall() const {
   return metrics_.recall();
 }
 
+double Meter::f1Score() const {
+  const double precision = this->precision();
+  const double recall = this->recall();
+  if (precision + recall != 0) {
+    return 2 * precision * recall / (precision + recall);
+  }
+  return std::numeric_limits<double>::quiet_NaN();
+}
+
 void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
   out << "N"
       << "\t" << nexamples_ << std::endl;

+ 2 - 0
src/meter.h

@@ -22,6 +22,7 @@ class Meter {
     uint64_t gold;
     uint64_t predicted;
     uint64_t predictedGold;
+    mutable std::vector<std::pair<real, real>> scoreVsTrue;
 
     Metrics() : gold(0), predicted(0), predictedGold(0) {}
 
@@ -55,6 +56,7 @@ class Meter {
   double f1Score(int32_t);
   double precision() const;
   double recall() const;
+  double f1Score() const;
   uint64_t nexamples() const {
     return nexamples_;
   }

+ 21 - 0
src/utils.cc

@@ -8,6 +8,7 @@
 
 #include "utils.h"
 
+#include <iomanip>
 #include <ios>
 
 namespace fasttext {
@@ -23,6 +24,26 @@ void seek(std::ifstream& ifs, int64_t pos) {
   ifs.clear();
   ifs.seekg(std::streampos(pos));
 }
+
+double getDuration(
+    const std::chrono::steady_clock::time_point& start,
+    const std::chrono::steady_clock::time_point& end) {
+  return std::chrono::duration_cast<std::chrono::duration<double>>(end - start)
+      .count();
+}
+
+ClockPrint::ClockPrint(int32_t duration) : duration_(duration) {}
+
+std::ostream& operator<<(std::ostream& out, const ClockPrint& me) {
+  int32_t etah = me.duration_ / 3600;
+  int32_t etam = (me.duration_ % 3600) / 60;
+  int32_t etas = (me.duration_ % 3600) % 60;
+
+  out << std::setw(3) << etah << "h" << std::setw(2) << etam << "m";
+  out << std::setw(2) << etas << "s";
+  return out;
+}
+
 } // namespace utils
 
 } // namespace fasttext

+ 15 - 0
src/utils.h

@@ -11,7 +11,9 @@
 #include "real.h"
 
 #include <algorithm>
+#include <chrono>
 #include <fstream>
+#include <ostream>
 #include <vector>
 
 #if defined(__clang__) || defined(__GNUC__)
@@ -38,6 +40,19 @@ bool contains(const std::vector<T>& container, const T& value) {
       container.end();
 }
 
+double getDuration(
+    const std::chrono::steady_clock::time_point& start,
+    const std::chrono::steady_clock::time_point& end);
+
+class ClockPrint {
+ public:
+  explicit ClockPrint(int32_t duration);
+  friend std::ostream& operator<<(std::ostream& out, const ClockPrint& me);
+
+ private:
+  int32_t duration_;
+};
+
 } // namespace utils
 
 } // namespace fasttext

+ 1 - 1
website/sidebars.json

@@ -2,7 +2,7 @@
   "docs": {
     "Introduction": ["support", "cheatsheet", "options"],
     "Tutorials": ["supervised-tutorial", "unsupervised-tutorial"],
-    "Help": ["python-module", "faqs", "api", "references"]
+    "Help": ["autotune", "python-module", "faqs", "api", "references"]
   },
   "download": {
     "Resources": [