File size: 2,758 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
#include "normalize.hh"
#include "interpolate_info.hh"
#include "merge_probabilities.hh"
#include "../common/ngram_stream.hh"
#include "../../util/stream/chain.hh"
#include "../../util/stream/multi_stream.hh"
#define BOOST_TEST_MODULE NormalizeTest
#include <boost/test/unit_test.hpp>
namespace lm { namespace interpolate { namespace {
// log without backoff
const float kInputs[] = {-0.3, 1.2, -9.8, 4.0, -7.0, 0.0};
class WriteInput {
public:
WriteInput() {}
void Run(const util::stream::ChainPosition &to) {
util::stream::Stream out(to);
for (WordIndex i = 0; i < sizeof(kInputs) / sizeof(float); ++i, ++out) {
memcpy(out.Get(), &i, sizeof(WordIndex));
memcpy((uint8_t*)out.Get() + sizeof(WordIndex), &kInputs[i], sizeof(float));
}
out.Poison();
}
};
void CheckOutput(const util::stream::ChainPosition &from) {
NGramStream<float> in(from);
float sum = 0.0;
for (WordIndex i = 0; i < sizeof(kInputs) / sizeof(float) - 1 /* <s> at the end */; ++i) {
sum += pow(10.0, kInputs[i]);
}
sum = log10(sum);
BOOST_REQUIRE(in);
BOOST_CHECK_CLOSE(kInputs[0] - sum, in->Value(), 0.0001);
BOOST_REQUIRE(++in);
BOOST_CHECK_CLOSE(kInputs[1] - sum, in->Value(), 0.0001);
BOOST_REQUIRE(++in);
BOOST_CHECK_CLOSE(kInputs[2] - sum, in->Value(), 0.0001);
BOOST_REQUIRE(++in);
BOOST_CHECK_CLOSE(kInputs[3] - sum, in->Value(), 0.0001);
BOOST_REQUIRE(++in);
BOOST_CHECK_CLOSE(kInputs[4] - sum, in->Value(), 0.0001);
BOOST_REQUIRE(++in);
BOOST_CHECK_CLOSE(kInputs[5] - sum, in->Value(), 0.0001);
BOOST_CHECK(!++in);
}
BOOST_AUTO_TEST_CASE(Unigrams) {
InterpolateInfo info;
info.lambdas.push_back(2.0);
info.lambdas.push_back(-0.1);
info.orders.push_back(1);
info.orders.push_back(1);
BOOST_CHECK_EQUAL(0, MakeEncoder(info, 1).EncodedLength());
// No backoffs.
util::stream::Chains blank(0);
util::FixedArray<util::stream::ChainPositions> models_by_order(2);
models_by_order.push_back(blank);
models_by_order.push_back(blank);
util::stream::Chains merged_probabilities(1);
util::stream::Chains probabilities_out(1);
util::stream::Chains backoffs_out(0);
merged_probabilities.push_back(util::stream::ChainConfig(sizeof(WordIndex) + sizeof(float) + sizeof(float), 2, 24));
probabilities_out.push_back(util::stream::ChainConfig(sizeof(WordIndex) + sizeof(float), 2, 100));
merged_probabilities[0] >> WriteInput();
Normalize(info, models_by_order, merged_probabilities, probabilities_out, backoffs_out);
util::stream::ChainPosition checker(probabilities_out[0].Add());
merged_probabilities >> util::stream::kRecycle;
probabilities_out >> util::stream::kRecycle;
CheckOutput(checker);
probabilities_out.Wait();
}
}}} // namespaces
|