#ifndef LM_FILTER_VOCAB_H #define LM_FILTER_VOCAB_H // Vocabulary-based filters for language models. #include "../../util/multi_intersection.hh" #include "../../util/string_piece.hh" #include "../../util/string_piece_hash.hh" #include "../../util/tokenize_piece.hh" #include #include #include #include #include #include namespace lm { namespace vocab { void ReadSingle(std::istream &in, boost::unordered_set &out); // Read one sentence vocabulary per line. Return the number of sentences. unsigned int ReadMultiple(std::istream &in, boost::unordered_map > &out); /* Is this a special tag like or ? This actually includes anything * surrounded with < and >, which most tokenizers separate for real words, so * this should not catch real words as it looks at a single token. */ inline bool IsTag(const StringPiece &value) { // The parser should never give an empty string. assert(!value.empty()); return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>'); } class Single { public: typedef boost::unordered_set Words; explicit Single(const Words &vocab) : vocab_(vocab) {} template bool PassNGram(const Iterator &begin, const Iterator &end) { for (Iterator i = begin; i != end; ++i) { if (IsTag(*i)) continue; if (FindStringPiece(vocab_, *i) == vocab_.end()) return false; } return true; } private: const Words &vocab_; }; class Union { public: typedef boost::unordered_map > Words; explicit Union(const Words &vocabs) : vocabs_(vocabs) {} template bool PassNGram(const Iterator &begin, const Iterator &end) { sets_.clear(); for (Iterator i(begin); i != end; ++i) { if (IsTag(*i)) continue; Words::const_iterator found(FindStringPiece(vocabs_, *i)); if (vocabs_.end() == found) return false; sets_.push_back(boost::iterator_range(&*found->second.begin(), &*found->second.end())); } return (sets_.empty() || util::FirstIntersection(sets_)); } private: const Words &vocabs_; std::vector > sets_; }; class Multiple { public: typedef boost::unordered_map > Words; Multiple(const Words &vocabs) : vocabs_(vocabs) {} private: // Callback from AllIntersection that does AddNGram. template class Callback { public: Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {} void operator()(unsigned int index) { out_.SingleAddNGram(index, line_); } private: Output &out_; const StringPiece &line_; }; public: template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { sets_.clear(); for (Iterator i(begin); i != end; ++i) { if (IsTag(*i)) continue; Words::const_iterator found(FindStringPiece(vocabs_, *i)); if (vocabs_.end() == found) return; sets_.push_back(boost::iterator_range(&*found->second.begin(), &*found->second.end())); } if (sets_.empty()) { output.AddNGram(line); return; } Callback cb(output, line); util::AllIntersection(sets_, cb); } template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); } void Flush() const {} private: const Words &vocabs_; std::vector > sets_; }; } // namespace vocab } // namespace lm #endif // LM_FILTER_VOCAB_H