1
0
Эх сурвалжийг харах

Add a -minCountLabel option

Summary:
In order to reproduce the results from [2] on YFCC100M, we add an option to
remove unfrequent labels.

Reviewed By: EdouardGrave

Differential Revision: D4031684

fbshipit-source-id: c3724706dc0ae35e7d9cc6d08a52cdd4b9d4bccc
Armand Joulin 9 жил өмнө
parent
commit
2211639001
6 өөрчлөгдсөн 18 нэмэгдсэн , 6 устгасан
  1. 5 0
      README.md
  2. 4 0
      src/args.cc
  3. 1 0
      src/args.h
  4. 6 4
      src/dictionary.cc
  5. 1 1
      src/dictionary.h
  6. 1 1
      src/fasttext.cc

+ 5 - 0
README.md

@@ -133,6 +133,7 @@ The following arguments are optional:
   -ws                 size of the context window [5]
   -epoch              number of epochs [5]
   -minCount           minimal number of word occurences [1]
+  -minCountLabel      minimal number of label occurences [0]
   -neg                number of negatives sampled [5]
   -wordNgrams         max length of word ngram [1]
   -loss               loss function {ns, hs, softmax} [ns]
@@ -180,6 +181,10 @@ Please cite [1](#enriching-word-vectors-with-subword-information) if using this
 
 (\* These authors contributed equally.)
 
+## Resources
+
+You can find the preprocessed YFCC100M data used in [2] at https://research.facebook.com/research/fasttext/
+
 ## Join the fastText community
 
 * Facebook page: https://www.facebook.com/groups/1174547215919768

+ 4 - 0
src/args.cc

@@ -22,6 +22,7 @@ Args::Args() {
   ws = 5;
   epoch = 5;
   minCount = 5;
+  minCountLabel = 0;
   neg = 5;
   wordNgrams = 1;
   loss = loss_name::ns;
@@ -78,6 +79,8 @@ void Args::parseArgs(int argc, char** argv) {
       epoch = atoi(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-minCount") == 0) {
       minCount = atoi(argv[ai + 1]);
+    } else if (strcmp(argv[ai], "-minCountLabel") == 0) {
+      minCountLabel = atoi(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-neg") == 0) {
       neg = atoi(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-wordNgrams") == 0) {
@@ -143,6 +146,7 @@ void Args::printHelp() {
     << "  -ws                 size of the context window [" << ws << "]\n"
     << "  -epoch              number of epochs [" << epoch << "]\n"
     << "  -minCount           minimal number of word occurences [" << minCount << "]\n"
+    << "  -minCountLabel      minimal number of label occurences [" << minCountLabel << "]\n"
     << "  -neg                number of negatives sampled [" << neg << "]\n"
     << "  -wordNgrams         max length of word ngram [" << wordNgrams << "]\n"
     << "  -loss               loss function {ns, hs, softmax} [ns]\n"

+ 1 - 0
src/args.h

@@ -31,6 +31,7 @@ class Args {
     int ws;
     int epoch;
     int minCount;
+    int minCountLabel;
     int neg;
     int wordNgrams;
     loss_name loss;

+ 6 - 4
src/dictionary.cc

@@ -180,10 +180,11 @@ void Dictionary::readFromFile(std::istream& in) {
       std::cout << "\rRead " << ntokens_  / 1000000 << "M words" << std::flush;
     }
     if (size_ > 0.75 * MAX_VOCAB_SIZE) {
-      threshold(minThreshold++);
+      minThreshold++;
+      threshold(minThreshold, minThreshold);
     }
   }
-  threshold(args_->minCount);
+  threshold(args_->minCount, args_->minCountLabel);
   initTableDiscard();
   initNgrams();
   if (args_->verbose > 0) {
@@ -197,13 +198,14 @@ void Dictionary::readFromFile(std::istream& in) {
   }
 }
 
-void Dictionary::threshold(int64_t t) {
+void Dictionary::threshold(int64_t t, int64_t tl) {
   sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) {
       if (e1.type != e2.type) return e1.type < e2.type;
       return e1.count > e2.count;
     });
   words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) {
-        return e.type == entry_type::word && e.count < t;
+        return (e.type == entry_type::word && e.count < t) ||
+               (e.type == entry_type::label && e.count < tl);
       }), words_.end());
   words_.shrink_to_fit();
   size_ = 0;

+ 1 - 1
src/dictionary.h

@@ -77,7 +77,7 @@ class Dictionary {
     void addNgrams(std::vector<int32_t>&, int32_t) const;
     int32_t getLine(std::istream&, std::vector<int32_t>&,
                     std::vector<int32_t>&, std::minstd_rand&) const;
-    void threshold(int64_t);
+    void threshold(int64_t, int64_t);
 };
 
 }

+ 1 - 1
src/fasttext.cc

@@ -309,7 +309,7 @@ void FastText::loadVectors(std::string filename) {
   }
   in.close();
 
-  dict_->threshold(1);
+  dict_->threshold(1, 0);
   input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
   input_->uniform(1.0 / args_->dim);