Jelajahi Sumber

Remove -sampling and -onlyWord arguments

Edouard Grave 9 tahun lalu
induk
melakukan
bef5f3c73d
6 mengubah file dengan 14 tambahan dan 58 penghapusan
  1. 2 0
      .gitignore
  2. 1 25
      src/args.cc
  3. 0 3
      src/args.h
  4. 1 2
      src/fasttext.cc
  5. 9 27
      src/model.cc
  6. 1 1
      src/model.h

+ 2 - 0
.gitignore

@@ -1,5 +1,7 @@
 .*.swp
 *.o
+*.bin
+*.vec
 data
 classify
 fasttext

+ 1 - 25
src/args.cc

@@ -21,13 +21,11 @@ Args::Args() {
   minCount = 5;
   neg = 5;
   wordNgrams = 1;
-  sampling = sampling_name::sqrt;
   loss = loss_name::ns;
   model = model_name::sg;
   bucket = 2000000;
   minn = 3;
   maxn = 6;
-  onlyWord = 0;
   thread = 12;
   verbose = 1000;
   t = 1e-4;
@@ -74,20 +72,6 @@ void Args::parseArgs(int argc, char** argv) {
       neg = atoi(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-wordNgrams") == 0) {
       wordNgrams = atoi(argv[ai + 1]);
-    } else if (strcmp(argv[ai], "-sampling") == 0) {
-      if (strcmp(argv[ai + 1], "sqrt") == 0) {
-        sampling = sampling_name::sqrt;
-      } else if (strcmp(argv[ai + 1], "log") == 0) {
-        sampling = sampling_name::log;
-      } else if (strcmp(argv[ai + 1], "tf") == 0) {
-        sampling = sampling_name::tf;
-      } else if (strcmp(argv[ai + 1], "uni") == 0) {
-        sampling = sampling_name::uni;
-      } else {
-        std::cout << "Unknown sampling: " << argv[ai + 1] << std::endl;
-        printHelp();
-        exit(EXIT_FAILURE);
-      }
     } else if (strcmp(argv[ai], "-loss") == 0) {
       if (strcmp(argv[ai + 1], "hs") == 0) {
         loss = loss_name::hs;
@@ -106,8 +90,6 @@ void Args::parseArgs(int argc, char** argv) {
       minn = atoi(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-maxn") == 0) {
       maxn = atoi(argv[ai + 1]);
-    } else if (strcmp(argv[ai], "-onlyWord") == 0) {
-      onlyWord = atoi(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-thread") == 0) {
       thread = atoi(argv[ai + 1]);
     } else if (strcmp(argv[ai], "-verbose") == 0) {
@@ -144,12 +126,10 @@ void Args::printHelp() {
     << "  -minCount   minimal number of word occurences [" << minCount << "]\n"
     << "  -neg        number of negatives sampled [" << neg << "]\n"
     << "  -wordNgrams max length of word ngram [" << wordNgrams << "]\n"
-    << "  -sampling   sampling distribution {sqrt, log, tf, uni} [log]\n"
-    << "  -loss       loss function {ns, hs, softmax}   [ns]\n"
+    << "  -loss       loss function {ns, hs, softmax} [ns]\n"
     << "  -bucket     number of buckets [" << bucket << "]\n"
     << "  -minn       min length of char ngram [" << minn << "]\n"
     << "  -maxn       max length of char ngram [" << maxn << "]\n"
-    << "  -onlyWord   number of words with no ngrams [" << onlyWord << "]\n"
     << "  -thread     number of threads [" << thread << "]\n"
     << "  -verbose    how often to print to stdout [" << verbose << "]\n"
     << "  -t          sampling threshold [" << t << "]\n"
@@ -165,13 +145,11 @@ void Args::save(std::ofstream& ofs) {
     ofs.write((char*) &(minCount), sizeof(int));
     ofs.write((char*) &(neg), sizeof(int));
     ofs.write((char*) &(wordNgrams), sizeof(int));
-    ofs.write((char*) &(sampling), sizeof(sampling_name));
     ofs.write((char*) &(loss), sizeof(loss_name));
     ofs.write((char*) &(model), sizeof(model_name));
     ofs.write((char*) &(bucket), sizeof(int));
     ofs.write((char*) &(minn), sizeof(int));
     ofs.write((char*) &(maxn), sizeof(int));
-    ofs.write((char*) &(onlyWord), sizeof(int));
     ofs.write((char*) &(verbose), sizeof(int));
     ofs.write((char*) &(t), sizeof(double));
   }
@@ -185,13 +163,11 @@ void Args::load(std::ifstream& ifs) {
     ifs.read((char*) &(minCount), sizeof(int));
     ifs.read((char*) &(neg), sizeof(int));
     ifs.read((char*) &(wordNgrams), sizeof(int));
-    ifs.read((char*) &(sampling), sizeof(sampling_name));
     ifs.read((char*) &(loss), sizeof(loss_name));
     ifs.read((char*) &(model), sizeof(model_name));
     ifs.read((char*) &(bucket), sizeof(int));
     ifs.read((char*) &(minn), sizeof(int));
     ifs.read((char*) &(maxn), sizeof(int));
-    ifs.read((char*) &(onlyWord), sizeof(int));
     ifs.read((char*) &(verbose), sizeof(int));
     ifs.read((char*) &(t), sizeof(double));
   }

+ 0 - 3
src/args.h

@@ -13,7 +13,6 @@
 #include <string>
 
 enum class model_name : int {cbow=1, sg, sup};
-enum class sampling_name : int {sqrt=1, log, uni, tf};
 enum class loss_name : int {hs=1, ns, softmax};
 
 class Args {
@@ -29,13 +28,11 @@ class Args {
     int minCount;
     int neg;
     int wordNgrams;
-    sampling_name sampling;
     loss_name loss;
     model_name model;
     int bucket;
     int minn;
     int maxn;
-    int onlyWord;
     int thread;
     int verbose;
     double t;

+ 1 - 2
src/fasttext.cc

@@ -107,7 +107,7 @@ void printInfo(Model& model, real progress) {
   std::cout << "  words/sec/thread: " << std::setprecision(0) << wst;
   std::cout << "  lr: " << std::setprecision(6) << model.getLearningRate();
   std::cout << "  loss: " << std::setprecision(6) << loss;
-  std::cout << "  eta: " << etah << "h" << etam << "m";
+  std::cout << "  eta: " << etah << "h" << etam << "m ";
   std::cout << std::flush;
 }
 
@@ -370,7 +370,6 @@ void train(int argc, char** argv) {
 }
 
 int main(int argc, char** argv) {
-  std::locale::global(std::locale(""));
   utils::initTables();
   if (argc < 2) {
     printUsage();

+ 9 - 27
src/model.cc

@@ -24,7 +24,7 @@ Model::Model(Matrix& wi, Matrix& wo, int32_t hsz, real lr, int32_t seed)
   osz_ = wo.m_;
   hsz_ = hsz;
   lr_ = lr;
-  npos = 0;
+  negpos = 0;
 }
 
 void Model::setLearningRate(real lr) {
@@ -160,31 +160,13 @@ void Model::setTargetCounts(const std::vector<int64_t>& counts) {
 }
 
 void Model::initTableNegatives(const std::vector<int64_t>& counts) {
-  real N = 0.0;
-  for (int32_t i = 0; i < counts.size(); i++) {
-    if (args.sampling == sampling_name::log) {
-      N += log(counts[i]);
-    } else if (args.sampling == sampling_name::sqrt) {
-      N += sqrt(counts[i]);
-    } else if (args.sampling == sampling_name::tf) {
-      N += pow(counts[i], 0.75);
-    } else {
-      N += 1.0;
-    }
+  real z = 0.0;
+  for (size_t i = 0; i < counts.size(); i++) {
+    z += pow(counts[i], 0.5);
   }
-  for (int32_t i = 0; i < counts.size(); i++) {
-    real c = 0.0;
-    if (args.sampling == sampling_name::log) {
-      c = log(counts[i]);
-    } else if (args.sampling == sampling_name::sqrt) {
-      c = sqrt(counts[i]);
-    } else if (args.sampling == sampling_name::tf) {
-      c = pow(counts[i], 0.75);
-    } else {
-      c = 1.0;
-    }
-    int32_t n = (int32_t)ceil(c * ((real)NEGATIVE_TABLE_SIZE / N));
-    for (int32_t j = 0; j < n; j++) {
+  for (size_t i = 0; i < counts.size(); i++) {
+    real c = pow(counts[i], 0.5);
+    for (size_t j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) {
       negatives.push_back(i);
     }
   }
@@ -194,8 +176,8 @@ void Model::initTableNegatives(const std::vector<int64_t>& counts) {
 int32_t Model::getNegative(int32_t target) {
   int32_t negative;
   do {
-    negative = negatives[npos];
-    npos = (npos + 1) % negatives.size();
+    negative = negatives[negpos];
+    negpos = (negpos + 1) % negatives.size();
   } while (target == negative);
   return negative;
 }

+ 1 - 1
src/model.h

@@ -38,7 +38,7 @@ class Model {
     static real lr_;
 
     std::vector<int32_t> negatives;
-    size_t npos;
+    size_t negpos;
     std::vector< std::vector<int32_t> > paths;
     std::vector< std::vector<bool> > codes;
     std::vector<Node> tree;