Ver código fonte

add model inspection command

Summary: add model inspection command

Reviewed By: EdouardGrave

Differential Revision: D5999484

fbshipit-source-id: 249749339697945afcf7cd481f3c3066de822745
Changhan Wang 8 anos atrás
pai
commit
8b10430ec3
9 arquivos alterados com 136 adições e 14 exclusões
  1. 28 0
      src/args.cc
  2. 15 13
      src/args.h
  3. 11 0
      src/dictionary.cc
  4. 1 0
      src/dictionary.h
  5. 24 0
      src/fasttext.cc
  6. 5 0
      src/fasttext.h
  7. 36 1
      src/main.cc
  8. 14 0
      src/matrix.cc
  9. 2 0
      src/matrix.h

+ 28 - 0
src/args.cc

@@ -65,6 +65,18 @@ std::string Args::boolToString(bool b) {
   }
 }
 
+std::string Args::modelToString(model_name mn) {
+  switch (mn) {
+    case model_name::cbow:
+      return "cbow";
+    case model_name::sg:
+      return "sg";
+    case model_name::sup:
+      return "sup";
+  }
+  return "Unknown model name!"; // should never happen
+}
+
 void Args::parseArgs(const std::vector<std::string>& args) {
   std::string command(args[1]);
   if (command == "supervised") {
@@ -262,4 +274,20 @@ void Args::load(std::istream& in) {
   in.read((char*) &(t), sizeof(double));
 }
 
+void Args::dump(std::ostream& out) {
+  out << "dim" << " " << dim << std::endl;
+  out << "ws" << " " << ws << std::endl;
+  out << "epoch" << " " << epoch << std::endl;
+  out << "minCount" << " " << minCount << std::endl;
+  out << "neg" << " " << neg << std::endl;
+  out << "wordNgrams" << " " << wordNgrams << std::endl;
+  out << "loss" << " " << lossToString(loss) << std::endl;
+  out << "model" << " " << modelToString(model) << std::endl;
+  out << "bucket" << " " << bucket << std::endl;
+  out << "minn" << " " << minn << std::endl;
+  out << "maxn" << " " << maxn << std::endl;
+  out << "lrUpdateRate" << " " << lrUpdateRate << std::endl;
+  out << "t" << " " << t << std::endl;
+}
+
 }

+ 15 - 13
src/args.h

@@ -23,6 +23,7 @@ class Args {
   protected:
     std::string lossToString(loss_name);
     std::string boolToString(bool);
+    std::string modelToString(model_name);
 
   public:
     Args();
@@ -49,19 +50,20 @@ class Args {
     std::string pretrainedVectors;
     bool saveOutput;
 
-  bool qout;
-  bool retrain;
-  bool qnorm;
-  size_t cutoff;
-  size_t dsub;
+    bool qout;
+    bool retrain;
+    bool qnorm;
+    size_t cutoff;
+    size_t dsub;
 
-  void parseArgs(const std::vector<std::string>& args);
-  void printHelp();
-  void printBasicHelp();
-  void printDictionaryHelp();
-  void printTrainingHelp();
-  void printQuantizationHelp();
-  void save(std::ostream&);
-  void load(std::istream&);
+    void parseArgs(const std::vector<std::string>& args);
+    void printHelp();
+    void printBasicHelp();
+    void printDictionaryHelp();
+    void printTrainingHelp();
+    void printQuantizationHelp();
+    void save(std::ostream&);
+    void load(std::istream&);
+    void dump(std::ostream&);
 };
 }

+ 11 - 0
src/dictionary.cc

@@ -477,4 +477,15 @@ void Dictionary::prune(std::vector<int32_t>& idx) {
   initNgrams();
 }
 
+void Dictionary::dump(std::ostream& out) const {
+  out << words_.size() << std::endl;
+  for (auto it : words_) {
+    std::string entryType = "word";
+    if (it.type == entry_type::label) {
+      entryType = "label";
+    }
+    out << it.word << " " << it.count << " " << entryType << std::endl;
+  }
+}
+
 }

+ 1 - 0
src/dictionary.h

@@ -104,6 +104,7 @@ class Dictionary {
     void threshold(int64_t, int64_t);
     void prune(std::vector<int32_t>&);
     bool isPruned() { return pruneidx_size_ >= 0; }
+    void dump(std::ostream&) const;
 };
 
 }

