#include "corpus_count.hh" #include "payload.hh" #include "../common/ngram.hh" #include "../lm_exception.hh" #include "../vocab.hh" #include "../word_index.hh" #include "../../util/file_stream.hh" #include "../../util/file.hh" #include "../../util/file_piece.hh" #include "../../util/murmur_hash.hh" #include "../../util/probing_hash_table.hh" #include "../../util/scoped.hh" #include "../../util/stream/chain.hh" #include "../../util/tokenize_piece.hh" #include #include namespace lm { namespace builder { namespace { class DedupeHash : public std::unary_function { public: explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} std::size_t operator()(const WordIndex *start) const { return util::MurmurHashNative(start, size_); } private: const std::size_t size_; }; class DedupeEquals : public std::binary_function { public: explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} bool operator()(const WordIndex *first, const WordIndex *second) const { return !memcmp(first, second, size_); } private: const std::size_t size_; }; struct DedupeEntry { typedef WordIndex *Key; Key GetKey() const { return key; } void SetKey(WordIndex *to) { key = to; } Key key; static DedupeEntry Construct(WordIndex *at) { DedupeEntry ret; ret.key = at; return ret; } }; // TODO: don't have this here, should be with probing hash table defaults? const float kProbingMultiplier = 1.5; typedef util::ProbingHashTable Dedupe; class Writer { public: Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) : block_(position), gram_(block_->Get(), order), dedupe_invalid_(order, std::numeric_limits::max()), dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), buffer_(new WordIndex[order - 1]), block_size_(position.GetChain().BlockSize()) { dedupe_.Clear(); assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); if (order == 1) { // Add special words. AdjustCounts is responsible if order != 1. AddUnigramWord(kUNK); AddUnigramWord(kBOS); } } ~Writer() { block_->SetValidSize(reinterpret_cast(gram_.begin()) - static_cast(block_->Get())); (++block_).Poison(); } // Write context with a bunch of void StartSentence() { for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) { *i = kBOS; } } void Append(WordIndex word) { *(gram_.end() - 1) = word; Dedupe::MutableIterator at; bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); if (found) { // Already present. NGram already(at->key, gram_.Order()); ++(already.Value().count); // Shift left by one. memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); return; } // Complete the write. gram_.Value().count = 1; // Prepare the next n-gram. if (reinterpret_cast(gram_.begin()) + gram_.TotalSize() != static_cast(block_->Get()) + block_size_) { NGram last(gram_); gram_.NextInMemory(); std::copy(last.begin() + 1, last.end(), gram_.begin()); return; } // Block end. Need to store the context in a temporary buffer. std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); dedupe_.Clear(); block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); } private: void AddUnigramWord(WordIndex index) { *gram_.begin() = index; gram_.Value().count = 0; gram_.NextInMemory(); if (gram_.Base() == static_cast(block_->Get()) + block_size_) { block_->SetValidSize(block_size_); gram_.ReBase((++block_)->Get()); } } util::stream::Link block_; NGram gram_; // This is the memory behind the invalid value in dedupe_. std::vector dedupe_invalid_; // Hash table combiner implementation. Dedupe dedupe_; // Small buffer to hold existing ngrams when shifting across a block boundary. boost::scoped_array buffer_; const std::size_t block_size_; }; } // namespace float CorpusCount::DedupeMultiplier(std::size_t order) { return kProbingMultiplier * static_cast(sizeof(DedupeEntry)) / static_cast(NGram::TotalSize(order)); } std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) { return ngram::GrowableVocab::MemUsage(vocab_estimate); } CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, bool dynamic_vocab, uint64_t &token_count, WordIndex &type_count, std::vector &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol) : from_(from), vocab_write_(vocab_write), dynamic_vocab_(dynamic_vocab), token_count_(token_count), type_count_(type_count), prune_words_(prune_words), prune_vocab_filename_(prune_vocab_filename), dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)), disallowed_symbol_action_(disallowed_symbol) { } namespace { void ComplainDisallowed(StringPiece word, WarningAction &action) { switch (action) { case SILENT: return; case COMPLAIN: std::cerr << "Warning: " << word << " appears in the input. All instances of , , and will be interpreted as whitespace." << std::endl; action = SILENT; return; case THROW_UP: UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing in the future. Pass --skip_symbols to convert these symbols to whitespace."); } } // Vocab ids are given in a precompiled hash table. class VocabGiven { public: explicit VocabGiven(int fd) { util::MapRead(util::POPULATE_OR_READ, fd, 0, util::CheckOverflow(util::SizeOrThrow(fd)), table_backing_); // Leave space for header with size. table_ = Table(static_cast(table_backing_.get()) + sizeof(uint64_t), table_backing_.size() - sizeof(uint64_t)); bos_ = FindOrInsert(""); eos_ = FindOrInsert(""); } WordIndex FindOrInsert(const StringPiece &word) const { Table::ConstIterator it; if (table_.Find(util::MurmurHash64A(word.data(), word.size()), it)) { return it->value; } else { return 0; // . } } WordIndex Index(const StringPiece &word) const { return FindOrInsert(word); } WordIndex Size() const { return *static_cast(table_backing_.get()); } bool IsSpecial(WordIndex word) const { return word == 0 || word == bos_ || word == eos_; } private: util::scoped_memory table_backing_; typedef util::ProbingHashTable Table; Table table_; WordIndex bos_, eos_; }; } // namespace void CorpusCount::Run(const util::stream::ChainPosition &position) { if (dynamic_vocab_) { ngram::GrowableVocab vocab(type_count_, vocab_write_); RunWithVocab(position, vocab); } else { VocabGiven vocab(vocab_write_); RunWithVocab(position, vocab); } } template void CorpusCount::RunWithVocab(const util::stream::ChainPosition &position, Vocab &vocab) { token_count_ = 0; type_count_ = 0; const WordIndex end_sentence = vocab.FindOrInsert(""); Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); uint64_t count = 0; bool delimiters[256]; util::BoolCharacter::Build("\0\t\n\r ", delimiters); StringPiece w; while(true) { writer.StartSentence(); while (from_.ReadWordSameLine(w, delimiters)) { WordIndex word = vocab.FindOrInsert(w); if (UTIL_UNLIKELY(vocab.IsSpecial(word))) { ComplainDisallowed(w, disallowed_symbol_action_); continue; } writer.Append(word); ++count; } if (!from_.ReadLineOrEOF(w)) break; writer.Append(end_sentence); } token_count_ = count; type_count_ = vocab.Size(); // Create list of unigrams that are supposed to be pruned if (!prune_vocab_filename_.empty()) { try { util::FilePiece prune_vocab_file(prune_vocab_filename_.c_str()); prune_words_.resize(vocab.Size(), true); try { while (true) { StringPiece word(prune_vocab_file.ReadDelimited(delimiters)); prune_words_[vocab.Index(word)] = false; } } catch (const util::EndOfFileException &e) {} // Never prune , , prune_words_[kUNK] = false; prune_words_[kBOS] = false; prune_words_[kEOS] = false; } catch (const util::Exception &e) { std::cerr << e.what() << std::endl; abort(); } } } } // namespace builder } // namespace lm