#include "backoff_reunification.hh" #include "../common/ngram_stream.hh" #define BOOST_TEST_MODULE InterpolateBackoffReunificationTest #include namespace lm { namespace interpolate { namespace { // none of this input actually makes sense, all we care about is making // sure the merging works template struct Gram { WordIndex ids[N]; float prob; float boff; }; template struct Grams { const static Gram grams[]; }; template <> const Gram<1> Grams<1>::grams[] = {{{0}, -0.1f, -0.1f}, {{1}, -0.4f, -0.2f}, {{2}, -0.5f, -0.1f}}; template <> const Gram<2> Grams<2>::grams[] = {{{0, 0}, -0.05f, -0.05f}, {{1, 0}, -0.05f, -0.02f}, {{1, 1}, -0.2f, -0.04f}, {{2, 2}, -0.2f, -0.01f}}; template <> const Gram<3> Grams<3>::grams[] = {{{0, 0, 0}, -0.001f, -0.005f}, {{1, 0, 0}, -0.001f, -0.002f}, {{2, 0, 0}, -0.001f, -0.003f}, {{0, 1, 0}, -0.1f, -0.008f}, {{1, 1, 0}, -0.1f, -0.09f}, {{1, 1, 1}, -0.2f, -0.08f}}; template class WriteInput { public: void Run(const util::stream::ChainPosition &position) { lm::NGramStream output(position); for (std::size_t i = 0; i < sizeof(Grams::grams) / sizeof(Gram); ++i, ++output) { std::copy(Grams::grams[i].ids, Grams::grams[i].ids + N, output->begin()); output->Value() = Grams::grams[i].prob; } output.Poison(); } }; template class WriteBackoffs { public: void Run(const util::stream::ChainPosition &position) { util::stream::Stream output(position); for (std::size_t i = 0; i < sizeof(Grams::grams) / sizeof(Gram); ++i, ++output) { *reinterpret_cast(output.Get()) = Grams::grams[i].boff; } output.Poison(); } }; template class CheckOutput { public: void Run(const util::stream::ChainPosition &position) { lm::NGramStream stream(position); std::size_t i = 0; for (; stream; ++stream, ++i) { std::stringstream ss; for (WordIndex *idx = stream->begin(); idx != stream->end(); ++idx) ss << "(" << *idx << ")"; BOOST_CHECK(std::equal(stream->begin(), stream->end(), Grams::grams[i].ids)); //"Mismatched id in CheckOutput<" << (int)N << ">: " << ss.str(); BOOST_CHECK_EQUAL(stream->Value().prob, Grams::grams[i].prob); /* "Mismatched probability in CheckOutput<" << (int)N << ">, got " << stream->Value().prob << ", expected " << Grams::grams[i].prob;*/ BOOST_CHECK_EQUAL(stream->Value().backoff, Grams::grams[i].boff); /* "Mismatched backoff in CheckOutput<" << (int)N << ">, got " << stream->Value().backoff << ", expected " << Grams::grams[i].boff);*/ } BOOST_CHECK_EQUAL(i , sizeof(Grams::grams) / sizeof(Gram)); /* "Did not get correct number of " << (int)N << "-grams: expected " << sizeof(Grams::grams) / sizeof(Gram) << ", got " << i;*/ } }; } BOOST_AUTO_TEST_CASE(BackoffReunificationTest) { util::stream::ChainConfig config; config.total_memory = 100; config.block_count = 1; util::stream::Chains prob_chains(3); config.entry_size = NGram::TotalSize(1); prob_chains.push_back(config); prob_chains.back() >> WriteInput<1>(); config.entry_size = NGram::TotalSize(2); prob_chains.push_back(config); prob_chains.back() >> WriteInput<2>(); config.entry_size = NGram::TotalSize(3); prob_chains.push_back(config); prob_chains.back() >> WriteInput<3>(); util::stream::Chains boff_chains(3); config.entry_size = sizeof(float); boff_chains.push_back(config); boff_chains.back() >> WriteBackoffs<1>(); boff_chains.push_back(config); boff_chains.back() >> WriteBackoffs<2>(); boff_chains.push_back(config); boff_chains.back() >> WriteBackoffs<3>(); util::stream::ChainPositions prob_pos(prob_chains); util::stream::ChainPositions boff_pos(boff_chains); util::stream::Chains output_chains(3); for (std::size_t i = 0; i < 3; ++i) { config.entry_size = NGram::TotalSize(i + 1); output_chains.push_back(config); } ReunifyBackoff(prob_pos, boff_pos, output_chains); output_chains[0] >> CheckOutput<1>(); output_chains[1] >> CheckOutput<2>(); output_chains[2] >> CheckOutput<3>(); prob_chains >> util::stream::kRecycle; boff_chains >> util::stream::kRecycle; output_chains.Wait(); } } }