Explorar o código

getLine calculate hash only once

Summary: This diff removes some of the duplicated calles to hash within getLine.

Reviewed By: EdouardGrave

Differential Revision: D5042887

fbshipit-source-id: 3864c7cfab9375d374ddf439fa4ab8e4ab9dfda1
Christian Puhrsch %!s(int64=8) %!d(string=hai) anos
pai
achega
e39dfbe598
Modificáronse 2 ficheiros con 18 adicións e 7 borrados
  1. 16 7
      src/dictionary.cc
  2. 2 0
      src/dictionary.h

+ 16 - 7
src/dictionary.cc

@@ -28,11 +28,15 @@ Dictionary::Dictionary(std::shared_ptr<Args> args) : args_(args),
   ntokens_(0) {}
 
 int32_t Dictionary::find(const std::string& w) const {
-  int32_t h = hash(w) % MAX_VOCAB_SIZE;
-  while (word2int_[h] != -1 && words_[word2int_[h]].word != w) {
-    h = (h + 1) % MAX_VOCAB_SIZE;
+  return find(w, hash(w));
+}
+
+int32_t Dictionary::find(const std::string& w, uint32_t h) const {
+  int32_t id = h % MAX_VOCAB_SIZE;
+  while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
+    id = (id + 1) % MAX_VOCAB_SIZE;
   }
-  return h;
+  return id;
 }
 
 void Dictionary::add(const std::string& w) {
@@ -102,6 +106,11 @@ bool Dictionary::discard(int32_t id, real rand) const {
   return rand > pdiscard_[id];
 }
 
+int32_t Dictionary::getId(const std::string& w, uint32_t h) const {
+  int32_t id = find(w, h);
+  return word2int_[id];
+}
+
 int32_t Dictionary::getId(const std::string& w) const {
   int32_t h = find(w);
   return word2int_[h];
@@ -308,11 +317,11 @@ int32_t Dictionary::getLine(std::istream& in,
   int32_t ntokens = 0;
   std::string token;
   while (readWord(in, token)) {
-    int32_t h = find(token);
-    int32_t wid = word2int_[h];
+    uint32_t h = hash(token);
+    int32_t wid = getId(token, h);
     if (wid < 0) {
       entry_type type = getType(token);
-      if (type == entry_type::word) word_hashes.push_back(hash(token));
+      if (type == entry_type::word) word_hashes.push_back(h);
       continue;
     }
     entry_type type = getType(wid);

+ 2 - 0
src/dictionary.h

@@ -39,6 +39,7 @@ class Dictionary {
     static const int32_t MAX_LINE_SIZE = 1024;
 
     int32_t find(const std::string&) const;
+    int32_t find(const std::string&, uint32_t h) const;
     void initTableDiscard();
     void initNgrams();
 
@@ -70,6 +71,7 @@ class Dictionary {
     int32_t nlabels() const;
     int64_t ntokens() const;
     int32_t getId(const std::string&) const;
+    int32_t getId(const std::string&, uint32_t h) const;
     entry_type getType(int32_t) const;
     entry_type getType(const std::string&) const;
     bool discard(int32_t, real) const;