|
#include "merge_probabilities.hh" |
|
#include "../common/ngram_stream.hh" |
|
#include "bounded_sequence_encoding.hh" |
|
#include "interpolate_info.hh" |
|
|
|
#include <algorithm> |
|
#include <limits> |
|
#include <numeric> |
|
|
|
namespace lm { |
|
namespace interpolate { |
|
|
|
|
|
|
|
|
|
|
|
BoundedSequenceEncoding MakeEncoder(const InterpolateInfo &info, uint8_t order) { |
|
util::FixedArray<uint8_t> max_orders(info.orders.size()); |
|
for (std::size_t i = 0; i < info.orders.size(); ++i) { |
|
max_orders.push_back(std::min(order, info.orders[i])); |
|
} |
|
return BoundedSequenceEncoding(max_orders.begin(), max_orders.end()); |
|
} |
|
|
|
namespace { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NGramHandler { |
|
public: |
|
NGramHandler(uint8_t order, const InterpolateInfo &ifo, |
|
util::FixedArray<util::stream::ChainPositions> &models_by_order) |
|
: info(ifo), |
|
encoder(MakeEncoder(info, order)), |
|
out_record(order, encoder.EncodedLength()) { |
|
std::size_t count_has_order = 0; |
|
for (std::size_t i = 0; i < models_by_order.size(); ++i) { |
|
count_has_order += (models_by_order[i].size() >= order); |
|
} |
|
inputs_.Init(count_has_order); |
|
for (std::size_t i = 0; i < models_by_order.size(); ++i) { |
|
if (models_by_order[i].size() < order) |
|
continue; |
|
inputs_.push_back(models_by_order[i][order - 1]); |
|
if (inputs_.back()) { |
|
active_.resize(active_.size() + 1); |
|
active_.back().model = i; |
|
active_.back().stream = &inputs_.back(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
probs.Init(info.Models()); |
|
from.Init(info.Models()); |
|
for (std::size_t i = 0; i < info.Models(); ++i) { |
|
probs.push_back(0.0); |
|
from.push_back(0); |
|
} |
|
} |
|
|
|
struct StreamIndex { |
|
NGramStream<ProbBackoff> *stream; |
|
NGramStream<ProbBackoff> &Stream() { return *stream; } |
|
std::size_t model; |
|
}; |
|
|
|
std::size_t ActiveSize() const { |
|
return active_.size(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
StreamIndex &operator[](std::size_t idx) { |
|
return active_[idx]; |
|
} |
|
|
|
void erase(std::size_t idx) { |
|
active_.erase(active_.begin() + idx); |
|
} |
|
|
|
const InterpolateInfo &info; |
|
BoundedSequenceEncoding encoder; |
|
PartialProbGamma out_record; |
|
util::FixedArray<float> probs; |
|
util::FixedArray<uint8_t> from; |
|
|
|
private: |
|
std::vector<StreamIndex> active_; |
|
NGramStreams<ProbBackoff> inputs_; |
|
}; |
|
|
|
|
|
|
|
|
|
class NGramHandlers : public util::FixedArray<NGramHandler> { |
|
public: |
|
explicit NGramHandlers(std::size_t num) |
|
: util::FixedArray<NGramHandler>(num) { |
|
} |
|
|
|
void push_back( |
|
std::size_t order, const InterpolateInfo &info, |
|
util::FixedArray<util::stream::ChainPositions> &models_by_order) { |
|
new (end()) NGramHandler(order, info, models_by_order); |
|
Constructed(); |
|
} |
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void HandleSuffix(NGramHandlers &handlers, WordIndex *suffix_begin, |
|
WordIndex *suffix_end, |
|
const util::FixedArray<float> &fallback_probs, |
|
const util::FixedArray<uint8_t> &fallback_from, |
|
float combined_fallback, |
|
util::stream::Streams &outputs) { |
|
uint8_t order = std::distance(suffix_begin, suffix_end) + 1; |
|
if (order > outputs.size()) return; |
|
|
|
util::stream::Stream &output = outputs[order - 1]; |
|
NGramHandler &handler = handlers[order - 1]; |
|
|
|
while (true) { |
|
|
|
|
|
WordIndex *minimum = NULL; |
|
for (std::size_t i = 0; i < handler.ActiveSize(); ++i) { |
|
if (!std::equal(suffix_begin, suffix_end, handler[i].Stream()->begin() + 1)) |
|
continue; |
|
|
|
|
|
|
|
WordIndex *last = handler[i].Stream()->begin(); |
|
if (!minimum || *last < *minimum) { minimum = handler[i].Stream()->begin(); } |
|
} |
|
|
|
|
|
if (!minimum) return; |
|
|
|
handler.out_record.ReBase(output.Get()); |
|
std::copy(minimum, minimum + order, handler.out_record.begin()); |
|
|
|
|
|
std::copy(fallback_probs.begin(), fallback_probs.end(), handler.probs.begin()); |
|
std::copy(fallback_from.begin(), fallback_from.end(), handler.from.begin()); |
|
|
|
for (std::size_t i = 0; i < handler.ActiveSize();) { |
|
if (std::equal(handler.out_record.begin(), handler.out_record.end(), |
|
handler[i].Stream()->begin())) { |
|
handler.probs[handler[i].model] = handler.info.lambdas[handler[i].model] * handler[i].Stream()->Value().prob; |
|
handler.from[handler[i].model] = order - 1; |
|
if (++handler[i].Stream()) { |
|
++i; |
|
} else { |
|
handler.erase(i); |
|
} |
|
} else { |
|
++i; |
|
} |
|
} |
|
handler.out_record.Prob() = std::accumulate(handler.probs.begin(), handler.probs.end(), 0.0); |
|
handler.out_record.LowerProb() = combined_fallback; |
|
handler.encoder.Encode(handler.from.begin(), |
|
handler.out_record.FromBegin()); |
|
|
|
|
|
|
|
HandleSuffix(handlers, handler.out_record.begin(), handler.out_record.end(), |
|
handler.probs, handler.from, handler.out_record.Prob(), outputs); |
|
|
|
++output; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void HandleNGrams(NGramHandlers &handlers, util::stream::Streams &outputs) { |
|
PartialProbGamma unk_record(1, 0); |
|
|
|
|
|
util::FixedArray<float> unk_probs(handlers[0].info.Models()); |
|
|
|
|
|
lm::NGram<ProbBackoff> ngram = *handlers[0][0].Stream(); |
|
unk_record.ReBase(outputs[0].Get()); |
|
std::copy(ngram.begin(), ngram.end(), unk_record.begin()); |
|
unk_record.Prob() = 0; |
|
|
|
|
|
|
|
|
|
|
|
assert(handlers[0].ActiveSize() == handlers[0].info.Models()); |
|
for (std::size_t i = 0; i < handlers[0].info.Models();) { |
|
ngram = *handlers[0][i].Stream(); |
|
unk_probs.push_back(handlers[0].info.lambdas[i] * ngram.Value().prob); |
|
unk_record.Prob() += unk_probs[i]; |
|
assert(*ngram.begin() == kUNK); |
|
if (++handlers[0][i].Stream()) { |
|
++i; |
|
} else { |
|
handlers[0].erase(i); |
|
} |
|
} |
|
float unk_combined = unk_record.Prob(); |
|
unk_record.LowerProb() = unk_combined; |
|
|
|
++outputs[0]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util::FixedArray<uint8_t> unk_from(handlers[0].info.Models()); |
|
for (std::size_t i = 0; i < handlers[0].info.Models(); ++i) { |
|
unk_from.push_back(0); |
|
} |
|
|
|
|
|
|
|
|
|
HandleSuffix(handlers, NULL, NULL, unk_probs, unk_from, unk_combined, outputs); |
|
|
|
|
|
for (std::size_t i = 0; i < handlers.size(); ++i) { |
|
UTIL_THROW_IF2(handlers[i].ActiveSize(), |
|
"MergeProbabilities did not exhaust all ngram streams"); |
|
outputs[i].Poison(); |
|
} |
|
} |
|
} |
|
|
|
void MergeProbabilities::Run(const util::stream::ChainPositions &output_pos) { |
|
NGramHandlers handlers(output_pos.size()); |
|
for (std::size_t i = 0; i < output_pos.size(); ++i) { |
|
handlers.push_back(i + 1, info_, models_by_order_); |
|
} |
|
|
|
util::stream::Streams outputs(output_pos); |
|
HandleNGrams(handlers, outputs); |
|
} |
|
|
|
}} |
|
|