File size: 6,044 Bytes
1ce325b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#include "interpolate.hh"

#include "hash_gamma.hh"
#include "payload.hh"
#include "../common/compare.hh"
#include "../common/joint_order.hh"
#include "../common/ngram_stream.hh"
#include "../lm_exception.hh"
#include "../../util/fixed_array.hh"
#include "../../util/murmur_hash.hh"

#include <iostream>
#include <cassert>
#include <cmath>

namespace lm { namespace builder {
namespace {

/* Calculate q, the collapsed probability and backoff, as defined in
 * @inproceedings{Heafield-rest,
 *   author = {Kenneth Heafield and Philipp Koehn and Alon Lavie},
 *   title = {Language Model Rest Costs and Space-Efficient Storage},
 *   year = {2012},
 *   month = {July},
 *   booktitle = {Proceedings of the Joint Conference on Empirical Methods in Natural Language Processing and Computational Natural Language Learning},
 *   address = {Jeju Island, Korea},
 *   pages = {1169--1178},
 *   url = {http://kheafield.com/professional/edinburgh/rest\_paper.pdf},
 * }
 * This is particularly convenient to calculate during interpolation because
 * the needed backoff terms are already accessed at the same time.
 */
class OutputQ {
  public:
    explicit OutputQ(std::size_t order) : q_delta_(order) {}

    void Gram(unsigned order_minus_1, float full_backoff, ProbBackoff &out) {
      float &q_del = q_delta_[order_minus_1];
      if (order_minus_1) {
        // Divide by context's backoff (which comes in as out.backoff)
        q_del = q_delta_[order_minus_1 - 1] / out.backoff * full_backoff;
      } else {
        q_del = full_backoff;
      }
      out.prob = log10f(out.prob * q_del);
      // TODO: stop wastefully outputting this!
      out.backoff = 0.0;
    }

  private:
    // Product of backoffs in the numerator divided by backoffs in the
    // denominator.  Does not include
    std::vector<float> q_delta_;
};

/* Default: output probability and backoff */
class OutputProbBackoff {
  public:
    explicit OutputProbBackoff(std::size_t /*order*/) {}

    void Gram(unsigned /*order_minus_1*/, float full_backoff, ProbBackoff &out) const {
      // Correcting for numerical precision issues.  Take that IRST.
      out.prob = std::min(0.0f, log10f(out.prob));
      out.backoff = log10f(full_backoff);
    }
};

template <class Output> class Callback {
  public:
    Callback(float uniform_prob, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t> &prune_thresholds, bool prune_vocab, const SpecialVocab &specials)
      : backoffs_(backoffs.size()), probs_(backoffs.size() + 2),
        prune_thresholds_(prune_thresholds),
        prune_vocab_(prune_vocab),
        output_(backoffs.size() + 1 /* order */),
        specials_(specials) {
      probs_[0] = uniform_prob;
      for (std::size_t i = 0; i < backoffs.size(); ++i) {
        backoffs_.push_back(backoffs[i]);
      }
    }

    ~Callback() {
      for (std::size_t i = 0; i < backoffs_.size(); ++i) {
        if(prune_vocab_ || prune_thresholds_[i + 1] > 0)
          while(backoffs_[i])
            ++backoffs_[i];

        if (backoffs_[i]) {
          std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl;
          abort();
        }
      }
    }

    void Enter(unsigned order_minus_1, void *data) {
      NGram<BuildingPayload> gram(data, order_minus_1 + 1);
      BuildingPayload &pay = gram.Value();
      pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1];
      probs_[order_minus_1 + 1] = pay.complete.prob;

      float out_backoff;
      if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != specials_.UNK() && *(gram.end() - 1) != specials_.EOS() && backoffs_[order_minus_1]) {
        if(prune_vocab_ || prune_thresholds_[order_minus_1 + 1] > 0) {
          //Compute hash value for current context
          uint64_t current_hash = util::MurmurHashNative(gram.begin(), gram.Order() * sizeof(WordIndex));

          const HashGamma *hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());
          while(current_hash != hashed_backoff->hash_value && ++backoffs_[order_minus_1])
            hashed_backoff = static_cast<const HashGamma*>(backoffs_[order_minus_1].Get());

          if(current_hash == hashed_backoff->hash_value) {
            out_backoff = hashed_backoff->gamma;
            ++backoffs_[order_minus_1];
          } else {
            // Has been pruned away so it is not a context anymore
            out_backoff = 1.0;
          }
        } else {
          out_backoff = *static_cast<const float*>(backoffs_[order_minus_1].Get());
          ++backoffs_[order_minus_1];
        }
      } else {
        // Not a context.
        out_backoff = 1.0;
      }

      output_.Gram(order_minus_1, out_backoff, pay.complete);
    }

    void Exit(unsigned, void *) const {}

  private:
    util::FixedArray<util::stream::Stream> backoffs_;

    std::vector<float> probs_;
    const std::vector<uint64_t>& prune_thresholds_;
    bool prune_vocab_;

    Output output_;
    const SpecialVocab specials_;
};
} // namespace

Interpolate::Interpolate(uint64_t vocab_size, const util::stream::ChainPositions &backoffs, const std::vector<uint64_t>& prune_thresholds, bool prune_vocab, bool output_q, const SpecialVocab &specials)
  : uniform_prob_(1.0 / static_cast<float>(vocab_size)), // Includes <unk> but excludes <s>.
    backoffs_(backoffs),
    prune_thresholds_(prune_thresholds),
    prune_vocab_(prune_vocab),
    output_q_(output_q),
    specials_(specials) {}

// perform order-wise interpolation
void Interpolate::Run(const util::stream::ChainPositions &positions) {
  assert(positions.size() == backoffs_.size() + 1);
  if (output_q_) {
    typedef Callback<OutputQ> C;
    C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
    JointOrder<C, SuffixOrder>(positions, callback);
  } else {
    typedef Callback<OutputProbBackoff> C;
    C callback(uniform_prob_, backoffs_, prune_thresholds_, prune_vocab_, specials_);
    JointOrder<C, SuffixOrder>(positions, callback);
  }
}

}} // namespaces