#include "model.hh" #include "../util/file_stream.hh" #include "../util/file.hh" #include "../util/file_piece.hh" #include "../util/usage.hh" #include "../util/thread_pool.hh" #include #include #include #include namespace { template void ConvertToBytes(const Model &model, int fd_in) { util::FilePiece in(fd_in); util::FileStream out(1); Width width; StringPiece word; const Width end_sentence = (Width)model.GetVocabulary().EndSentence(); while (true) { while (in.ReadWordSameLine(word)) { width = (Width)model.GetVocabulary().Index(word); out.write(&width, sizeof(Width)); } if (!in.ReadLineOrEOF(word)) break; out.write(&end_sentence, sizeof(Width)); } } template class Worker { public: explicit Worker(const Model &model, double &add_total) : model_(model), total_(0.0), add_total_(add_total) {} // Destructors happen in the main thread, so there's no race for add_total_. ~Worker() { add_total_ += total_; } typedef boost::iterator_range Request; void operator()(Request request) { const lm::ngram::State *const begin_state = &model_.BeginSentenceState(); const lm::ngram::State *next_state = begin_state; const Width kEOS = model_.GetVocabulary().EndSentence(); float sum = 0.0; // Do even stuff first. const Width *even_end = request.begin() + (request.size() & ~1); // Alternating states const Width *i; for (i = request.begin(); i != even_end;) { sum += model_.FullScore(*next_state, *i, state_[1]).prob; next_state = (*i++ == kEOS) ? begin_state : &state_[1]; sum += model_.FullScore(*next_state, *i, state_[0]).prob; next_state = (*i++ == kEOS) ? begin_state : &state_[0]; } // Odd corner case. if (request.size() & 1) { sum += model_.FullScore(*next_state, *i, state_[2]).prob; next_state = (*i++ == kEOS) ? begin_state : &state_[2]; } total_ += sum; } private: const Model &model_; double total_; double &add_total_; lm::ngram::State state_[3]; }; struct Config { int fd_in; std::size_t threads; std::size_t buf_per_thread; bool query; }; template void QueryFromBytes(const Model &model, const Config &config) { util::FileStream out(1); out << "Threads: " << config.threads << '\n'; const Width kEOS = model.GetVocabulary().EndSentence(); double total = 0.0; // Number of items to have in queue in addition to everything in flight. const std::size_t kInQueue = 3; std::size_t total_queue = config.threads + kInQueue; std::vector backing(config.buf_per_thread * total_queue); double loaded_cpu; double loaded_wall; uint64_t queries = 0; { util::RecyclingThreadPool > pool(total_queue, config.threads, Worker(model, total), boost::iterator_range((Width*)0, (Width*)0)); for (std::size_t i = 0; i < total_queue; ++i) { pool.PopulateRecycling(boost::iterator_range(&backing[i * config.buf_per_thread], &backing[i * config.buf_per_thread])); } loaded_cpu = util::CPUTime(); loaded_wall = util::WallTime(); out << "To Load, CPU: " << loaded_cpu << " Wall: " << loaded_wall << '\n'; boost::iterator_range overhang((Width*)0, (Width*)0); while (true) { boost::iterator_range buf = pool.Consume(); std::memmove(buf.begin(), overhang.begin(), overhang.size() * sizeof(Width)); std::size_t got = util::ReadOrEOF(config.fd_in, buf.begin() + overhang.size(), (config.buf_per_thread - overhang.size()) * sizeof(Width)); if (!got && overhang.empty()) break; UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width)); Width *read_end = buf.begin() + overhang.size() + got / sizeof(Width); Width *last_eos; for (last_eos = read_end - 1; ; --last_eos) { UTIL_THROW_IF2(last_eos <= buf.begin(), "Encountered a sentence longer than the buffer size of " << config.buf_per_thread << " words. Rerun with increased buffer size. TODO: adaptable buffer"); if (*last_eos == kEOS) break; } buf = boost::iterator_range(buf.begin(), last_eos + 1); overhang = boost::iterator_range(last_eos + 1, read_end); queries += buf.size(); pool.Produce(buf); } } // Drain pool. double after_cpu = util::CPUTime(); double after_wall = util::WallTime(); util::FileStream(2, 70) << "Probability sum: " << total << '\n'; out << "Queries: " << queries << '\n'; out << "Excluding load, CPU: " << (after_cpu - loaded_cpu) << " Wall: " << (after_wall - loaded_wall) << '\n'; double cpu_per_entry = ((after_cpu - loaded_cpu) / static_cast(queries)); double wall_per_entry = ((after_wall - loaded_wall) / static_cast(queries)); out << "Seconds per query excluding load, CPU: " << cpu_per_entry << " Wall: " << wall_per_entry << '\n'; out << "Queries per second excluding load, CPU: " << (1.0/cpu_per_entry) << " Wall: " << (1.0/wall_per_entry) << '\n'; out << "RSSMax: " << util::RSSMax() << '\n'; } template void DispatchFunction(const Model &model, const Config &config) { if (config.query) { QueryFromBytes(model, config); } else { ConvertToBytes(model, config.fd_in); } } template void DispatchWidth(const char *file, const Config &config) { lm::ngram::Config model_config; model_config.load_method = util::READ; Model model(file, model_config); uint64_t bound = model.GetVocabulary().Bound(); if (bound <= 256) { DispatchFunction(model, config); } else if (bound <= 65536) { DispatchFunction(model, config); } else if (bound <= (1ULL << 32)) { DispatchFunction(model, config); } else { DispatchFunction(model, config); } } void Dispatch(const char *file, const Config &config) { using namespace lm::ngram; lm::ngram::ModelType model_type; if (lm::ngram::RecognizeBinary(file, model_type)) { switch(model_type) { case PROBING: DispatchWidth(file, config); break; case REST_PROBING: DispatchWidth(file, config); break; case TRIE: DispatchWidth(file, config); break; case QUANT_TRIE: DispatchWidth(file, config); break; case ARRAY_TRIE: DispatchWidth(file, config); break; case QUANT_ARRAY_TRIE: DispatchWidth(file, config); break; default: UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type); } } else { UTIL_THROW(util::Exception, "Binarize before running benchmarks."); } } } // namespace int main(int argc, char *argv[]) { try { Config config; config.fd_in = 0; std::string model; namespace po = boost::program_options; po::options_description options("Benchmark options"); options.add_options() ("help,h", po::bool_switch(), "Show help message") ("model,m", po::value(&model)->required(), "Model to query or convert vocab ids") ("threads,t", po::value(&config.threads)->default_value(boost::thread::hardware_concurrency()), "Threads to use (querying only; TODO vocab conversion)") ("buffer,b", po::value(&config.buf_per_thread)->default_value(4096), "Number of words to buffer per task.") ("vocab,v", po::bool_switch(), "Convert strings to vocab ids") ("query,q", po::bool_switch(), "Query from vocab ids"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, options), vm); if (argc == 1 || vm["help"].as()) { std::cerr << "Benchmark program for KenLM. Intended usage:\n" << "#Convert text to vocabulary ids offline. These ids are tied to a model.\n" << argv[0] << " -v -m $model <$text >$text.vocab\n" << "#Ensure files are in RAM.\n" << "cat $text.vocab $model >/dev/null\n" << "#Timed query against the model.\n" << argv[0] << " -q -m $model <$text.vocab\n"; return 0; } po::notify(vm); if (!(vm["vocab"].as() ^ vm["query"].as())) { std::cerr << "Specify exactly one of -v (vocab conversion) or -q (query)." << std::endl; return 0; } config.query = vm["query"].as(); if (!config.threads) { std::cerr << "Specify a non-zero number of threads with -t." << std::endl; } Dispatch(model.c_str(), config); } catch (const std::exception &e) { std::cerr << e.what() << std::endl; return 1; } return 0; }