|
|
@@ -142,7 +142,7 @@ bool Model::comparePairs(const std::pair<real, int32_t> &l,
|
|
|
return l.first > r.first;
|
|
|
}
|
|
|
|
|
|
-void Model::predict(const std::vector<int32_t>& input, int32_t k,
|
|
|
+void Model::predict(const std::vector<int32_t>& input, int32_t k, real threshold,
|
|
|
std::vector<std::pair<real, int32_t>>& heap,
|
|
|
Vector& hidden, Vector& output) const {
|
|
|
if (k <= 0) {
|
|
|
@@ -154,22 +154,31 @@ void Model::predict(const std::vector<int32_t>& input, int32_t k,
|
|
|
heap.reserve(k + 1);
|
|
|
computeHidden(input, hidden);
|
|
|
if (args_->loss == loss_name::hs) {
|
|
|
- dfs(k, 2 * osz_ - 2, 0.0, heap, hidden);
|
|
|
+ dfs(k, threshold, 2 * osz_ - 2, 0.0, heap, hidden);
|
|
|
} else {
|
|
|
- findKBest(k, heap, hidden, output);
|
|
|
+ findKBest(k, threshold, heap, hidden, output);
|
|
|
}
|
|
|
std::sort_heap(heap.begin(), heap.end(), comparePairs);
|
|
|
}
|
|
|
|
|
|
-void Model::predict(const std::vector<int32_t>& input, int32_t k,
|
|
|
- std::vector<std::pair<real, int32_t>>& heap) {
|
|
|
- predict(input, k, heap, hidden_, output_);
|
|
|
+void Model::predict(
|
|
|
+ const std::vector<int32_t>& input,
|
|
|
+ int32_t k,
|
|
|
+ real threshold,
|
|
|
+ std::vector<std::pair<real, int32_t>>& heap
|
|
|
+) {
|
|
|
+ predict(input, k, threshold, heap, hidden_, output_);
|
|
|
}
|
|
|
|
|
|
-void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
|
|
|
- Vector& hidden, Vector& output) const {
|
|
|
+void Model::findKBest(
|
|
|
+ int32_t k,
|
|
|
+ real threshold,
|
|
|
+ std::vector<std::pair<real, int32_t>>& heap,
|
|
|
+ Vector& hidden, Vector& output
|
|
|
+) const {
|
|
|
computeOutputSoftmax(hidden, output);
|
|
|
for (int32_t i = 0; i < osz_; i++) {
|
|
|
+ if (output[i] < threshold) continue;
|
|
|
if (heap.size() == k && std_log(output[i]) < heap.front().first) {
|
|
|
continue;
|
|
|
}
|
|
|
@@ -182,9 +191,10 @@ void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void Model::dfs(int32_t k, int32_t node, real score,
|
|
|
+void Model::dfs(int32_t k, real threshold, int32_t node, real score,
|
|
|
std::vector<std::pair<real, int32_t>>& heap,
|
|
|
Vector& hidden) const {
|
|
|
+ if (score < std_log(threshold)) return;
|
|
|
if (heap.size() == k && score < heap.front().first) {
|
|
|
return;
|
|
|
}
|
|
|
@@ -207,8 +217,8 @@ void Model::dfs(int32_t k, int32_t node, real score,
|
|
|
}
|
|
|
f = 1. / (1 + std::exp(-f));
|
|
|
|
|
|
- dfs(k, tree[node].left, score + std_log(1.0 - f), heap, hidden);
|
|
|
- dfs(k, tree[node].right, score + std_log(f), heap, hidden);
|
|
|
+ dfs(k, threshold, tree[node].left, score + std_log(1.0 - f), heap, hidden);
|
|
|
+ dfs(k, threshold, tree[node].right, score + std_log(f), heap, hidden);
|
|
|
}
|
|
|
|
|
|
void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {
|