File size: 5,778 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
#ifndef LM_FILTER_PHRASE_H
#define LM_FILTER_PHRASE_H
#include "../../util/murmur_hash.hh"
#include "../../util/string_piece.hh"
#include "../../util/tokenize_piece.hh"
#include <boost/unordered_map.hpp>
#include <iosfwd>
#include <vector>
#define LM_FILTER_PHRASE_METHOD(caps, lower) \
bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
Table::const_iterator i(table_.find(key));\
if (i==table_.end()) return false; \
out = &i->second.lower; \
return true; \
}
namespace lm {
namespace phrase {
typedef uint64_t Hash;
class Substrings {
private:
/* This is the value in a hash table where the key is a string. It indicates
* four sets of sentences:
* substring is sentences with a phrase containing the key as a substring.
* left is sentencess with a phrase that begins with the key (left aligned).
* right is sentences with a phrase that ends with the key (right aligned).
* phrase is sentences where the key is a phrase.
* Each set is encoded as a vector of sentence ids in increasing order.
*/
struct SentenceRelation {
std::vector<unsigned int> substring, left, right, phrase;
};
/* Most of the CPU is hash table lookups, so let's not complicate it with
* vector equality comparisons. If a collision happens, the SentenceRelation
* structure will contain the union of sentence ids over the colliding strings.
* In that case, the filter will be slightly more permissive.
* The key here is the same as boost's hash of std::vector<std::string>.
*/
typedef boost::unordered_map<Hash, SentenceRelation> Table;
public:
Substrings() {}
/* If the string isn't a substring of any phrase, return NULL. Otherwise,
* return a pointer to std::vector<unsigned int> listing sentences with
* matching phrases. This set may be empty for Left, Right, or Phrase.
* Example: const std::vector<unsigned int> *FindSubstring(Hash key)
*/
LM_FILTER_PHRASE_METHOD(Substring, substring)
LM_FILTER_PHRASE_METHOD(Left, left)
LM_FILTER_PHRASE_METHOD(Right, right)
LM_FILTER_PHRASE_METHOD(Phrase, phrase)
#pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization
// sentence_id must be non-decreasing. Iterators are over words in the phrase.
template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
// Iterate over all substrings.
for (Iterator start = begin; start != end; ++start) {
Hash hash = 0;
SentenceRelation *relation;
for (Iterator finish = start; finish != end; ++finish) {
hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
// Now hash is of [start, finish].
relation = &table_[hash];
AppendSentence(relation->substring, sentence_id);
if (start == begin) AppendSentence(relation->left, sentence_id);
}
AppendSentence(relation->right, sentence_id);
if (start == begin) AppendSentence(relation->phrase, sentence_id);
}
}
private:
void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
}
Table table_;
};
// Read a file with one sentence per line containing tab-delimited phrases of
// space-separated words.
unsigned int ReadMultiple(std::istream &in, Substrings &out);
namespace detail {
extern const StringPiece kEndSentence;
template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
hashes.clear();
if (i == end) return;
// TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags.
if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
++i;
}
for (; i != end && (*i != kEndSentence); ++i) {
hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
}
}
class Vertex;
class Arc;
class ConditionCommon {
protected:
ConditionCommon(const Substrings &substrings);
ConditionCommon(const ConditionCommon &from);
~ConditionCommon();
detail::Vertex &MakeGraph();
// Temporaries in PassNGram and Evaluate to avoid reallocation.
std::vector<Hash> hashes_;
private:
std::vector<detail::Vertex> vertices_;
std::vector<detail::Arc> arcs_;
const Substrings &substrings_;
};
} // namespace detail
class Union : public detail::ConditionCommon {
public:
explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
detail::MakeHashes(begin, end, hashes_);
return hashes_.empty() || Evaluate();
}
private:
bool Evaluate();
};
class Multiple : public detail::ConditionCommon {
public:
explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
detail::MakeHashes(begin, end, hashes_);
if (hashes_.empty()) {
output.AddNGram(line);
} else {
Evaluate(line, output);
}
}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
}
void Flush() const {}
private:
template <class Output> void Evaluate(const StringPiece &line, Output &output);
};
} // namespace phrase
} // namespace lm
#endif // LM_FILTER_PHRASE_H
|