File size: 3,975 Bytes
8652957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include "lm/interpolate/tune_derivatives.hh"

#include "lm/interpolate/tune_instances.hh"

#include "util/stream/config.hh"
#include "util/stream/chain.hh"
#include "util/stream/io.hh"
#include "util/stream/typed_stream.hh"

#define BOOST_TEST_MODULE DerivativeTest
#include <boost/test/unit_test.hpp>

namespace lm { namespace interpolate {

class MockInstances : public Instances {
  public:
    MockInstances() : chain_(util::stream::ChainConfig(ReadExtensionsEntrySize(), 2, 100)), write_(chain_.Add()) {
      extensions_subsequent_.reset(new util::stream::FileBuffer(util::MakeTemp("/tmp/")));
      chain_ >> extensions_subsequent_->Sink() >> util::stream::kRecycle;
    }

    Matrix &LNUnigrams() { return ln_unigrams_; }

    BackoffMatrix &LNBackoffs() { return ln_backoffs_; }

    WordIndex &BOS() { return bos_; }

    Vector &NegLNCorrectSum() { return neg_ln_correct_sum_; }

    // Extensions must be provided sorted!
    void AddExtension(const Extension &extension) {
      *write_ = extension;
      ++write_;
    }

    void DoneExtending() {
      write_.Poison();
      chain_.Wait(true);
    }

  private:
    util::stream::Chain chain_;
    util::stream::TypedStream<Extension> write_;
};

namespace {

BOOST_AUTO_TEST_CASE(Small) {
  MockInstances mock;

  {
    // Three vocabulary words plus <s>, two models.
    Matrix unigrams(4, 2);
    unigrams <<
      0.1, 0.6,
      0.4, 0.3,
      0.5, 0.1,
      // <s>
      1.0, 1.0;
    mock.LNUnigrams() = unigrams.array().log();
  }
  mock.BOS() = 3;

  // One instance
  mock.LNBackoffs().resize(1, 2);
  mock.LNBackoffs() << 0.2, 0.4;
  mock.LNBackoffs() = mock.LNBackoffs().array().log();

  // Sparse extensions: model 0 word 2 and model 1 word 1.

  // Assuming that model 1 only matches word 1, this is p_1(1 | context)
  Accum model_1_word_1 = 1.0 - .6 * .4 - .1 * .4;

  mock.NegLNCorrectSum().resize(2);
  // We'll suppose correct has WordIndex 1, which backs off in model 0, and matches in model 1
  mock.NegLNCorrectSum() << (0.4 * 0.2), model_1_word_1;
  mock.NegLNCorrectSum() = -mock.NegLNCorrectSum().array().log();

  Accum model_0_word_2 = 1.0 - .1 * .2 - .4 * .2;

  Extension ext;

  ext.instance = 0;
  ext.word = 1;
  ext.model = 1;
  ext.ln_prob = log(model_1_word_1);
  mock.AddExtension(ext);

  ext.instance = 0;
  ext.word = 2;
  ext.model = 0;
  ext.ln_prob = log(model_0_word_2);
  mock.AddExtension(ext);

  mock.DoneExtending();

  Vector weights(2);
  weights << 0.9, 1.2;

  Vector gradient(2);
  Matrix hessian(2,2);
  Derivatives(mock, weights, gradient, hessian);
  // TODO: check perplexity value coming out.

  // p_I(x | context)
  Vector p_I(3);
  p_I <<
    pow(0.1 * 0.2, 0.9) * pow(0.6 * 0.4, 1.2),
    pow(0.4 * 0.2, 0.9) * pow(model_1_word_1, 1.2),
    pow(model_0_word_2, 0.9) * pow(0.1 * 0.4, 1.2);
  p_I /= p_I.sum();

  Vector expected_gradient = mock.NegLNCorrectSum();
  expected_gradient(0) += p_I(0) * log(0.1 * 0.2);
  expected_gradient(0) += p_I(1) * log(0.4 * 0.2);
  expected_gradient(0) += p_I(2) * log(model_0_word_2);
  BOOST_CHECK_CLOSE(expected_gradient(0), gradient(0), 0.01);

  expected_gradient(1) += p_I(0) * log(0.6 * 0.4);
  expected_gradient(1) += p_I(1) * log(model_1_word_1);
  expected_gradient(1) += p_I(2) * log(0.1 * 0.4);
  BOOST_CHECK_CLOSE(expected_gradient(1), gradient(1), 0.01);

  Matrix expected_hessian(2, 2);
  expected_hessian(1, 0) =
    // First term
    p_I(0) * log(0.1 * 0.2) * log(0.6 * 0.4) +
    p_I(1) * log(0.4 * 0.2) * log(model_1_word_1) +
    p_I(2) * log(model_0_word_2) * log(0.1 * 0.4);
  expected_hessian(1, 0) -=
    (p_I(0) * log(0.1 * 0.2) + p_I(1) * log(0.4 * 0.2) + p_I(2) * log(model_0_word_2)) *
    (p_I(0) * log(0.6 * 0.4) + p_I(1) * log(model_1_word_1) + p_I(2) * log(0.1 * 0.4));
  expected_hessian(0, 1) = expected_hessian(1, 0);
  BOOST_CHECK_CLOSE(expected_hessian(1, 0), hessian(1, 0), 0.01);
  BOOST_CHECK_CLOSE(expected_hessian(0, 1), hessian(0, 1), 0.01);
}

}}} // namespaces