#define BOOST_TEST_MODULE InterpolateMergeVocabTest #include #include "../enumerate_vocab.hh" #include "merge_vocab.hh" #include "universal_vocab.hh" #include "../lm_exception.hh" #include "../vocab.hh" #include "../word_index.hh" #include "../../util/file.hh" #include "../../util/file_piece.hh" #include "../../util/file_stream.hh" #include "../../util/tokenize_piece.hh" #include #include #include namespace lm { namespace interpolate { namespace { struct VocabEntry { explicit VocabEntry(StringPiece value) : str(value), hash(util::MurmurHash64A(value.data(), value.size())) {} StringPiece str; uint64_t hash; bool operator<(const VocabEntry &other) const { return hash < other.hash; } }; int WriteVocabFile(const std::vector &vocab, util::scoped_fd &file) { file.reset(util::MakeTemp(util::DefaultTempDirectory())); { util::FileStream out(file.get(), 128); for (std::vector::const_iterator i = vocab.begin(); i != vocab.end(); ++i) { out << i->str << '\0'; } } util::SeekOrThrow(file.get(), 0); return file.get(); } std::vector ParseVocab(StringPiece words) { std::vector entries; entries.push_back(VocabEntry("")); for (util::TokenIter i(words, '\t'); i; ++i) { entries.push_back(VocabEntry(*i)); } std::sort(entries.begin() + 1, entries.end()); return entries; } int WriteVocabFile(StringPiece words, util::scoped_fd &file) { return WriteVocabFile(ParseVocab(words), file); } class TestFiles { public: TestFiles() {} int Test0() { return WriteVocabFile("this\tis\ta\tfirst\tcut", test[0]); } int Test1() { return WriteVocabFile("is this\tthis a\tfirst cut\ta first", test[1]); } int Test2() { return WriteVocabFile("is\tsecd\ti", test[2]); } int NoUNK() { std::vector no_unk_vec; no_unk_vec.push_back(VocabEntry("toto")); return WriteVocabFile(no_unk_vec, no_unk); } int BadOrder() { std::vector bad_order_vec; bad_order_vec.push_back(VocabEntry("")); bad_order_vec.push_back(VocabEntry("0")); bad_order_vec.push_back(VocabEntry("1")); bad_order_vec.push_back(VocabEntry("2")); bad_order_vec.push_back(VocabEntry("a")); return WriteVocabFile(bad_order_vec, bad_order); } private: util::scoped_fd test[3], no_unk, bad_order; }; class DoNothingEnumerate : public EnumerateVocab { public: void Add(WordIndex, const StringPiece &) {} }; BOOST_AUTO_TEST_CASE(MergeVocabTest) { TestFiles files; util::FixedArray used_files(3); used_files.push_back(files.Test0()); used_files.push_back(files.Test1()); used_files.push_back(files.Test2()); std::vector model_max_idx; model_max_idx.push_back(10); model_max_idx.push_back(10); model_max_idx.push_back(10); util::scoped_fd combined(util::MakeTemp(util::DefaultTempDirectory())); UniversalVocab universal_vocab(model_max_idx); { ngram::ImmediateWriteWordsWrapper writer(NULL, combined.get(), 0); MergeVocab(used_files, universal_vocab, writer); } BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(0, 0), 0); BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 0), 0); BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(2, 0), 0); BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(0, 1), 1); BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 1), 2); BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(2, 1), 8); BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(0, 5), 11); #if BYTE_ORDER == LITTLE_ENDIAN BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 3), 4); #elif BYTE_ORDER == BIG_ENDIAN // MurmurHash has a different ordering of the vocabulary. BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 3), 5); #endif BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(2, 3), 10); util::SeekOrThrow(combined.get(), 0); util::FilePiece f(combined.release()); std::vector expected = ParseVocab("a\tis this\tthis a\tfirst cut\tthis\ta first\tcut\tis\ti\tsecd\tfirst"); for (std::vector::const_iterator i = expected.begin(); i != expected.end(); ++i) { BOOST_CHECK_EQUAL(i->str, f.ReadLine('\0')); } BOOST_CHECK_THROW(f.ReadLine('\0'), util::EndOfFileException); } BOOST_AUTO_TEST_CASE(MergeVocabNoUnkTest) { TestFiles files; util::FixedArray used_files(1); used_files.push_back(files.NoUNK()); std::vector model_max_idx; model_max_idx.push_back(10); UniversalVocab universal_vocab(model_max_idx); DoNothingEnumerate nothing; BOOST_CHECK_THROW(MergeVocab(used_files, universal_vocab, nothing), FormatLoadException); } BOOST_AUTO_TEST_CASE(MergeVocabWrongOrderTest) { TestFiles files; util::FixedArray used_files(2); used_files.push_back(files.Test0()); used_files.push_back(files.BadOrder()); std::vector model_max_idx; model_max_idx.push_back(10); model_max_idx.push_back(10); lm::interpolate::UniversalVocab universal_vocab(model_max_idx); DoNothingEnumerate nothing; BOOST_CHECK_THROW(MergeVocab(used_files, universal_vocab, nothing), FormatLoadException); } }}} // namespaces