File size: 7,180 Bytes
8652957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
/* Efficient left and right language model state for sentence fragments.
 * Intended usage:
 * Store ChartState with every chart entry.
 * To do a rule application:
 * 1. Make a ChartState object for your new entry.
 * 2. Construct RuleScore.
 * 3. Going from left to right, call Terminal or NonTerminal.
 *   For terminals, just pass the vocab id.
 *   For non-terminals, pass that non-terminal's ChartState.
 *     If your decoder expects scores inclusive of subtree scores (i.e. you
 *     label entries with the highest-scoring path), pass the non-terminal's
 *     score as prob.
 *     If your decoder expects relative scores and will walk the chart later,
 *     pass prob = 0.0.
 *     In other words, the only effect of prob is that it gets added to the
 *     returned log probability.
 * 4. Call Finish.  It returns the log probability.
 *
 * There's a couple more details:
 * Do not pass <s> to Terminal as it is formally not a word in the sentence,
 * only context.  Instead, call BeginSentence.  If called, it should be the
 * first call after RuleScore is constructed (since <s> is always the
 * leftmost).
 *
 * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal.
 *
 * Hashing and sorting comparison operators are provided.   All state objects
 * are POD.  If you intend to use memcmp on raw state objects, you must call
 * ZeroRemaining first, as the value of array entries beyond length is
 * otherwise undefined.
 *
 * Usage is of course not limited to chart decoding.  Anything that generates
 * sentence fragments missing left context could benefit.  For example, a
 * phrase-based decoder could pre-score phrases, storing ChartState with each
 * phrase, even if hypotheses are generated left-to-right.
 */

#ifndef LM_LEFT_H
#define LM_LEFT_H

#include "lm/max_order.hh"
#include "lm/state.hh"
#include "lm/return.hh"

#include "util/murmur_hash.hh"

#include <algorithm>

namespace lm {
namespace ngram {

template <class M> class RuleScore {
  public:
    explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
      out.left.length = 0;
      out.right.length = 0;
    }

    void BeginSentence() {
      out_->right = model_.BeginSentenceState();
      // out_->left is empty.
      left_done_ = true;
    }

    void Terminal(WordIndex word) {
      State copy(out_->right);
      FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
      if (left_done_) { prob_ += ret.prob; return; }
      if (ret.independent_left) {
        prob_ += ret.prob;
        left_done_ = true;
        return;
      }
      out_->left.pointers[out_->left.length++] = ret.extend_left;
      prob_ += ret.rest;
      if (out_->right.length != copy.length + 1)
        left_done_ = true;
    }

    // Faster version of NonTerminal for the case where the rule begins with a non-terminal.
    void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
      prob_ = prob;
      *out_ = in;
      left_done_ = in.left.full;
    }

    void NonTerminal(const ChartState &in, float prob = 0.0) {
      prob_ += prob;

      if (!in.left.length) {
        if (in.left.full) {
          for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
          left_done_ = true;
          out_->right = in.right;
        }
        return;
      }

      if (!out_->right.length) {
        out_->right = in.right;
        if (left_done_) {
          prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
          return;
        }
        if (out_->left.length) {
          left_done_ = true;
        } else {
          out_->left = in.left;
          left_done_ = in.left.full;
        }
        return;
      }

      float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
      float *back = backoffs, *back2 = backoffs2;
      unsigned char next_use = out_->right.length;

      // First word
      if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;

      // Words after the first, so extending a bigram to begin with
      for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
        if (ExtendLeft(in, next_use, extend_length, back, back2)) return;
        std::swap(back, back2);
      }

      if (in.left.full) {
        for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
        left_done_ = true;
        out_->right = in.right;
        return;
      }

      // Right state was minimized, so it's already independent of the new words to the left.
      if (in.right.length < in.left.length) {
        out_->right = in.right;
        return;
      }

      // Shift exisiting words down.
      for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
        *(i + in.right.length) = *i;
      }
      // Add words from in.right.
      std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
      // Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
      std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
      std::copy(back, back + next_use, out_->right.backoff + in.right.length);
      out_->right.length = in.right.length + next_use;
    }

    float Finish() {
      // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
      out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
      return prob_;
    }

    void Reset() {
      prob_ = 0.0;
      left_done_ = false;
      out_->left.length = 0;
      out_->right.length = 0;
    }
    void Reset(ChartState &replacement) {
      out_ = &replacement;
      Reset();
    }

  private:
    bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
      ProcessRet(model_.ExtendLeft(
            out_->right.words, out_->right.words + next_use, // Words to extend into
            back_in, // Backoffs to use
            in.left.pointers[extend_length - 1], extend_length, // Words to be extended
            back_out, // Backoffs for the next score
            next_use)); // Length of n-gram to use in next scoring.
      if (next_use != out_->right.length) {
        left_done_ = true;
        if (!next_use) {
          // Early exit.
          out_->right = in.right;
          prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
          return true;
        }
      }
      // Continue scoring.
      return false;
    }

    void ProcessRet(const FullScoreReturn &ret) {
      if (left_done_) {
        prob_ += ret.prob;
        return;
      }
      if (ret.independent_left) {
        prob_ += ret.prob;
        left_done_ = true;
        return;
      }
      out_->left.pointers[out_->left.length++] = ret.extend_left;
      prob_ += ret.rest;
    }

    const M &model_;

    ChartState *out_;

    bool left_done_;

    float prob_;
};

} // namespace ngram
} // namespace lm

#endif // LM_LEFT_H