+ 24 - 0
src/fasttext.cc

@@ -712,4 +712,28 @@ bool FastText::isQuant() const {
   return quant_;
 }
 
+void FastText::dumpArgs() const {
+  args_->dump(std::cout);
+}
+
+void FastText::dumpDict() const {
+  dict_->dump(std::cout);
+}
+
+void FastText::dumpInput() const {
+  if (quant_) {
+    std::cerr << "Not supported for quantized models." << std::endl;
+  } else {
+    input_->dump(std::cout);
+  }
+}
+
+void FastText::dumpOutput() const {
+  if (quant_) {
+    std::cerr << "Not supported for quantized models." << std::endl;
+  } else {
+    output_->dump(std::cout);
+  }
+}
+
 }

+ 5 - 0
src/fasttext.h

@@ -109,5 +109,10 @@ class FastText {
   void loadVectors(std::string);
   int getDimension() const;
   bool isQuant() const;
+
+  void dumpArgs() const;
+  void dumpDict() const;
+  void dumpInput() const;
+  void dumpOutput() const;
 };
 }

+ 36 - 1
src/main.cc

@@ -8,7 +8,6 @@
  */
 
 #include <iostream>
-
 #include "fasttext.h"
 #include "args.h"
 
@@ -30,6 +29,7 @@ void printUsage() {
     << "  print-ngrams            print ngrams given a trained model and word\n"
     << "  nn                      query for nearest neighbors\n"
     << "  analogies               query for analogies\n"
+    << "  dump                    dump arguments,dictionary,input/output vectors\n"
     << std::endl;
 }
 
@@ -111,6 +111,14 @@ void printAnalogiesUsage() {
     << std::endl;
 }
 
+void printDumpUsage() {
+  std::cout
+    << "usage: fasttext dump <model> <option>\n\n"
+    << "  <model>      model filename\n"
+    << "  <option>     option from args,dict,input,output"
+    << std::endl;
+}
+
 void test(const std::vector<std::string>& args) {
   if (args.size() < 4 || args.size() > 5) {
     printTestUsage();
@@ -256,6 +264,31 @@ void train(const std::vector<std::string> args) {
   }
 }
 
+void dump(const std::vector<std::string>& args) {
+  if (args.size() < 4) {
+    printDumpUsage();
+    exit(EXIT_FAILURE);
+  }
+
+  std::string modelPath = args[2];
+  std::string option = args[3];
+
+  FastText fasttext;
+  fasttext.loadModel(modelPath);
+  if (option == "args") {
+    fasttext.dumpArgs();
+  } else if (option == "dict") {
+    fasttext.dumpDict();
+  } else if (option == "input") {
+    fasttext.dumpInput();
+  } else if (option == "output") {
+    fasttext.dumpOutput();
+  } else {
+    printDumpUsage();
+    exit(EXIT_FAILURE);
+  }
+}
+
 int main(int argc, char** argv) {
   std::vector<std::string> args(argv, argv + argc);
   if (args.size() < 2) {
@@ -281,6 +314,8 @@ int main(int argc, char** argv) {
     analogies(args);
   } else if (command == "predict" || command == "predict-prob" ) {
     predict(args);
+  } else if (command == "dump") {
+    dump(args);
   } else {
     printUsage();
     exit(EXIT_FAILURE);

+ 14 - 0
src/matrix.cc

@@ -117,4 +117,18 @@ void Matrix::load(std::istream& in) {
   data_ = std::vector<real>(m_ * n_);
   in.read((char*)data_.data(), m_ * n_ * sizeof(real));
 }
+
+void Matrix::dump(std::ostream& out) const {
+  out << m_ << " " << n_ << std::endl;
+  for (int64_t i = 0; i < m_; i++) {
+    for (int64_t j = 0; j < n_; j++) {
+      if (j > 0) {
+        out << " ";
+      }
+      out << at(i, j);
+    }
+    out << std::endl;
+  }
+};
+
 }

+ 2 - 0
src/matrix.h

@@ -73,5 +73,7 @@ class Matrix {
 
   void save(std::ostream&);
   void load(std::istream&);
+
+  void dump(std::ostream&) const;
 };
 }