|
#ifndef LM_TRIE_H |
|
#define LM_TRIE_H |
|
|
|
#include "weights.hh" |
|
#include "word_index.hh" |
|
#include "../util/bit_packing.hh" |
|
|
|
#include <cstddef> |
|
|
|
#include <stdint.h> |
|
|
|
namespace lm { |
|
namespace ngram { |
|
struct Config; |
|
namespace trie { |
|
|
|
struct NodeRange { |
|
uint64_t begin, end; |
|
}; |
|
|
|
|
|
struct UnigramValue { |
|
ProbBackoff weights; |
|
uint64_t next; |
|
uint64_t Next() const { return next; } |
|
}; |
|
|
|
class UnigramPointer { |
|
public: |
|
explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {} |
|
|
|
UnigramPointer() : to_(NULL) {} |
|
|
|
bool Found() const { return to_ != NULL; } |
|
|
|
float Prob() const { return to_->prob; } |
|
float Backoff() const { return to_->backoff; } |
|
float Rest() const { return Prob(); } |
|
|
|
private: |
|
const ProbBackoff *to_; |
|
}; |
|
|
|
class Unigram { |
|
public: |
|
Unigram() {} |
|
|
|
void Init(void *start) { |
|
unigram_ = static_cast<UnigramValue*>(start); |
|
} |
|
|
|
static uint64_t Size(uint64_t count) { |
|
|
|
return (count + 2) * sizeof(UnigramValue); |
|
} |
|
|
|
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; } |
|
|
|
ProbBackoff &Unknown() { return unigram_[0].weights; } |
|
|
|
UnigramValue *Raw() { |
|
return unigram_; |
|
} |
|
|
|
UnigramPointer Find(WordIndex word, NodeRange &next) const { |
|
UnigramValue *val = unigram_ + word; |
|
next.begin = val->next; |
|
next.end = (val+1)->next; |
|
return UnigramPointer(val->weights); |
|
} |
|
|
|
private: |
|
UnigramValue *unigram_; |
|
}; |
|
|
|
class BitPacked { |
|
public: |
|
BitPacked() {} |
|
|
|
uint64_t InsertIndex() const { |
|
return insert_index_; |
|
} |
|
|
|
protected: |
|
static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); |
|
|
|
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); |
|
|
|
uint8_t word_bits_; |
|
uint8_t total_bits_; |
|
uint64_t word_mask_; |
|
|
|
uint8_t *base_; |
|
|
|
uint64_t insert_index_, max_vocab_; |
|
}; |
|
|
|
template <class Bhiksha> class BitPackedMiddle : public BitPacked { |
|
public: |
|
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); |
|
|
|
|
|
BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); |
|
|
|
util::BitAddress Insert(WordIndex word); |
|
|
|
void FinishedLoading(uint64_t next_end, const Config &config); |
|
|
|
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; |
|
|
|
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { |
|
uint64_t addr = pointer * total_bits_; |
|
addr += word_bits_; |
|
bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range); |
|
return util::BitAddress(base_, addr); |
|
} |
|
|
|
private: |
|
uint8_t quant_bits_; |
|
Bhiksha bhiksha_; |
|
|
|
const BitPacked *next_source_; |
|
}; |
|
|
|
class BitPackedLongest : public BitPacked { |
|
public: |
|
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { |
|
return BaseSize(entries, max_vocab, quant_bits); |
|
} |
|
|
|
BitPackedLongest() {} |
|
|
|
void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) { |
|
BaseInit(base, max_vocab, quant_bits); |
|
} |
|
|
|
util::BitAddress Insert(WordIndex word); |
|
|
|
util::BitAddress Find(WordIndex word, const NodeRange &node) const; |
|
}; |
|
|
|
} |
|
} |
|
} |
|
|
|
#endif |
|
|