xls-r-300m-sv-robust / kenlm /lm /kenlm_benchmark_main.cc
marinone94's picture
Training in progress, epoch 0
1ce325b
#include "model.hh"
#include "../util/file_stream.hh"
#include "../util/file.hh"
#include "../util/file_piece.hh"
#include "../util/usage.hh"
#include "../util/thread_pool.hh"
#include <boost/range/iterator_range.hpp>
#include <boost/program_options.hpp>
#include <iostream>
#include <stdint.h>
namespace {
template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
util::FilePiece in(fd_in);
util::FileStream out(1);
Width width;
StringPiece word;
const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
while (true) {
while (in.ReadWordSameLine(word)) {
width = (Width)model.GetVocabulary().Index(word);
out.write(&width, sizeof(Width));
}
if (!in.ReadLineOrEOF(word)) break;
out.write(&end_sentence, sizeof(Width));
}
}
template <class Model, class Width> class Worker {
public:
explicit Worker(const Model &model, double &add_total) : model_(model), total_(0.0), add_total_(add_total) {}
// Destructors happen in the main thread, so there's no race for add_total_.
~Worker() { add_total_ += total_; }
typedef boost::iterator_range<Width *> Request;
void operator()(Request request) {
const lm::ngram::State *const begin_state = &model_.BeginSentenceState();
const lm::ngram::State *next_state = begin_state;
const Width kEOS = model_.GetVocabulary().EndSentence();
float sum = 0.0;
// Do even stuff first.
const Width *even_end = request.begin() + (request.size() & ~1);
// Alternating states
const Width *i;
for (i = request.begin(); i != even_end;) {
sum += model_.FullScore(*next_state, *i, state_[1]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[1];
sum += model_.FullScore(*next_state, *i, state_[0]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[0];
}
// Odd corner case.
if (request.size() & 1) {
sum += model_.FullScore(*next_state, *i, state_[2]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[2];
}
total_ += sum;
}
private:
const Model &model_;
double total_;
double &add_total_;
lm::ngram::State state_[3];
};
struct Config {
int fd_in;
std::size_t threads;
std::size_t buf_per_thread;
bool query;
};
template <class Model, class Width> void QueryFromBytes(const Model &model, const Config &config) {
util::FileStream out(1);
out << "Threads: " << config.threads << '\n';
const Width kEOS = model.GetVocabulary().EndSentence();
double total = 0.0;
// Number of items to have in queue in addition to everything in flight.
const std::size_t kInQueue = 3;
std::size_t total_queue = config.threads + kInQueue;
std::vector<Width> backing(config.buf_per_thread * total_queue);
double loaded_cpu;
double loaded_wall;
uint64_t queries = 0;
{
util::RecyclingThreadPool<Worker<Model, Width> > pool(total_queue, config.threads, Worker<Model, Width>(model, total), boost::iterator_range<Width *>((Width*)0, (Width*)0));
for (std::size_t i = 0; i < total_queue; ++i) {
pool.PopulateRecycling(boost::iterator_range<Width *>(&backing[i * config.buf_per_thread], &backing[i * config.buf_per_thread]));
}
loaded_cpu = util::CPUTime();
loaded_wall = util::WallTime();
out << "To Load, CPU: " << loaded_cpu << " Wall: " << loaded_wall << '\n';
boost::iterator_range<Width *> overhang((Width*)0, (Width*)0);
while (true) {
boost::iterator_range<Width *> buf = pool.Consume();
std::memmove(buf.begin(), overhang.begin(), overhang.size() * sizeof(Width));
std::size_t got = util::ReadOrEOF(config.fd_in, buf.begin() + overhang.size(), (config.buf_per_thread - overhang.size()) * sizeof(Width));
if (!got && overhang.empty()) break;
UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
Width *read_end = buf.begin() + overhang.size() + got / sizeof(Width);
Width *last_eos;
for (last_eos = read_end - 1; ; --last_eos) {
UTIL_THROW_IF2(last_eos <= buf.begin(), "Encountered a sentence longer than the buffer size of " << config.buf_per_thread << " words. Rerun with increased buffer size. TODO: adaptable buffer");
if (*last_eos == kEOS) break;
}
buf = boost::iterator_range<Width*>(buf.begin(), last_eos + 1);
overhang = boost::iterator_range<Width*>(last_eos + 1, read_end);
queries += buf.size();
pool.Produce(buf);
}
} // Drain pool.
double after_cpu = util::CPUTime();
double after_wall = util::WallTime();
util::FileStream(2, 70) << "Probability sum: " << total << '\n';
out << "Queries: " << queries << '\n';
out << "Excluding load, CPU: " << (after_cpu - loaded_cpu) << " Wall: " << (after_wall - loaded_wall) << '\n';
double cpu_per_entry = ((after_cpu - loaded_cpu) / static_cast<double>(queries));
double wall_per_entry = ((after_wall - loaded_wall) / static_cast<double>(queries));
out << "Seconds per query excluding load, CPU: " << cpu_per_entry << " Wall: " << wall_per_entry << '\n';
out << "Queries per second excluding load, CPU: " << (1.0/cpu_per_entry) << " Wall: " << (1.0/wall_per_entry) << '\n';
out << "RSSMax: " << util::RSSMax() << '\n';
}
template <class Model, class Width> void DispatchFunction(const Model &model, const Config &config) {
if (config.query) {
QueryFromBytes<Model, Width>(model, config);
} else {
ConvertToBytes<Model, Width>(model, config.fd_in);
}
}
template <class Model> void DispatchWidth(const char *file, const Config &config) {
lm::ngram::Config model_config;
model_config.load_method = util::READ;
Model model(file, model_config);
uint64_t bound = model.GetVocabulary().Bound();
if (bound <= 256) {
DispatchFunction<Model, uint8_t>(model, config);
} else if (bound <= 65536) {
DispatchFunction<Model, uint16_t>(model, config);
} else if (bound <= (1ULL << 32)) {
DispatchFunction<Model, uint32_t>(model, config);
} else {
DispatchFunction<Model, uint64_t>(model, config);
}
}
void Dispatch(const char *file, const Config &config) {
using namespace lm::ngram;
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
DispatchWidth<lm::ngram::ProbingModel>(file, config);
break;
case REST_PROBING:
DispatchWidth<lm::ngram::RestProbingModel>(file, config);
break;
case TRIE:
DispatchWidth<lm::ngram::TrieModel>(file, config);
break;
case QUANT_TRIE:
DispatchWidth<lm::ngram::QuantTrieModel>(file, config);
break;
case ARRAY_TRIE:
DispatchWidth<lm::ngram::ArrayTrieModel>(file, config);
break;
case QUANT_ARRAY_TRIE:
DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, config);
break;
default:
UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
}
} else {
UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
}
}
} // namespace
int main(int argc, char *argv[]) {
try {
Config config;
config.fd_in = 0;
std::string model;
namespace po = boost::program_options;
po::options_description options("Benchmark options");
options.add_options()
("help,h", po::bool_switch(), "Show help message")
("model,m", po::value<std::string>(&model)->required(), "Model to query or convert vocab ids")
("threads,t", po::value<std::size_t>(&config.threads)->default_value(boost::thread::hardware_concurrency()), "Threads to use (querying only; TODO vocab conversion)")
("buffer,b", po::value<std::size_t>(&config.buf_per_thread)->default_value(4096), "Number of words to buffer per task.")
("vocab,v", po::bool_switch(), "Convert strings to vocab ids")
("query,q", po::bool_switch(), "Query from vocab ids");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm);
if (argc == 1 || vm["help"].as<bool>()) {
std::cerr << "Benchmark program for KenLM. Intended usage:\n"
<< "#Convert text to vocabulary ids offline. These ids are tied to a model.\n"
<< argv[0] << " -v -m $model <$text >$text.vocab\n"
<< "#Ensure files are in RAM.\n"
<< "cat $text.vocab $model >/dev/null\n"
<< "#Timed query against the model.\n"
<< argv[0] << " -q -m $model <$text.vocab\n";
return 0;
}
po::notify(vm);
if (!(vm["vocab"].as<bool>() ^ vm["query"].as<bool>())) {
std::cerr << "Specify exactly one of -v (vocab conversion) or -q (query)." << std::endl;
return 0;
}
config.query = vm["query"].as<bool>();
if (!config.threads) {
std::cerr << "Specify a non-zero number of threads with -t." << std::endl;
}
Dispatch(model.c_str(), config);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
}
return 0;
}