|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef LM_LEFT_H |
|
#define LM_LEFT_H |
|
|
|
#include "lm/max_order.hh" |
|
#include "lm/state.hh" |
|
#include "lm/return.hh" |
|
|
|
#include "util/murmur_hash.hh" |
|
|
|
#include <algorithm> |
|
|
|
namespace lm { |
|
namespace ngram { |
|
|
|
template <class M> class RuleScore { |
|
public: |
|
explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) { |
|
out.left.length = 0; |
|
out.right.length = 0; |
|
} |
|
|
|
void BeginSentence() { |
|
out_->right = model_.BeginSentenceState(); |
|
|
|
left_done_ = true; |
|
} |
|
|
|
void Terminal(WordIndex word) { |
|
State copy(out_->right); |
|
FullScoreReturn ret(model_.FullScore(copy, word, out_->right)); |
|
if (left_done_) { prob_ += ret.prob; return; } |
|
if (ret.independent_left) { |
|
prob_ += ret.prob; |
|
left_done_ = true; |
|
return; |
|
} |
|
out_->left.pointers[out_->left.length++] = ret.extend_left; |
|
prob_ += ret.rest; |
|
if (out_->right.length != copy.length + 1) |
|
left_done_ = true; |
|
} |
|
|
|
|
|
void BeginNonTerminal(const ChartState &in, float prob = 0.0) { |
|
prob_ = prob; |
|
*out_ = in; |
|
left_done_ = in.left.full; |
|
} |
|
|
|
void NonTerminal(const ChartState &in, float prob = 0.0) { |
|
prob_ += prob; |
|
|
|
if (!in.left.length) { |
|
if (in.left.full) { |
|
for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i; |
|
left_done_ = true; |
|
out_->right = in.right; |
|
} |
|
return; |
|
} |
|
|
|
if (!out_->right.length) { |
|
out_->right = in.right; |
|
if (left_done_) { |
|
prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); |
|
return; |
|
} |
|
if (out_->left.length) { |
|
left_done_ = true; |
|
} else { |
|
out_->left = in.left; |
|
left_done_ = in.left.full; |
|
} |
|
return; |
|
} |
|
|
|
float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; |
|
float *back = backoffs, *back2 = backoffs2; |
|
unsigned char next_use = out_->right.length; |
|
|
|
|
|
if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return; |
|
|
|
|
|
for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { |
|
if (ExtendLeft(in, next_use, extend_length, back, back2)) return; |
|
std::swap(back, back2); |
|
} |
|
|
|
if (in.left.full) { |
|
for (const float *i = back; i != back + next_use; ++i) prob_ += *i; |
|
left_done_ = true; |
|
out_->right = in.right; |
|
return; |
|
} |
|
|
|
|
|
if (in.right.length < in.left.length) { |
|
out_->right = in.right; |
|
return; |
|
} |
|
|
|
|
|
for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) { |
|
*(i + in.right.length) = *i; |
|
} |
|
|
|
std::copy(in.right.words, in.right.words + in.right.length, out_->right.words); |
|
|
|
std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); |
|
std::copy(back, back + next_use, out_->right.backoff + in.right.length); |
|
out_->right.length = in.right.length + next_use; |
|
} |
|
|
|
float Finish() { |
|
|
|
out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1); |
|
return prob_; |
|
} |
|
|
|
void Reset() { |
|
prob_ = 0.0; |
|
left_done_ = false; |
|
out_->left.length = 0; |
|
out_->right.length = 0; |
|
} |
|
void Reset(ChartState &replacement) { |
|
out_ = &replacement; |
|
Reset(); |
|
} |
|
|
|
private: |
|
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { |
|
ProcessRet(model_.ExtendLeft( |
|
out_->right.words, out_->right.words + next_use, |
|
back_in, |
|
in.left.pointers[extend_length - 1], extend_length, |
|
back_out, |
|
next_use)); |
|
if (next_use != out_->right.length) { |
|
left_done_ = true; |
|
if (!next_use) { |
|
|
|
out_->right = in.right; |
|
prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); |
|
return true; |
|
} |
|
} |
|
|
|
return false; |
|
} |
|
|
|
void ProcessRet(const FullScoreReturn &ret) { |
|
if (left_done_) { |
|
prob_ += ret.prob; |
|
return; |
|
} |
|
if (ret.independent_left) { |
|
prob_ += ret.prob; |
|
left_done_ = true; |
|
return; |
|
} |
|
out_->left.pointers[out_->left.length++] = ret.extend_left; |
|
prob_ += ret.rest; |
|
} |
|
|
|
const M &model_; |
|
|
|
ChartState *out_; |
|
|
|
bool left_done_; |
|
|
|
float prob_; |
|
}; |
|
|
|
} |
|
} |
|
|
|
#endif |
|
|