|
|
@@ -257,7 +257,9 @@ void FastText::loadModel(std::istream& in) {
|
|
|
output_->load(in);
|
|
|
|
|
|
auto loss = createLoss(output_);
|
|
|
- model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
|
|
|
+ bool normalizeGradient = (args_->model == model_name::sup);
|
|
|
+ model_ = std::make_shared<Model>(
|
|
|
+ input_, output_, loss, args_->dim, normalizeGradient, 0);
|
|
|
}
|
|
|
|
|
|
void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
|
|
|
@@ -316,6 +318,8 @@ void FastText::quantize(const Args& qargs) {
|
|
|
std::dynamic_pointer_cast<DenseMatrix>(input_);
|
|
|
std::shared_ptr<DenseMatrix> output =
|
|
|
std::dynamic_pointer_cast<DenseMatrix>(output_);
|
|
|
+ bool normalizeGradient = (args_->model == model_name::sup);
|
|
|
+
|
|
|
if (qargs.cutoff > 0 && qargs.cutoff < input->size(0)) {
|
|
|
auto idx = selectEmbeddings(qargs.cutoff);
|
|
|
dict_->prune(idx);
|
|
|
@@ -333,7 +337,8 @@ void FastText::quantize(const Args& qargs) {
|
|
|
args_->thread = qargs.thread;
|
|
|
args_->verbose = qargs.verbose;
|
|
|
auto loss = createLoss(output_);
|
|
|
- model_ = std::make_shared<Model>(input, output, args_, loss, 0);
|
|
|
+ model_ = std::make_shared<Model>(
|
|
|
+ input, output, loss, args_->dim, normalizeGradient, 0);
|
|
|
startThreads();
|
|
|
}
|
|
|
}
|
|
|
@@ -348,7 +353,8 @@ void FastText::quantize(const Args& qargs) {
|
|
|
|
|
|
quant_ = true;
|
|
|
auto loss = createLoss(output_);
|
|
|
- model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
|
|
|
+ model_ = std::make_shared<Model>(
|
|
|
+ input_, output_, loss, args_->dim, normalizeGradient, 0);
|
|
|
}
|
|
|
|
|
|
void FastText::supervised(
|
|
|
@@ -438,6 +444,9 @@ void FastText::predict(
|
|
|
}
|
|
|
Vector hidden(args_->dim);
|
|
|
Vector output(dict_->nlabels());
|
|
|
+ if (args_->model != model_name::sup) {
|
|
|
+ throw std::invalid_argument("Model needs to be supervised for prediction!");
|
|
|
+ }
|
|
|
model_->predict(words, k, threshold, predictions, hidden, output);
|
|
|
}
|
|
|
|
|
|
@@ -764,7 +773,9 @@ void FastText::train(const Args& args) {
|
|
|
}
|
|
|
output_ = createTrainOutputMatrix();
|
|
|
auto loss = createLoss(output_);
|
|
|
- model_ = std::make_shared<Model>(input_, output_, args_, loss, 0);
|
|
|
+ bool normalizeGradient = (args_->model == model_name::sup);
|
|
|
+ model_ = std::make_shared<Model>(
|
|
|
+ input_, output_, loss, args_->dim, normalizeGradient, 0);
|
|
|
startThreads();
|
|
|
}
|
|
|
|