#ifndef LM_FILTER_THREAD_H #define LM_FILTER_THREAD_H #include "../../util/thread_pool.hh" #include #include #include namespace lm { template class ThreadBatch { public: ThreadBatch() {} void Reserve(size_t size) { input_.Reserve(size); output_.Reserve(size); } // File reading thread. InputBuffer &Fill(uint64_t sequence) { sequence_ = sequence; // Why wait until now to clear instead of after output? free in the same // thread as allocated. input_.Clear(); return input_; } // Filter worker thread. template void CallFilter(Filter &filter) { input_.CallFilter(filter, output_); } uint64_t Sequence() const { return sequence_; } // File writing thread. template void Flush(RealOutput &output) { output_.Flush(output); } private: InputBuffer input_; OutputBuffer output_; uint64_t sequence_; }; template class FilterWorker { public: typedef Batch *Request; FilterWorker(const Filter &filter, util::PCQueue &done) : filter_(filter), done_(done) {} void operator()(Request request) { request->CallFilter(filter_); done_.Produce(request); } private: Filter filter_; util::PCQueue &done_; }; // There should only be one OutputWorker. template class OutputWorker { public: typedef Batch *Request; OutputWorker(Output &output, util::PCQueue &done) : output_(output), done_(done), base_sequence_(0) {} void operator()(Request request) { assert(request->Sequence() >= base_sequence_); // Assemble the output in order. uint64_t pos = request->Sequence() - base_sequence_; if (pos >= ordering_.size()) { ordering_.resize(pos + 1, NULL); } ordering_[pos] = request; while (!ordering_.empty() && ordering_.front()) { ordering_.front()->Flush(output_); done_.Produce(ordering_.front()); ordering_.pop_front(); ++base_sequence_; } } private: Output &output_; util::PCQueue &done_; std::deque ordering_; uint64_t base_sequence_; }; template class Controller : boost::noncopyable { private: typedef ThreadBatch Batch; public: Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output) : batch_size_(batch_size), queue_size_(queue), batches_(queue), to_read_(queue), output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL), filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL), sequence_(0) { for (size_t i = 0; i < queue; ++i) { batches_[i].Reserve(batch_size); local_read_.push(&batches_[i]); } NewInput(); } void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) { input_->AddNGram(ngram, line, output); if (input_->Size() == batch_size_) { FlushInput(); NewInput(); } } void Flush() { FlushInput(); while (local_read_.size() < queue_size_) { MoveRead(); } NewInput(); } private: void FlushInput() { if (input_->Empty()) return; filter_.Produce(local_read_.top()); local_read_.pop(); if (local_read_.empty()) MoveRead(); } void NewInput() { input_ = &local_read_.top()->Fill(sequence_++); } void MoveRead() { local_read_.push(to_read_.Consume()); } const size_t batch_size_; const size_t queue_size_; std::vector batches_; util::PCQueue to_read_; std::stack local_read_; util::ThreadPool > output_; util::ThreadPool > filter_; uint64_t sequence_; InputBuffer *input_; }; } // namespace lm #endif // LM_FILTER_THREAD_H