File size: 5,312 Bytes
8652957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
#define BOOST_TEST_MODULE InterpolateMergeVocabTest
#include <boost/test/unit_test.hpp>
#include "lm/enumerate_vocab.hh"
#include "lm/interpolate/merge_vocab.hh"
#include "lm/interpolate/universal_vocab.hh"
#include "lm/lm_exception.hh"
#include "lm/vocab.hh"
#include "lm/word_index.hh"
#include "util/file.hh"
#include "util/file_piece.hh"
#include "util/file_stream.hh"
#include "util/tokenize_piece.hh"
#include <algorithm>
#include <cstring>
#include <vector>
namespace lm {
namespace interpolate {
namespace {
struct VocabEntry {
explicit VocabEntry(StringPiece value) :
str(value), hash(util::MurmurHash64A(value.data(), value.size())) {}
StringPiece str;
uint64_t hash;
bool operator<(const VocabEntry &other) const {
return hash < other.hash;
}
};
int WriteVocabFile(const std::vector<VocabEntry> &vocab, util::scoped_fd &file) {
file.reset(util::MakeTemp(util::DefaultTempDirectory()));
{
util::FileStream out(file.get(), 128);
for (std::vector<VocabEntry>::const_iterator i = vocab.begin(); i != vocab.end(); ++i) {
out << i->str << '\0';
}
}
util::SeekOrThrow(file.get(), 0);
return file.get();
}
std::vector<VocabEntry> ParseVocab(StringPiece words) {
std::vector<VocabEntry> entries;
entries.push_back(VocabEntry("<unk>"));
for (util::TokenIter<util::SingleCharacter> i(words, '\t'); i; ++i) {
entries.push_back(VocabEntry(*i));
}
std::sort(entries.begin() + 1, entries.end());
return entries;
}
int WriteVocabFile(StringPiece words, util::scoped_fd &file) {
return WriteVocabFile(ParseVocab(words), file);
}
class TestFiles {
public:
TestFiles() {}
int Test0() {
return WriteVocabFile("this\tis\ta\tfirst\tcut", test[0]);
}
int Test1() {
return WriteVocabFile("is this\tthis a\tfirst cut\ta first", test[1]);
}
int Test2() {
return WriteVocabFile("is\tsecd\ti", test[2]);
}
int NoUNK() {
std::vector<VocabEntry> no_unk_vec;
no_unk_vec.push_back(VocabEntry("toto"));
return WriteVocabFile(no_unk_vec, no_unk);
}
int BadOrder() {
std::vector<VocabEntry> bad_order_vec;
bad_order_vec.push_back(VocabEntry("<unk>"));
bad_order_vec.push_back(VocabEntry("0"));
bad_order_vec.push_back(VocabEntry("1"));
bad_order_vec.push_back(VocabEntry("2"));
bad_order_vec.push_back(VocabEntry("a"));
return WriteVocabFile(bad_order_vec, bad_order);
}
private:
util::scoped_fd test[3], no_unk, bad_order;
};
class DoNothingEnumerate : public EnumerateVocab {
public:
void Add(WordIndex, const StringPiece &) {}
};
BOOST_AUTO_TEST_CASE(MergeVocabTest) {
TestFiles files;
util::FixedArray<int> used_files(3);
used_files.push_back(files.Test0());
used_files.push_back(files.Test1());
used_files.push_back(files.Test2());
std::vector<lm::WordIndex> model_max_idx;
model_max_idx.push_back(10);
model_max_idx.push_back(10);
model_max_idx.push_back(10);
util::scoped_fd combined(util::MakeTemp(util::DefaultTempDirectory()));
UniversalVocab universal_vocab(model_max_idx);
{
ngram::ImmediateWriteWordsWrapper writer(NULL, combined.get(), 0);
MergeVocab(used_files, universal_vocab, writer);
}
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(0, 0), 0);
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 0), 0);
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(2, 0), 0);
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(0, 1), 1);
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 1), 2);
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(2, 1), 8);
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(0, 5), 11);
#if BYTE_ORDER == LITTLE_ENDIAN
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 3), 4);
#elif BYTE_ORDER == BIG_ENDIAN
// MurmurHash has a different ordering of the vocabulary.
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(1, 3), 5);
#endif
BOOST_CHECK_EQUAL(universal_vocab.GetUniversalIdx(2, 3), 10);
util::SeekOrThrow(combined.get(), 0);
util::FilePiece f(combined.release());
std::vector<VocabEntry> expected = ParseVocab("a\tis this\tthis a\tfirst cut\tthis\ta first\tcut\tis\ti\tsecd\tfirst");
for (std::vector<VocabEntry>::const_iterator i = expected.begin(); i != expected.end(); ++i) {
BOOST_CHECK_EQUAL(i->str, f.ReadLine('\0'));
}
BOOST_CHECK_THROW(f.ReadLine('\0'), util::EndOfFileException);
}
BOOST_AUTO_TEST_CASE(MergeVocabNoUnkTest) {
TestFiles files;
util::FixedArray<int> used_files(1);
used_files.push_back(files.NoUNK());
std::vector<lm::WordIndex> model_max_idx;
model_max_idx.push_back(10);
UniversalVocab universal_vocab(model_max_idx);
DoNothingEnumerate nothing;
BOOST_CHECK_THROW(MergeVocab(used_files, universal_vocab, nothing), FormatLoadException);
}
BOOST_AUTO_TEST_CASE(MergeVocabWrongOrderTest) {
TestFiles files;
util::FixedArray<int> used_files(2);
used_files.push_back(files.Test0());
used_files.push_back(files.BadOrder());
std::vector<lm::WordIndex> model_max_idx;
model_max_idx.push_back(10);
model_max_idx.push_back(10);
lm::interpolate::UniversalVocab universal_vocab(model_max_idx);
DoNothingEnumerate nothing;
BOOST_CHECK_THROW(MergeVocab(used_files, universal_vocab, nothing), FormatLoadException);
}
}}} // namespaces
|