|
|
@@ -114,7 +114,8 @@ AutotuneStrategy::AutotuneStrategy(
|
|
|
trials_(0),
|
|
|
bestMinnIndex_(0),
|
|
|
bestDsubExponent_(1),
|
|
|
- bestNonzeroBucket_(2000000) {
|
|
|
+ bestNonzeroBucket_(2000000),
|
|
|
+ originalBucket_(originalArgs.bucket) {
|
|
|
minnChoices_ = {0, 2, 3};
|
|
|
updateBest(originalArgs);
|
|
|
}
|
|
|
@@ -167,13 +168,14 @@ Args AutotuneStrategy::ask(double elapsed) {
|
|
|
}
|
|
|
}
|
|
|
if (!args.isManual("bucket")) {
|
|
|
- if (args.wordNgrams <= 1 && args.maxn == 0) {
|
|
|
- args.bucket = 0;
|
|
|
- } else {
|
|
|
- int nonZeroBucket = updateArgGauss(
|
|
|
- bestNonzeroBucket_, 10000, 10000000, 2.0, 1.5, t, false, rng_);
|
|
|
- args.bucket = nonZeroBucket;
|
|
|
- }
|
|
|
+ int nonZeroBucket = updateArgGauss(
|
|
|
+ bestNonzeroBucket_, 10000, 10000000, 2.0, 1.5, t, false, rng_);
|
|
|
+ args.bucket = nonZeroBucket;
|
|
|
+ } else {
|
|
|
+ args.bucket = originalBucket_;
|
|
|
+ }
|
|
|
+ if (args.wordNgrams <= 1 && args.maxn == 0) {
|
|
|
+ args.bucket = 0;
|
|
|
}
|
|
|
if (!args.isManual("loss")) {
|
|
|
args.loss = loss_name::softmax;
|