autotune.h 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. /**
  2. * Copyright (c) 2016-present, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under the MIT license found in the
  6. * LICENSE file in the root directory of this source tree.
  7. */
  8. #pragma once
  9. #include <istream>
  10. #include <memory>
  11. #include <random>
  12. #include <thread>
  13. #include <vector>
  14. #include "args.h"
  15. #include "fasttext.h"
  16. namespace fasttext {
  17. class AutotuneStrategy {
  18. private:
  19. Args bestArgs_;
  20. int maxDuration_;
  21. std::minstd_rand rng_;
  22. int trials_;
  23. int bestMinnIndex_;
  24. int bestDsubExponent_;
  25. int bestNonzeroBucket_;
  26. int originalBucket_;
  27. std::vector<int> minnChoices_;
  28. int getIndex(int val, const std::vector<int>& choices);
  29. public:
  30. explicit AutotuneStrategy(
  31. const Args& args,
  32. std::minstd_rand::result_type seed);
  33. Args ask(double elapsed);
  34. void updateBest(const Args& args);
  35. };
  36. class Autotune {
  37. protected:
  38. std::shared_ptr<FastText> fastText_;
  39. double elapsed_;
  40. double bestScore_;
  41. int32_t trials_;
  42. int32_t sizeConstraintFailed_;
  43. std::atomic<bool> continueTraining_;
  44. std::unique_ptr<AutotuneStrategy> strategy_;
  45. std::thread timer_;
  46. bool keepTraining(double maxDuration) const;
  47. void printInfo(double maxDuration);
  48. void timer(
  49. const std::chrono::steady_clock::time_point& start,
  50. double maxDuration);
  51. void abort();
  52. void startTimer(const Args& args);
  53. double getMetricScore(
  54. Meter& meter,
  55. const metric_name& metricName,
  56. const double metricValue,
  57. const std::string& metricLabel) const;
  58. void printArgs(const Args& args, const Args& autotuneArgs);
  59. void printSkippedArgs(const Args& autotuneArgs);
  60. bool quantize(Args& args, const Args& autotuneArgs);
  61. int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
  62. const;
  63. class TimeoutError : public std::runtime_error {
  64. public:
  65. TimeoutError() : std::runtime_error("Autotune timed out.") {}
  66. };
  67. public:
  68. Autotune() = delete;
  69. explicit Autotune(const std::shared_ptr<FastText>& fastText);
  70. Autotune(const Autotune&) = delete;
  71. Autotune(Autotune&&) = delete;
  72. Autotune& operator=(const Autotune&) = delete;
  73. Autotune& operator=(Autotune&&) = delete;
  74. ~Autotune() noexcept = default;
  75. void train(const Args& args);
  76. };
  77. } // namespace fasttext