#include "merge_probabilities.hh" #include "../common/ngram_stream.hh" #include "bounded_sequence_encoding.hh" #include "interpolate_info.hh" #include #include #include namespace lm { namespace interpolate { /** * Helper to generate the BoundedSequenceEncoding used for writing the * from values. */ BoundedSequenceEncoding MakeEncoder(const InterpolateInfo &info, uint8_t order) { util::FixedArray max_orders(info.orders.size()); for (std::size_t i = 0; i < info.orders.size(); ++i) { max_orders.push_back(std::min(order, info.orders[i])); } return BoundedSequenceEncoding(max_orders.begin(), max_orders.end()); } namespace { /** * A simple wrapper class that holds information needed to read and write * the ngrams of a particular order. This class has the memory needed to * buffer the data needed for the recursive process of computing the * probabilities and "from" values for each component model. * * "From" values indicate, for each model, what order (as an index, so -1) * was backed off to in order to arrive at a probability. For example, if a * 5-gram model (order index 4) backed off twice, we would write a 2. */ class NGramHandler { public: NGramHandler(uint8_t order, const InterpolateInfo &ifo, util::FixedArray &models_by_order) : info(ifo), encoder(MakeEncoder(info, order)), out_record(order, encoder.EncodedLength()) { std::size_t count_has_order = 0; for (std::size_t i = 0; i < models_by_order.size(); ++i) { count_has_order += (models_by_order[i].size() >= order); } inputs_.Init(count_has_order); for (std::size_t i = 0; i < models_by_order.size(); ++i) { if (models_by_order[i].size() < order) continue; inputs_.push_back(models_by_order[i][order - 1]); if (inputs_.back()) { active_.resize(active_.size() + 1); active_.back().model = i; active_.back().stream = &inputs_.back(); } } // have to init outside since NGramStreams doesn't forward to // GenericStreams ctor given a ChainPositions probs.Init(info.Models()); from.Init(info.Models()); for (std::size_t i = 0; i < info.Models(); ++i) { probs.push_back(0.0); from.push_back(0); } } struct StreamIndex { NGramStream *stream; NGramStream &Stream() { return *stream; } std::size_t model; }; std::size_t ActiveSize() const { return active_.size(); } /** * @return the input stream for a particular model that corresponds to * this ngram order */ StreamIndex &operator[](std::size_t idx) { return active_[idx]; } void erase(std::size_t idx) { active_.erase(active_.begin() + idx); } const InterpolateInfo &info; BoundedSequenceEncoding encoder; PartialProbGamma out_record; util::FixedArray probs; util::FixedArray from; private: std::vector active_; NGramStreams inputs_; }; /** * A collection of NGramHandlers. */ class NGramHandlers : public util::FixedArray { public: explicit NGramHandlers(std::size_t num) : util::FixedArray(num) { } void push_back( std::size_t order, const InterpolateInfo &info, util::FixedArray &models_by_order) { new (end()) NGramHandler(order, info, models_by_order); Constructed(); } }; /** * The recursive helper function that computes probability and "from" * values for all ngrams matching a particular suffix. * * The current order can be computed as the suffix length + 1. Note that * the suffix could be empty (suffix_begin == suffix_end == NULL), in which * case we are handling unigrams with the UNK token as the fallback * probability. * * @param handlers The full collection of handlers * @param suffix_begin A start iterator for the suffix * @param suffix_end An end iterator for the suffix * @param fallback_probs The probabilities of this ngram if we need to * back off (that is, the probability of the suffix) * @param fallback_from The order that the corresponding fallback * probability in the fallback_probs is from * @param combined_fallback interpolated fallback_probs * @param outputs The output streams, one for each order */ void HandleSuffix(NGramHandlers &handlers, WordIndex *suffix_begin, WordIndex *suffix_end, const util::FixedArray &fallback_probs, const util::FixedArray &fallback_from, float combined_fallback, util::stream::Streams &outputs) { uint8_t order = std::distance(suffix_begin, suffix_end) + 1; if (order > outputs.size()) return; util::stream::Stream &output = outputs[order - 1]; NGramHandler &handler = handlers[order - 1]; while (true) { // find the next smallest ngram which matches our suffix // TODO: priority queue driven. WordIndex *minimum = NULL; for (std::size_t i = 0; i < handler.ActiveSize(); ++i) { if (!std::equal(suffix_begin, suffix_end, handler[i].Stream()->begin() + 1)) continue; // if we either haven't set a minimum yet or this one is smaller than // the minimum we found before, replace it WordIndex *last = handler[i].Stream()->begin(); if (!minimum || *last < *minimum) { minimum = handler[i].Stream()->begin(); } } // no more ngrams of this order match our suffix, so we're done if (!minimum) return; handler.out_record.ReBase(output.Get()); std::copy(minimum, minimum + order, handler.out_record.begin()); // Default case is having backed off. std::copy(fallback_probs.begin(), fallback_probs.end(), handler.probs.begin()); std::copy(fallback_from.begin(), fallback_from.end(), handler.from.begin()); for (std::size_t i = 0; i < handler.ActiveSize();) { if (std::equal(handler.out_record.begin(), handler.out_record.end(), handler[i].Stream()->begin())) { handler.probs[handler[i].model] = handler.info.lambdas[handler[i].model] * handler[i].Stream()->Value().prob; handler.from[handler[i].model] = order - 1; if (++handler[i].Stream()) { ++i; } else { handler.erase(i); } } else { ++i; } } handler.out_record.Prob() = std::accumulate(handler.probs.begin(), handler.probs.end(), 0.0); handler.out_record.LowerProb() = combined_fallback; handler.encoder.Encode(handler.from.begin(), handler.out_record.FromBegin()); // we've handled this particular ngram, so now recurse to the higher // order using the current ngram as the suffix HandleSuffix(handlers, handler.out_record.begin(), handler.out_record.end(), handler.probs, handler.from, handler.out_record.Prob(), outputs); // consume the output ++output; } } /** * Kicks off the recursion for computing the probabilities and "from" * values for each ngram order. We begin by handling the UNK token that * should be at the front of each of the unigram input streams. This is * then output to the stream and it is used as the fallback for handling * our unigram case, the unigram used as the fallback for the bigram case, * etc. */ void HandleNGrams(NGramHandlers &handlers, util::stream::Streams &outputs) { PartialProbGamma unk_record(1, 0); // First: populate the unk probabilities by reading the first unigram // from each stream util::FixedArray unk_probs(handlers[0].info.Models()); // start by populating the ngram id from the first stream lm::NGram ngram = *handlers[0][0].Stream(); unk_record.ReBase(outputs[0].Get()); std::copy(ngram.begin(), ngram.end(), unk_record.begin()); unk_record.Prob() = 0; // then populate the probabilities into unk_probs while "multiply" the // model probabilities together into the unk record // // note that from doesn't need to be set for unigrams assert(handlers[0].ActiveSize() == handlers[0].info.Models()); for (std::size_t i = 0; i < handlers[0].info.Models();) { ngram = *handlers[0][i].Stream(); unk_probs.push_back(handlers[0].info.lambdas[i] * ngram.Value().prob); unk_record.Prob() += unk_probs[i]; assert(*ngram.begin() == kUNK); if (++handlers[0][i].Stream()) { ++i; } else { handlers[0].erase(i); } } float unk_combined = unk_record.Prob(); unk_record.LowerProb() = unk_combined; // flush the unk output record ++outputs[0]; // Then, begin outputting everything in lexicographic order: first we'll // get the unigram then the first bigram with that context, then the // first trigram with that bigram context, etc., until we exhaust all of // the ngrams, then all of the (n-1)grams, etc. // // This function is the "root" of this recursive process. util::FixedArray unk_from(handlers[0].info.Models()); for (std::size_t i = 0; i < handlers[0].info.Models(); ++i) { unk_from.push_back(0); } // the two nulls are to encode that our "fallback" word is the "0-gram" // case, e.g. we "backed off" to UNK // TODO: stop generating vocab ids and LowerProb for unigrams. HandleSuffix(handlers, NULL, NULL, unk_probs, unk_from, unk_combined, outputs); // Verify we reached the end. And poison! for (std::size_t i = 0; i < handlers.size(); ++i) { UTIL_THROW_IF2(handlers[i].ActiveSize(), "MergeProbabilities did not exhaust all ngram streams"); outputs[i].Poison(); } } } // namespace void MergeProbabilities::Run(const util::stream::ChainPositions &output_pos) { NGramHandlers handlers(output_pos.size()); for (std::size_t i = 0; i < output_pos.size(); ++i) { handlers.push_back(i + 1, info_, models_by_order_); } util::stream::Streams outputs(output_pos); HandleNGrams(handlers, outputs); } }} // namespaces