File size: 5,186 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
#include "model_buffer.hh"
#include "compare.hh"
#include "../state.hh"
#include "../weights.hh"
#include "../../util/exception.hh"
#include "../../util/file_stream.hh"
#include "../../util/file.hh"
#include "../../util/file_piece.hh"
#include "../../util/stream/io.hh"
#include "../../util/stream/multi_stream.hh"
#include <boost/lexical_cast.hpp>
#include <numeric>
namespace lm {
namespace {
const char kMetadataHeader[] = "KenLM intermediate binary file";
} // namespace
ModelBuffer::ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q)
: file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer), output_q_(output_q),
vocab_file_(keep_buffer ? util::CreateOrThrow((file_base_ + ".vocab").c_str()) : util::MakeTemp(file_base_)) {}
ModelBuffer::ModelBuffer(StringPiece file_base)
: file_base_(file_base.data(), file_base.size()), keep_buffer_(false) {
const std::string full_name = file_base_ + ".kenlm_intermediate";
util::FilePiece in(full_name.c_str());
StringPiece token = in.ReadLine();
UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader);
token = in.ReadDelimited();
UTIL_THROW_IF2(token != "Counts", "Expected Counts, got \"" << token << "\" in " << full_name);
char got;
while ((got = in.get()) == ' ') {
counts_.push_back(in.ReadULong());
}
UTIL_THROW_IF2(got != '\n', "Expected newline at end of counts.");
token = in.ReadDelimited();
UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name);
token = in.ReadDelimited();
if (token == "q") {
output_q_ = true;
} else if (token == "pb") {
output_q_ = false;
} else {
UTIL_THROW(util::Exception, "Unknown payload " << token);
}
vocab_file_.reset(util::OpenReadOrThrow((file_base_ + ".vocab").c_str()));
files_.Init(counts_.size());
for (unsigned long i = 0; i < counts_.size(); ++i) {
files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()));
}
}
void ModelBuffer::Sink(util::stream::Chains &chains, const std::vector<uint64_t> &counts) {
counts_ = counts;
// Open files.
files_.Init(chains.size());
for (std::size_t i = 0; i < chains.size(); ++i) {
if (keep_buffer_) {
files_.push_back(util::CreateOrThrow(
(file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()
));
} else {
files_.push_back(util::MakeTemp(file_base_));
}
chains[i] >> util::stream::Write(files_.back().get());
}
if (keep_buffer_) {
util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str()));
util::FileStream meta(metadata.get(), 200);
meta << kMetadataHeader << "\nCounts";
for (std::vector<uint64_t>::const_iterator i = counts_.begin(); i != counts_.end(); ++i) {
meta << ' ' << *i;
}
meta << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
}
}
void ModelBuffer::Source(util::stream::Chains &chains) {
assert(chains.size() <= files_.size());
for (unsigned int i = 0; i < chains.size(); ++i) {
chains[i].SetProgressTarget(util::SizeOrThrow(files_[i].get()));
chains[i] >> util::stream::PRead(files_[i].get());
}
}
void ModelBuffer::Source(std::size_t order_minus_1, util::stream::Chain &chain) {
chain >> util::stream::PRead(files_[order_minus_1].get());
}
float ModelBuffer::SlowQuery(const ngram::State &context, WordIndex word, ngram::State &out) const {
// Lookup unigram.
ProbBackoff value;
util::ErsatzPRead(RawFile(0), &value, sizeof(value), word * (sizeof(WordIndex) + sizeof(value)) + sizeof(WordIndex));
out.backoff[0] = value.backoff;
out.words[0] = word;
out.length = 1;
std::vector<WordIndex> buffer(context.length + 1), query(context.length + 1);
std::reverse_copy(context.words, context.words + context.length, query.begin());
query[context.length] = word;
for (std::size_t order = 2; order <= query.size() && order <= context.length + 1; ++order) {
SuffixOrder less(order);
const WordIndex *key = &*query.end() - order;
int file = RawFile(order - 1);
std::size_t length = order * sizeof(WordIndex) + sizeof(ProbBackoff);
// TODO: cache file size?
uint64_t begin = 0, end = util::SizeOrThrow(file) / length;
while (true) {
if (end <= begin) {
// Did not find for order.
return std::accumulate(context.backoff + out.length - 1, context.backoff + context.length, value.prob);
}
uint64_t test = begin + (end - begin) / 2;
util::ErsatzPRead(file, &*buffer.begin(), sizeof(WordIndex) * order, test * length);
if (less(&*buffer.begin(), key)) {
begin = test + 1;
} else if (less(key, &*buffer.begin())) {
end = test;
} else {
// Found it.
util::ErsatzPRead(file, &value, sizeof(value), test * length + sizeof(WordIndex) * order);
if (order != Order()) {
out.length = order;
out.backoff[order - 1] = value.backoff;
out.words[order - 1] = *key;
}
break;
}
}
}
return value.prob;
}
} // namespace
|