|
#include "initial_probabilities.hh" |
|
|
|
#include "discount.hh" |
|
#include "hash_gamma.hh" |
|
#include "payload.hh" |
|
#include "../common/special.hh" |
|
#include "../common/ngram_stream.hh" |
|
#include "../../util/murmur_hash.hh" |
|
#include "../../util/file.hh" |
|
#include "../../util/stream/chain.hh" |
|
#include "../../util/stream/io.hh" |
|
#include "../../util/stream/stream.hh" |
|
|
|
#include <vector> |
|
|
|
namespace lm { namespace builder { |
|
|
|
namespace { |
|
struct BufferEntry { |
|
|
|
float gamma; |
|
|
|
float denominator; |
|
}; |
|
|
|
struct HashBufferEntry : public BufferEntry { |
|
|
|
uint64_t hash_value; |
|
}; |
|
|
|
|
|
|
|
|
|
class PruneNGramStream { |
|
public: |
|
PruneNGramStream(const util::stream::ChainPosition &position, const SpecialVocab &specials) : |
|
current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())), |
|
dest_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())), |
|
currentCount_(0), |
|
block_(position), |
|
specials_(specials) |
|
{ |
|
StartBlock(); |
|
} |
|
|
|
NGram<BuildingPayload> &operator*() { return current_; } |
|
NGram<BuildingPayload> *operator->() { return ¤t_; } |
|
|
|
operator bool() const { |
|
return block_; |
|
} |
|
|
|
PruneNGramStream &operator++() { |
|
assert(block_); |
|
if(UTIL_UNLIKELY(current_.Order() == 1 && specials_.IsSpecial(*current_.begin()))) |
|
dest_.NextInMemory(); |
|
else if(currentCount_ > 0) { |
|
if(dest_.Base() < current_.Base()) { |
|
memcpy(dest_.Base(), current_.Base(), current_.TotalSize()); |
|
} |
|
dest_.NextInMemory(); |
|
} |
|
|
|
current_.NextInMemory(); |
|
|
|
uint8_t *block_base = static_cast<uint8_t*>(block_->Get()); |
|
if (current_.Base() == block_base + block_->ValidSize()) { |
|
block_->SetValidSize(dest_.Base() - block_base); |
|
++block_; |
|
StartBlock(); |
|
if (block_) { |
|
currentCount_ = current_.Value().CutoffCount(); |
|
} |
|
} else { |
|
currentCount_ = current_.Value().CutoffCount(); |
|
} |
|
|
|
return *this; |
|
} |
|
|
|
private: |
|
void StartBlock() { |
|
for (; ; ++block_) { |
|
if (!block_) return; |
|
if (block_->ValidSize()) break; |
|
} |
|
current_.ReBase(block_->Get()); |
|
currentCount_ = current_.Value().CutoffCount(); |
|
|
|
dest_.ReBase(block_->Get()); |
|
} |
|
|
|
NGram<BuildingPayload> current_; |
|
NGram<BuildingPayload> dest_; |
|
|
|
uint64_t currentCount_; |
|
|
|
util::stream::Link block_; |
|
|
|
const SpecialVocab specials_; |
|
}; |
|
|
|
|
|
class OnlyGamma { |
|
public: |
|
explicit OnlyGamma(bool pruning) : pruning_(pruning) {} |
|
|
|
void Run(const util::stream::ChainPosition &position) { |
|
for (util::stream::Link block_it(position); block_it; ++block_it) { |
|
if(pruning_) { |
|
const HashBufferEntry *in = static_cast<const HashBufferEntry*>(block_it->Get()); |
|
const HashBufferEntry *end = static_cast<const HashBufferEntry*>(block_it->ValidEnd()); |
|
|
|
|
|
|
|
HashGamma *out = static_cast<HashGamma*>(block_it->Get()); |
|
for (; in < end; out += 1, in += 1) { |
|
|
|
float gamma_buf = in->gamma; |
|
uint64_t hash_buf = in->hash_value; |
|
|
|
out->gamma = gamma_buf; |
|
out->hash_value = hash_buf; |
|
} |
|
block_it->SetValidSize((block_it->ValidSize() * sizeof(HashGamma)) / sizeof(HashBufferEntry)); |
|
} |
|
else { |
|
float *out = static_cast<float*>(block_it->Get()); |
|
const float *in = out; |
|
const float *end = static_cast<const float*>(block_it->ValidEnd()); |
|
for (out += 1, in += 2; in < end; out += 1, in += 2) { |
|
*out = *in; |
|
} |
|
block_it->SetValidSize(block_it->ValidSize() / 2); |
|
} |
|
} |
|
} |
|
|
|
private: |
|
bool pruning_; |
|
}; |
|
|
|
class AddRight { |
|
public: |
|
AddRight(const Discount &discount, const util::stream::ChainPosition &input, bool pruning) |
|
: discount_(discount), input_(input), pruning_(pruning) {} |
|
|
|
void Run(const util::stream::ChainPosition &output) { |
|
NGramStream<BuildingPayload> in(input_); |
|
util::stream::Stream out(output); |
|
|
|
std::vector<WordIndex> previous(in->Order() - 1); |
|
|
|
void *const previous_raw = previous.empty() ? NULL : static_cast<void*>(&previous[0]); |
|
const std::size_t size = sizeof(WordIndex) * previous.size(); |
|
|
|
for(; in; ++out) { |
|
memcpy(previous_raw, in->begin(), size); |
|
uint64_t denominator = 0; |
|
uint64_t normalizer = 0; |
|
|
|
uint64_t counts[4]; |
|
memset(counts, 0, sizeof(counts)); |
|
do { |
|
denominator += in->Value().UnmarkedCount(); |
|
|
|
|
|
|
|
normalizer += in->Value().UnmarkedCount() - in->Value().CutoffCount(); |
|
|
|
|
|
|
|
|
|
if(in->Value().CutoffCount() > 0) |
|
++counts[std::min(in->Value().CutoffCount(), static_cast<uint64_t>(3))]; |
|
|
|
} while (++in && !memcmp(previous_raw, in->begin(), size)); |
|
|
|
BufferEntry &entry = *reinterpret_cast<BufferEntry*>(out.Get()); |
|
entry.denominator = static_cast<float>(denominator); |
|
entry.gamma = 0.0; |
|
for (unsigned i = 1; i <= 3; ++i) { |
|
entry.gamma += discount_.Get(i) * static_cast<float>(counts[i]); |
|
} |
|
|
|
|
|
entry.gamma += normalizer; |
|
|
|
entry.gamma /= entry.denominator; |
|
|
|
if(pruning_) { |
|
|
|
|
|
static_cast<HashBufferEntry*>(&entry)->hash_value = util::MurmurHashNative(previous_raw, size); |
|
} |
|
} |
|
out.Poison(); |
|
} |
|
|
|
private: |
|
const Discount &discount_; |
|
const util::stream::ChainPosition input_; |
|
bool pruning_; |
|
}; |
|
|
|
class MergeRight { |
|
public: |
|
MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount, const SpecialVocab &specials) |
|
: interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount), specials_(specials) {} |
|
|
|
|
|
|
|
void Run(const util::stream::ChainPosition &primary) { |
|
util::stream::Stream summed(from_adder_); |
|
|
|
PruneNGramStream grams(primary, specials_); |
|
|
|
|
|
if (grams->Order() == 1) { |
|
BufferEntry sums(*static_cast<const BufferEntry*>(summed.Get())); |
|
|
|
assert(*grams->begin() == kUNK); |
|
float gamma_assign; |
|
if (interpolate_unigrams_) { |
|
|
|
gamma_assign = sums.gamma; |
|
grams->Value().uninterp.prob = 0.0; |
|
} else { |
|
|
|
gamma_assign = 0.0; |
|
grams->Value().uninterp.prob = sums.gamma; |
|
} |
|
grams->Value().uninterp.gamma = gamma_assign; |
|
|
|
for (++grams; *grams->begin() != specials_.BOS(); ++grams) { |
|
grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator; |
|
grams->Value().uninterp.gamma = gamma_assign; |
|
} |
|
|
|
|
|
|
|
|
|
assert(*grams->begin() == specials_.BOS()); |
|
grams->Value().uninterp.prob = 1.0; |
|
grams->Value().uninterp.gamma = 0.0; |
|
|
|
while (++grams) { |
|
grams->Value().uninterp.prob = discount_.Apply(grams->Value().count) / sums.denominator; |
|
grams->Value().uninterp.gamma = gamma_assign; |
|
} |
|
++summed; |
|
return; |
|
} |
|
|
|
std::vector<WordIndex> previous(grams->Order() - 1); |
|
const std::size_t size = sizeof(WordIndex) * previous.size(); |
|
for (; grams; ++summed) { |
|
memcpy(&previous[0], grams->begin(), size); |
|
const BufferEntry &sums = *static_cast<const BufferEntry*>(summed.Get()); |
|
|
|
do { |
|
BuildingPayload &pay = grams->Value(); |
|
pay.uninterp.prob = discount_.Apply(grams->Value().UnmarkedCount()) / sums.denominator; |
|
pay.uninterp.gamma = sums.gamma; |
|
} while (++grams && !memcmp(&previous[0], grams->begin(), size)); |
|
} |
|
} |
|
|
|
private: |
|
bool interpolate_unigrams_; |
|
util::stream::ChainPosition from_adder_; |
|
Discount discount_; |
|
const SpecialVocab specials_; |
|
}; |
|
|
|
} |
|
|
|
void InitialProbabilities( |
|
const InitialProbabilitiesConfig &config, |
|
const std::vector<Discount> &discounts, |
|
util::stream::Chains &primary, |
|
util::stream::Chains &second_in, |
|
util::stream::Chains &gamma_out, |
|
const std::vector<uint64_t> &prune_thresholds, |
|
bool prune_vocab, |
|
const SpecialVocab &specials) { |
|
for (size_t i = 0; i < primary.size(); ++i) { |
|
util::stream::ChainConfig gamma_config = config.adder_out; |
|
if(prune_vocab || prune_thresholds[i] > 0) |
|
gamma_config.entry_size = sizeof(HashBufferEntry); |
|
else |
|
gamma_config.entry_size = sizeof(BufferEntry); |
|
|
|
util::stream::ChainPosition second(second_in[i].Add()); |
|
second_in[i] >> util::stream::kRecycle; |
|
gamma_out.push_back(gamma_config); |
|
gamma_out[i] >> AddRight(discounts[i], second, prune_vocab || prune_thresholds[i] > 0); |
|
|
|
primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i], specials); |
|
|
|
|
|
if (i) gamma_out[i] >> OnlyGamma(prune_vocab || prune_thresholds[i] > 0); |
|
} |
|
} |
|
|
|
}} |
|
|