#ifndef LM_FILTER_PHRASE_H #define LM_FILTER_PHRASE_H #include "../../util/murmur_hash.hh" #include "../../util/string_piece.hh" #include "../../util/tokenize_piece.hh" #include #include #include #define LM_FILTER_PHRASE_METHOD(caps, lower) \ bool Find##caps(Hash key, const std::vector *&out) const {\ Table::const_iterator i(table_.find(key));\ if (i==table_.end()) return false; \ out = &i->second.lower; \ return true; \ } namespace lm { namespace phrase { typedef uint64_t Hash; class Substrings { private: /* This is the value in a hash table where the key is a string. It indicates * four sets of sentences: * substring is sentences with a phrase containing the key as a substring. * left is sentencess with a phrase that begins with the key (left aligned). * right is sentences with a phrase that ends with the key (right aligned). * phrase is sentences where the key is a phrase. * Each set is encoded as a vector of sentence ids in increasing order. */ struct SentenceRelation { std::vector substring, left, right, phrase; }; /* Most of the CPU is hash table lookups, so let's not complicate it with * vector equality comparisons. If a collision happens, the SentenceRelation * structure will contain the union of sentence ids over the colliding strings. * In that case, the filter will be slightly more permissive. * The key here is the same as boost's hash of std::vector. */ typedef boost::unordered_map Table; public: Substrings() {} /* If the string isn't a substring of any phrase, return NULL. Otherwise, * return a pointer to std::vector listing sentences with * matching phrases. This set may be empty for Left, Right, or Phrase. * Example: const std::vector *FindSubstring(Hash key) */ LM_FILTER_PHRASE_METHOD(Substring, substring) LM_FILTER_PHRASE_METHOD(Left, left) LM_FILTER_PHRASE_METHOD(Right, right) LM_FILTER_PHRASE_METHOD(Phrase, phrase) #pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization // sentence_id must be non-decreasing. Iterators are over words in the phrase. template void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) { // Iterate over all substrings. for (Iterator start = begin; start != end; ++start) { Hash hash = 0; SentenceRelation *relation; for (Iterator finish = start; finish != end; ++finish) { hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish); // Now hash is of [start, finish]. relation = &table_[hash]; AppendSentence(relation->substring, sentence_id); if (start == begin) AppendSentence(relation->left, sentence_id); } AppendSentence(relation->right, sentence_id); if (start == begin) AppendSentence(relation->phrase, sentence_id); } } private: void AppendSentence(std::vector &vec, unsigned int sentence_id) { if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id); } Table table_; }; // Read a file with one sentence per line containing tab-delimited phrases of // space-separated words. unsigned int ReadMultiple(std::istream &in, Substrings &out); namespace detail { extern const StringPiece kEndSentence; template void MakeHashes(Iterator i, const Iterator &end, std::vector &hashes) { hashes.clear(); if (i == end) return; // TODO: check strict phrase boundaries after and before . For now, just skip tags. if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) { ++i; } for (; i != end && (*i != kEndSentence); ++i) { hashes.push_back(util::MurmurHashNative(i->data(), i->size())); } } class Vertex; class Arc; class ConditionCommon { protected: ConditionCommon(const Substrings &substrings); ConditionCommon(const ConditionCommon &from); ~ConditionCommon(); detail::Vertex &MakeGraph(); // Temporaries in PassNGram and Evaluate to avoid reallocation. std::vector hashes_; private: std::vector vertices_; std::vector arcs_; const Substrings &substrings_; }; } // namespace detail class Union : public detail::ConditionCommon { public: explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {} template bool PassNGram(const Iterator &begin, const Iterator &end) { detail::MakeHashes(begin, end, hashes_); return hashes_.empty() || Evaluate(); } private: bool Evaluate(); }; class Multiple : public detail::ConditionCommon { public: explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {} template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { detail::MakeHashes(begin, end, hashes_); if (hashes_.empty()) { output.AddNGram(line); } else { Evaluate(line, output); } } template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); } void Flush() const {} private: template void Evaluate(const StringPiece &line, Output &output); }; } // namespace phrase } // namespace lm #endif // LM_FILTER_PHRASE_H