File size: 5,778 Bytes
1ce325b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#ifndef LM_FILTER_PHRASE_H
#define LM_FILTER_PHRASE_H

#include "../../util/murmur_hash.hh"
#include "../../util/string_piece.hh"
#include "../../util/tokenize_piece.hh"

#include <boost/unordered_map.hpp>

#include <iosfwd>
#include <vector>

#define LM_FILTER_PHRASE_METHOD(caps, lower) \
bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
  Table::const_iterator i(table_.find(key));\
  if (i==table_.end()) return false; \
  out = &i->second.lower; \
  return true; \
}

namespace lm {
namespace phrase {

typedef uint64_t Hash;

class Substrings {
  private:
    /* This is the value in a hash table where the key is a string.  It indicates
     * four sets of sentences:
     * substring is sentences with a phrase containing the key as a substring.
     * left is sentencess with a phrase that begins with the key (left aligned).
     * right is sentences with a phrase that ends with the key (right aligned).
     * phrase is sentences where the key is a phrase.
     * Each set is encoded as a vector of sentence ids in increasing order.
     */
    struct SentenceRelation {
      std::vector<unsigned int> substring, left, right, phrase;
    };
    /* Most of the CPU is hash table lookups, so let's not complicate it with
     * vector equality comparisons.  If a collision happens, the SentenceRelation
     * structure will contain the union of sentence ids over the colliding strings.
     * In that case, the filter will be slightly more permissive.
     * The key here is the same as boost's hash of std::vector<std::string>.
     */
    typedef boost::unordered_map<Hash, SentenceRelation> Table;

  public:
    Substrings() {}

    /* If the string isn't a substring of any phrase, return NULL.  Otherwise,
     * return a pointer to std::vector<unsigned int> listing sentences with
     * matching phrases.  This set may be empty for Left, Right, or Phrase.
     * Example: const std::vector<unsigned int> *FindSubstring(Hash key)
     */
    LM_FILTER_PHRASE_METHOD(Substring, substring)
    LM_FILTER_PHRASE_METHOD(Left, left)
    LM_FILTER_PHRASE_METHOD(Right, right)
    LM_FILTER_PHRASE_METHOD(Phrase, phrase)

#pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization
    // sentence_id must be non-decreasing.  Iterators are over words in the phrase.
    template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
      // Iterate over all substrings.
      for (Iterator start = begin; start != end; ++start) {
        Hash hash = 0;
        SentenceRelation *relation;
        for (Iterator finish = start; finish != end; ++finish) {
          hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
          // Now hash is of [start, finish].
          relation = &table_[hash];
          AppendSentence(relation->substring, sentence_id);
          if (start == begin) AppendSentence(relation->left, sentence_id);
        }
        AppendSentence(relation->right, sentence_id);
        if (start == begin) AppendSentence(relation->phrase, sentence_id);
      }
    }

  private:
    void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
      if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
    }

    Table table_;
};

// Read a file with one sentence per line containing tab-delimited phrases of
// space-separated words.
unsigned int ReadMultiple(std::istream &in, Substrings &out);

namespace detail {
extern const StringPiece kEndSentence;

template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
  hashes.clear();
  if (i == end) return;
  // TODO: check strict phrase boundaries after <s> and before </s>.  For now, just skip tags.
  if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
    ++i;
  }
  for (; i != end && (*i != kEndSentence); ++i) {
    hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
  }
}

class Vertex;
class Arc;

class ConditionCommon {
  protected:
    ConditionCommon(const Substrings &substrings);
    ConditionCommon(const ConditionCommon &from);

    ~ConditionCommon();

    detail::Vertex &MakeGraph();

    // Temporaries in PassNGram and Evaluate to avoid reallocation.
    std::vector<Hash> hashes_;

  private:
    std::vector<detail::Vertex> vertices_;
    std::vector<detail::Arc> arcs_;

    const Substrings &substrings_;
};

} // namespace detail

class Union : public detail::ConditionCommon {
  public:
    explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}

    template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
      detail::MakeHashes(begin, end, hashes_);
      return hashes_.empty() || Evaluate();
    }

  private:
    bool Evaluate();
};

class Multiple : public detail::ConditionCommon {
  public:
    explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}

    template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
      detail::MakeHashes(begin, end, hashes_);
      if (hashes_.empty()) {
        output.AddNGram(line);
      } else {
        Evaluate(line, output);
      }
    }

    template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
      AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
    }

    void Flush() const {}

  private:
    template <class Output> void Evaluate(const StringPiece &line, Output &output);
};

} // namespace phrase
} // namespace lm
#endif // LM_FILTER_PHRASE_H