|
#ifndef LM_VOCAB_H |
|
#define LM_VOCAB_H |
|
|
|
#include "lm/enumerate_vocab.hh" |
|
#include "lm/lm_exception.hh" |
|
#include "lm/virtual_interface.hh" |
|
#include "util/fake_ofstream.hh" |
|
#include "util/murmur_hash.hh" |
|
#include "util/pool.hh" |
|
#include "util/probing_hash_table.hh" |
|
#include "util/sorted_uniform.hh" |
|
#include "util/string_piece.hh" |
|
|
|
#include <limits> |
|
#include <string> |
|
#include <vector> |
|
|
|
namespace lm { |
|
struct ProbBackoff; |
|
class EnumerateVocab; |
|
|
|
namespace ngram { |
|
struct Config; |
|
|
|
namespace detail { |
|
uint64_t HashForVocab(const char *str, std::size_t len); |
|
inline uint64_t HashForVocab(const StringPiece &str) { |
|
return HashForVocab(str.data(), str.length()); |
|
} |
|
struct ProbingVocabularyHeader; |
|
} |
|
|
|
class WriteWordsWrapper : public EnumerateVocab { |
|
public: |
|
WriteWordsWrapper(EnumerateVocab *inner); |
|
|
|
~WriteWordsWrapper(); |
|
|
|
void Add(WordIndex index, const StringPiece &str); |
|
|
|
const std::string &Buffer() const { return buffer_; } |
|
|
|
private: |
|
EnumerateVocab *inner_; |
|
|
|
std::string buffer_; |
|
}; |
|
|
|
|
|
class SortedVocabulary : public base::Vocabulary { |
|
public: |
|
SortedVocabulary(); |
|
|
|
WordIndex Index(const StringPiece &str) const { |
|
const uint64_t *found; |
|
if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>( |
|
util::IdentityAccessor<uint64_t>(), |
|
begin_ - 1, 0, |
|
end_, std::numeric_limits<uint64_t>::max(), |
|
detail::HashForVocab(str), found)) { |
|
return found - begin_ + 1; |
|
} else { |
|
return 0; |
|
} |
|
} |
|
|
|
|
|
static uint64_t Size(uint64_t entries, const Config &config); |
|
|
|
|
|
WordIndex Bound() const { return bound_; } |
|
|
|
|
|
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); |
|
|
|
void Relocate(void *new_start); |
|
|
|
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); |
|
|
|
WordIndex Insert(const StringPiece &str); |
|
|
|
|
|
void FinishedLoading(ProbBackoff *reorder_vocab); |
|
|
|
|
|
std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } |
|
|
|
bool SawUnk() const { return saw_unk_; } |
|
|
|
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); |
|
|
|
private: |
|
uint64_t *begin_, *end_; |
|
|
|
WordIndex bound_; |
|
|
|
bool saw_unk_; |
|
|
|
EnumerateVocab *enumerate_; |
|
|
|
|
|
util::Pool string_backing_; |
|
|
|
std::vector<StringPiece> strings_to_enumerate_; |
|
}; |
|
|
|
#pragma pack(push) |
|
#pragma pack(4) |
|
struct ProbingVocabularyEntry { |
|
uint64_t key; |
|
WordIndex value; |
|
|
|
typedef uint64_t Key; |
|
uint64_t GetKey() const { return key; } |
|
void SetKey(uint64_t to) { key = to; } |
|
|
|
static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { |
|
ProbingVocabularyEntry ret; |
|
ret.key = key; |
|
ret.value = value; |
|
return ret; |
|
} |
|
}; |
|
#pragma pack(pop) |
|
|
|
|
|
class ProbingVocabulary : public base::Vocabulary { |
|
public: |
|
ProbingVocabulary(); |
|
|
|
WordIndex Index(const StringPiece &str) const { |
|
Lookup::ConstIterator i; |
|
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; |
|
} |
|
|
|
static uint64_t Size(uint64_t entries, float probing_multiplier); |
|
|
|
static uint64_t Size(uint64_t entries, const Config &config); |
|
|
|
|
|
WordIndex Bound() const { return bound_; } |
|
|
|
|
|
void SetupMemory(void *start, std::size_t allocated); |
|
void SetupMemory(void *start, std::size_t allocated, std::size_t , const Config &) { |
|
SetupMemory(start, allocated); |
|
} |
|
|
|
void Relocate(void *new_start); |
|
|
|
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); |
|
|
|
WordIndex Insert(const StringPiece &str); |
|
|
|
template <class Weights> void FinishedLoading(Weights * ) { |
|
FinishedLoading(); |
|
} |
|
void FinishedLoading(); |
|
|
|
std::size_t UnkCountChangePadding() const { return 0; } |
|
|
|
bool SawUnk() const { return saw_unk_; } |
|
|
|
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); |
|
|
|
private: |
|
typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup; |
|
|
|
Lookup lookup_; |
|
|
|
WordIndex bound_; |
|
|
|
bool saw_unk_; |
|
|
|
EnumerateVocab *enumerate_; |
|
|
|
detail::ProbingVocabularyHeader *header_; |
|
}; |
|
|
|
void MissingUnknown(const Config &config) throw(SpecialWordMissingException); |
|
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException); |
|
|
|
template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) { |
|
if (!vocab.SawUnk()) MissingUnknown(config); |
|
if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>"); |
|
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); |
|
} |
|
|
|
class WriteUniqueWords { |
|
public: |
|
explicit WriteUniqueWords(int fd) : word_list_(fd) {} |
|
|
|
void operator()(const StringPiece &word) { |
|
word_list_ << word << '\0'; |
|
} |
|
|
|
private: |
|
util::FakeOFStream word_list_; |
|
}; |
|
|
|
class NoOpUniqueWords { |
|
public: |
|
NoOpUniqueWords() {} |
|
void operator()(const StringPiece &word) {} |
|
}; |
|
|
|
template <class NewWordAction = NoOpUniqueWords> class GrowableVocab { |
|
public: |
|
static std::size_t MemUsage(WordIndex content) { |
|
return Lookup::MemUsage(content > 2 ? content : 2); |
|
} |
|
|
|
|
|
template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) |
|
: lookup_(initial_size), new_word_(new_word_construct) { |
|
FindOrInsert("<unk>"); |
|
FindOrInsert("<s>"); |
|
FindOrInsert("</s>"); |
|
} |
|
|
|
WordIndex Index(const StringPiece &str) const { |
|
Lookup::ConstIterator i; |
|
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; |
|
} |
|
|
|
WordIndex FindOrInsert(const StringPiece &word) { |
|
ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); |
|
Lookup::MutableIterator it; |
|
if (!lookup_.FindOrInsert(entry, it)) { |
|
new_word_(word); |
|
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh"); |
|
} |
|
return it->value; |
|
} |
|
|
|
WordIndex Size() const { return lookup_.Size(); } |
|
|
|
private: |
|
typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup; |
|
|
|
Lookup lookup_; |
|
|
|
NewWordAction new_word_; |
|
}; |
|
|
|
} |
|
} |
|
|
|
#endif |
|
|