|
|
@@ -12,9 +12,9 @@
|
|
|
#include <assert.h>
|
|
|
|
|
|
#include <iostream>
|
|
|
+#include <fstream>
|
|
|
#include <algorithm>
|
|
|
#include <iterator>
|
|
|
-#include <unordered_map>
|
|
|
|
|
|
namespace fasttext {
|
|
|
|
|
|
@@ -22,17 +22,9 @@ const std::string Dictionary::EOS = "</s>";
|
|
|
const std::string Dictionary::BOW = "<";
|
|
|
const std::string Dictionary::EOW = ">";
|
|
|
|
|
|
-Dictionary::Dictionary(std::shared_ptr<Args> args) {
|
|
|
- args_ = args;
|
|
|
- size_ = 0;
|
|
|
- nwords_ = 0;
|
|
|
- nlabels_ = 0;
|
|
|
- ntokens_ = 0;
|
|
|
- word2int_.resize(MAX_VOCAB_SIZE);
|
|
|
- for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
|
|
|
- word2int_[i] = -1;
|
|
|
- }
|
|
|
-}
|
|
|
+Dictionary::Dictionary(std::shared_ptr<Args> args) : args_(args),
|
|
|
+ word2int_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0),
|
|
|
+ ntokens_(0), quant_(false) {}
|
|
|
|
|
|
int32_t Dictionary::find(const std::string& w) const {
|
|
|
int32_t h = hash(w) % MAX_VOCAB_SIZE;
|
|
|
@@ -246,9 +238,7 @@ void Dictionary::threshold(int64_t t, int64_t tl) {
|
|
|
size_ = 0;
|
|
|
nwords_ = 0;
|
|
|
nlabels_ = 0;
|
|
|
- for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
|
|
|
- word2int_[i] = -1;
|
|
|
- }
|
|
|
+ std::fill(word2int_.begin(), word2int_.end(), -1);
|
|
|
for (auto it = words_.begin(); it != words_.end(); ++it) {
|
|
|
int32_t h = find(it->word);
|
|
|
word2int_[h] = size_++;
|
|
|
@@ -273,43 +263,88 @@ std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
|
|
|
return counts;
|
|
|
}
|
|
|
|
|
|
-void Dictionary::addNgrams(std::vector<int32_t>& line, int32_t n) const {
|
|
|
- int32_t line_size = line.size();
|
|
|
- for (int32_t i = 0; i < line_size; i++) {
|
|
|
- uint64_t h = line[i];
|
|
|
- for (int32_t j = i + 1; j < line_size && j < i + n; j++) {
|
|
|
- h = h * 116049371 + line[j];
|
|
|
- line.push_back(nwords_ + (h % args_->bucket));
|
|
|
+void Dictionary::addNgrams(std::vector<int32_t>& line,
|
|
|
+ const std::vector<int32_t>& hashes,
|
|
|
+ int32_t n) const {
|
|
|
+ for (int32_t i = 0; i < hashes.size(); i++) {
|
|
|
+ uint64_t h = hashes[i];
|
|
|
+ for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
|
|
|
+ h = h * 116049371 + hashes[j];
|
|
|
+ int64_t id = h % args_->bucket;
|
|
|
+ if (quantidx_.size() != 0) {
|
|
|
+ if (quantidx_.find(id) != quantidx_.end()) {
|
|
|
+ id = quantidx_.at(id);
|
|
|
+ } else {continue;}
|
|
|
+ }
|
|
|
+ line.push_back(nwords_ + id);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+int32_t Dictionary::getLine(std::istream& in,
|
|
|
+ std::vector<std::string>& tokens) const {
|
|
|
+ if (in.eof()) {
|
|
|
+ in.clear();
|
|
|
+ in.seekg(std::streampos(0));
|
|
|
+ }
|
|
|
+ tokens.clear();
|
|
|
+ std::string token;
|
|
|
+ while (readWord(in, token)) {
|
|
|
+ tokens.push_back(token);
|
|
|
+ if (token == EOS) break;
|
|
|
+ if (tokens.size() > MAX_LINE_SIZE && args_->model != model_name::sup) break;
|
|
|
+ }
|
|
|
+ return tokens.size();
|
|
|
+}
|
|
|
+
|
|
|
int32_t Dictionary::getLine(std::istream& in,
|
|
|
std::vector<int32_t>& words,
|
|
|
+ std::vector<int32_t>& word_hashes,
|
|
|
std::vector<int32_t>& labels,
|
|
|
std::minstd_rand& rng) const {
|
|
|
std::uniform_real_distribution<> uniform(0, 1);
|
|
|
- std::string token;
|
|
|
- int32_t ntokens = 0;
|
|
|
+ std::vector<std::string> tokens;
|
|
|
+ getLine(in, tokens);
|
|
|
words.clear();
|
|
|
labels.clear();
|
|
|
- if (in.eof()) {
|
|
|
- in.clear();
|
|
|
- in.seekg(std::streampos(0));
|
|
|
- }
|
|
|
- while (readWord(in, token)) {
|
|
|
- int32_t wid = getId(token);
|
|
|
- if (wid < 0) continue;
|
|
|
+ word_hashes.clear();
|
|
|
+ int32_t ntokens = 0;
|
|
|
+ for(auto it = tokens.cbegin(); it != tokens.cend(); ++it) {
|
|
|
+ int32_t h = find(*it);
|
|
|
+ int32_t wid = word2int_[h];
|
|
|
+ if (wid < 0) {
|
|
|
+ word_hashes.push_back(hash(*it));
|
|
|
+ continue;
|
|
|
+ }
|
|
|
entry_type type = getType(wid);
|
|
|
ntokens++;
|
|
|
if (type == entry_type::word && !discard(wid, uniform(rng))) {
|
|
|
words.push_back(wid);
|
|
|
+ word_hashes.push_back(hash(*it));
|
|
|
}
|
|
|
if (type == entry_type::label) {
|
|
|
labels.push_back(wid - nwords_);
|
|
|
}
|
|
|
- if (words.size() > MAX_LINE_SIZE && args_->model != model_name::sup) break;
|
|
|
- if (token == EOS) break;
|
|
|
+ }
|
|
|
+ return ntokens;
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+int32_t Dictionary::getLine(std::istream& in,
|
|
|
+ std::vector<int32_t>& words,
|
|
|
+ std::vector<int32_t>& labels,
|
|
|
+ std::minstd_rand& rng) const {
|
|
|
+ std::vector<int32_t> word_hashes;
|
|
|
+ int32_t ntokens = getLine(in, words, word_hashes, labels, rng);
|
|
|
+ if (args_->model == model_name::sup ) {
|
|
|
+ if (quant_) {
|
|
|
+ addNgrams(words, word_hashes, args_->wordNgrams);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ std::vector<int32_t> ngrams;
|
|
|
+ addNgrams(ngrams, words, args_->wordNgrams);
|
|
|
+ words.insert(words.end(), ngrams.begin(), ngrams.end());
|
|
|
+ }
|
|
|
}
|
|
|
return ntokens;
|
|
|
}
|
|
|
@@ -332,13 +367,19 @@ void Dictionary::save(std::ostream& out) const {
|
|
|
out.write((char*) &(e.count), sizeof(int64_t));
|
|
|
out.write((char*) &(e.type), sizeof(entry_type));
|
|
|
}
|
|
|
+ if (quant_) {
|
|
|
+ auto ss = quantidx_.size();
|
|
|
+ out.write((char*) &(ss), sizeof(ss));
|
|
|
+ for (auto it = quantidx_.begin(); it != quantidx_.end(); it++) {
|
|
|
+ out.write((char*)&(it->first), sizeof(int32_t));
|
|
|
+ out.write((char*)&(it->second), sizeof(int32_t));
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
void Dictionary::load(std::istream& in) {
|
|
|
words_.clear();
|
|
|
- for (int32_t i = 0; i < MAX_VOCAB_SIZE; i++) {
|
|
|
- word2int_[i] = -1;
|
|
|
- }
|
|
|
+ std::fill(word2int_.begin(), word2int_.end(), -1);
|
|
|
in.read((char*) &size_, sizeof(int32_t));
|
|
|
in.read((char*) &nwords_, sizeof(int32_t));
|
|
|
in.read((char*) &nlabels_, sizeof(int32_t));
|
|
|
@@ -354,8 +395,99 @@ void Dictionary::load(std::istream& in) {
|
|
|
words_.push_back(e);
|
|
|
word2int_[find(e.word)] = i;
|
|
|
}
|
|
|
+ if (quant_) {
|
|
|
+ std::size_t size;
|
|
|
+ in.read((char*) &size, sizeof(std::size_t));
|
|
|
+ for (auto i = 0; i < size; i++) {
|
|
|
+ int32_t k, v;
|
|
|
+ in.read((char*)&k, sizeof(int32_t));
|
|
|
+ in.read((char*)&v, sizeof(int32_t));
|
|
|
+ quantidx_[k] = v;
|
|
|
+ }
|
|
|
+ }
|
|
|
initTableDiscard();
|
|
|
initNgrams();
|
|
|
}
|
|
|
|
|
|
+void Dictionary::prune(std::vector<int32_t>& idx) {
|
|
|
+ std::vector<int32_t> words, ngrams;
|
|
|
+ for (auto it = idx.cbegin(); it != idx.cend(); ++it) {
|
|
|
+ if (*it < nwords_) {words.push_back(*it);}
|
|
|
+ else {ngrams.push_back(*it);}
|
|
|
+ }
|
|
|
+ std::sort(words.begin(), words.end());
|
|
|
+ idx = words;
|
|
|
+
|
|
|
+ if (ngrams.size() != 0) {
|
|
|
+ convertNgrams(ngrams);
|
|
|
+ idx.insert(idx.end(), ngrams.begin(), ngrams.end());
|
|
|
+ }
|
|
|
+
|
|
|
+ std::fill(word2int_.begin(), word2int_.end(), -1);
|
|
|
+
|
|
|
+ int32_t j = 0;
|
|
|
+ for (int32_t i = 0; i < words_.size(); i++) {
|
|
|
+ if (getType(i) == entry_type::label || (j < words.size() && words[j] == i)) {
|
|
|
+ words_[j] = words_[i];
|
|
|
+ word2int_[find(words_[j].word)] = j;
|
|
|
+ j++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ nwords_ = words.size();
|
|
|
+ size_ = nwords_ + nlabels_;
|
|
|
+ words_.erase(words_.begin() + size_, words_.end());
|
|
|
+}
|
|
|
+
|
|
|
+void Dictionary::convertNgrams(std::vector<int32_t>& ngramidx) {
|
|
|
+
|
|
|
+ std::ifstream in(args_->input);
|
|
|
+ if (!in.is_open()) {
|
|
|
+ std::cerr << "Input file cannot be opened!" << std::endl;
|
|
|
+ exit(EXIT_FAILURE);
|
|
|
+ }
|
|
|
+
|
|
|
+ std::unordered_map<int32_t, std::unordered_map<int32_t, int32_t>> convertMap;
|
|
|
+ for (auto it = ngramidx.cbegin(); it != ngramidx.cend(); ++it) {
|
|
|
+ convertMap[*it] = std::unordered_map<int32_t, int32_t>();
|
|
|
+ }
|
|
|
+ std::vector<std::string> tokens;
|
|
|
+ std::vector<int32_t> word_hashes, words, labels, oldhashes, newhashes;
|
|
|
+ std::minstd_rand rng;
|
|
|
+ while (in.peek() != EOF) {
|
|
|
+ getLine(in, words, word_hashes, labels, rng);
|
|
|
+ if (words.empty()) {continue;}
|
|
|
+ oldhashes.clear(); newhashes.clear();
|
|
|
+ addNgrams(oldhashes, words, args_->wordNgrams);
|
|
|
+ addNgrams(newhashes, word_hashes, args_->wordNgrams);
|
|
|
+ assert(newhashes.size() == oldhashes.size());
|
|
|
+ for (int32_t i = 0; i < oldhashes.size(); i++) {
|
|
|
+ auto oh = oldhashes[i];
|
|
|
+ if (convertMap.find(oh) == convertMap.end()) {continue;}
|
|
|
+ convertMap[oh][newhashes[i]]++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ in.close();
|
|
|
+
|
|
|
+ quantidx_.clear();
|
|
|
+ std::vector<int32_t> remaining_indices;
|
|
|
+ int32_t size = 0;
|
|
|
+ for (auto it = ngramidx.begin(); it != ngramidx.end(); ++it) {
|
|
|
+ auto cm = convertMap[*it];
|
|
|
+ int32_t newhash; int32_t count = -1;
|
|
|
+ for (auto nit = cm.cbegin(); nit != cm.cend(); ++nit) {
|
|
|
+ if (count < nit->second) {
|
|
|
+ newhash = nit->first;
|
|
|
+ count = nit->second;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ newhash -= nwords_;
|
|
|
+ if (quantidx_.find(newhash) == quantidx_.end()) {
|
|
|
+ quantidx_[newhash] = size;
|
|
|
+ size++;
|
|
|
+ remaining_indices.push_back(*it);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ngramidx = remaining_indices;
|
|
|
+}
|
|
|
+
|
|
|
}
|