File size: 3,969 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#include "tune_derivatives.hh"
#include "tune_instances.hh"
#include "../../util/stream/config.hh"
#include "../../util/stream/chain.hh"
#include "../../util/stream/io.hh"
#include "../../util/stream/typed_stream.hh"
#define BOOST_TEST_MODULE DerivativeTest
#include <boost/test/unit_test.hpp>
namespace lm { namespace interpolate {
class MockInstances : public Instances {
public:
MockInstances() : chain_(util::stream::ChainConfig(ReadExtensionsEntrySize(), 2, 100)), write_(chain_.Add()) {
extensions_subsequent_.reset(new util::stream::FileBuffer(util::MakeTemp("/tmp/")));
chain_ >> extensions_subsequent_->Sink() >> util::stream::kRecycle;
}
Matrix &LNUnigrams() { return ln_unigrams_; }
BackoffMatrix &LNBackoffs() { return ln_backoffs_; }
WordIndex &BOS() { return bos_; }
Vector &NegLNCorrectSum() { return neg_ln_correct_sum_; }
// Extensions must be provided sorted!
void AddExtension(const Extension &extension) {
*write_ = extension;
++write_;
}
void DoneExtending() {
write_.Poison();
chain_.Wait(true);
}
private:
util::stream::Chain chain_;
util::stream::TypedStream<Extension> write_;
};
namespace {
BOOST_AUTO_TEST_CASE(Small) {
MockInstances mock;
{
// Three vocabulary words plus <s>, two models.
Matrix unigrams(4, 2);
unigrams <<
0.1, 0.6,
0.4, 0.3,
0.5, 0.1,
// <s>
1.0, 1.0;
mock.LNUnigrams() = unigrams.array().log();
}
mock.BOS() = 3;
// One instance
mock.LNBackoffs().resize(1, 2);
mock.LNBackoffs() << 0.2, 0.4;
mock.LNBackoffs() = mock.LNBackoffs().array().log();
// Sparse extensions: model 0 word 2 and model 1 word 1.
// Assuming that model 1 only matches word 1, this is p_1(1 | context)
Accum model_1_word_1 = 1.0 - .6 * .4 - .1 * .4;
mock.NegLNCorrectSum().resize(2);
// We'll suppose correct has WordIndex 1, which backs off in model 0, and matches in model 1
mock.NegLNCorrectSum() << (0.4 * 0.2), model_1_word_1;
mock.NegLNCorrectSum() = -mock.NegLNCorrectSum().array().log();
Accum model_0_word_2 = 1.0 - .1 * .2 - .4 * .2;
Extension ext;
ext.instance = 0;
ext.word = 1;
ext.model = 1;
ext.ln_prob = log(model_1_word_1);
mock.AddExtension(ext);
ext.instance = 0;
ext.word = 2;
ext.model = 0;
ext.ln_prob = log(model_0_word_2);
mock.AddExtension(ext);
mock.DoneExtending();
Vector weights(2);
weights << 0.9, 1.2;
Vector gradient(2);
Matrix hessian(2,2);
Derivatives(mock, weights, gradient, hessian);
// TODO: check perplexity value coming out.
// p_I(x | context)
Vector p_I(3);
p_I <<
pow(0.1 * 0.2, 0.9) * pow(0.6 * 0.4, 1.2),
pow(0.4 * 0.2, 0.9) * pow(model_1_word_1, 1.2),
pow(model_0_word_2, 0.9) * pow(0.1 * 0.4, 1.2);
p_I /= p_I.sum();
Vector expected_gradient = mock.NegLNCorrectSum();
expected_gradient(0) += p_I(0) * log(0.1 * 0.2);
expected_gradient(0) += p_I(1) * log(0.4 * 0.2);
expected_gradient(0) += p_I(2) * log(model_0_word_2);
BOOST_CHECK_CLOSE(expected_gradient(0), gradient(0), 0.01);
expected_gradient(1) += p_I(0) * log(0.6 * 0.4);
expected_gradient(1) += p_I(1) * log(model_1_word_1);
expected_gradient(1) += p_I(2) * log(0.1 * 0.4);
BOOST_CHECK_CLOSE(expected_gradient(1), gradient(1), 0.01);
Matrix expected_hessian(2, 2);
expected_hessian(1, 0) =
// First term
p_I(0) * log(0.1 * 0.2) * log(0.6 * 0.4) +
p_I(1) * log(0.4 * 0.2) * log(model_1_word_1) +
p_I(2) * log(model_0_word_2) * log(0.1 * 0.4);
expected_hessian(1, 0) -=
(p_I(0) * log(0.1 * 0.2) + p_I(1) * log(0.4 * 0.2) + p_I(2) * log(model_0_word_2)) *
(p_I(0) * log(0.6 * 0.4) + p_I(1) * log(model_1_word_1) + p_I(2) * log(0.1 * 0.4));
expected_hessian(0, 1) = expected_hessian(1, 0);
BOOST_CHECK_CLOSE(expected_hessian(1, 0), hessian(1, 0), 0.01);
BOOST_CHECK_CLOSE(expected_hessian(0, 1), hessian(0, 1), 0.01);
}
}}} // namespaces
|