| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- /**
- * Copyright (c) 2016-present, Facebook, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
- */
- #pragma once
- #include <istream>
- #include <memory>
- #include <random>
- #include <thread>
- #include <vector>
- #include "args.h"
- #include "fasttext.h"
- namespace fasttext {
- class AutotuneStrategy {
- private:
- Args bestArgs_;
- int maxDuration_;
- std::minstd_rand rng_;
- int trials_;
- int bestMinnIndex_;
- int bestDsubExponent_;
- int bestNonzeroBucket_;
- int originalBucket_;
- std::vector<int> minnChoices_;
- int getIndex(int val, const std::vector<int>& choices);
- public:
- explicit AutotuneStrategy(
- const Args& args,
- std::minstd_rand::result_type seed);
- Args ask(double elapsed);
- void updateBest(const Args& args);
- };
- class Autotune {
- protected:
- std::shared_ptr<FastText> fastText_;
- double elapsed_;
- double bestScore_;
- int32_t trials_;
- int32_t sizeConstraintFailed_;
- std::atomic<bool> continueTraining_;
- std::unique_ptr<AutotuneStrategy> strategy_;
- std::thread timer_;
- bool keepTraining(double maxDuration) const;
- void printInfo(double maxDuration);
- void timer(
- const std::chrono::steady_clock::time_point& start,
- double maxDuration);
- void abort();
- void startTimer(const Args& args);
- double getMetricScore(
- Meter& meter,
- const metric_name& metricName,
- const double metricValue,
- const std::string& metricLabel) const;
- void printArgs(const Args& args, const Args& autotuneArgs);
- void printSkippedArgs(const Args& autotuneArgs);
- bool quantize(Args& args, const Args& autotuneArgs);
- int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
- const;
- class TimeoutError : public std::runtime_error {
- public:
- TimeoutError() : std::runtime_error("Autotune timed out.") {}
- };
- public:
- Autotune() = delete;
- explicit Autotune(const std::shared_ptr<FastText>& fastText);
- Autotune(const Autotune&) = delete;
- Autotune(Autotune&&) = delete;
- Autotune& operator=(const Autotune&) = delete;
- Autotune& operator=(Autotune&&) = delete;
- ~Autotune() noexcept = default;
- void train(const Args& args);
- };
- } // namespace fasttext
|