File size: 3,594 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
#ifndef LM_TRIE_H
#define LM_TRIE_H
#include "lm/weights.hh"
#include "lm/word_index.hh"
#include "util/bit_packing.hh"
#include <cstddef>
#include <stdint.h>
namespace lm {
namespace ngram {
struct Config;
namespace trie {
struct NodeRange {
uint64_t begin, end;
};
// TODO: if the number of unigrams is a concern, also bit pack these records.
struct UnigramValue {
ProbBackoff weights;
uint64_t next;
uint64_t Next() const { return next; }
};
class UnigramPointer {
public:
explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}
UnigramPointer() : to_(NULL) {}
bool Found() const { return to_ != NULL; }
float Prob() const { return to_->prob; }
float Backoff() const { return to_->backoff; }
float Rest() const { return Prob(); }
private:
const ProbBackoff *to_;
};
class Unigram {
public:
Unigram() {}
void Init(void *start) {
unigram_ = static_cast<UnigramValue*>(start);
}
static uint64_t Size(uint64_t count) {
// +1 in case unknown doesn't appear. +1 for the final next.
return (count + 2) * sizeof(UnigramValue);
}
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
ProbBackoff &Unknown() { return unigram_[0].weights; }
UnigramValue *Raw() {
return unigram_;
}
UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
next.begin = val->next;
next.end = (val+1)->next;
return UnigramPointer(val->weights);
}
private:
UnigramValue *unigram_;
};
class BitPacked {
public:
BitPacked() {}
uint64_t InsertIndex() const {
return insert_index_;
}
protected:
static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
uint8_t word_bits_;
uint8_t total_bits_;
uint64_t word_mask_;
uint8_t *base_;
uint64_t insert_index_, max_vocab_;
};
template <class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
util::BitAddress Insert(WordIndex word);
void FinishedLoading(uint64_t next_end, const Config &config);
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
uint64_t addr = pointer * total_bits_;
addr += word_bits_;
bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
return util::BitAddress(base_, addr);
}
private:
uint8_t quant_bits_;
Bhiksha bhiksha_;
const BitPacked *next_source_;
};
class BitPackedLongest : public BitPacked {
public:
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, quant_bits);
}
BitPackedLongest() {}
void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
BaseInit(base, max_vocab, quant_bits);
}
util::BitAddress Insert(WordIndex word);
util::BitAddress Find(WordIndex word, const NodeRange &node) const;
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_TRIE_H
|