File size: 4,205 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
#ifndef LM_FILTER_THREAD_H
#define LM_FILTER_THREAD_H
#include "../../util/thread_pool.hh"
#include <boost/utility/in_place_factory.hpp>
#include <deque>
#include <stack>
namespace lm {
template <class OutputBuffer> class ThreadBatch {
public:
ThreadBatch() {}
void Reserve(size_t size) {
input_.Reserve(size);
output_.Reserve(size);
}
// File reading thread.
InputBuffer &Fill(uint64_t sequence) {
sequence_ = sequence;
// Why wait until now to clear instead of after output? free in the same
// thread as allocated.
input_.Clear();
return input_;
}
// Filter worker thread.
template <class Filter> void CallFilter(Filter &filter) {
input_.CallFilter(filter, output_);
}
uint64_t Sequence() const { return sequence_; }
// File writing thread.
template <class RealOutput> void Flush(RealOutput &output) {
output_.Flush(output);
}
private:
InputBuffer input_;
OutputBuffer output_;
uint64_t sequence_;
};
template <class Batch, class Filter> class FilterWorker {
public:
typedef Batch *Request;
FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
void operator()(Request request) {
request->CallFilter(filter_);
done_.Produce(request);
}
private:
Filter filter_;
util::PCQueue<Request> &done_;
};
// There should only be one OutputWorker.
template <class Batch, class Output> class OutputWorker {
public:
typedef Batch *Request;
OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
void operator()(Request request) {
assert(request->Sequence() >= base_sequence_);
// Assemble the output in order.
uint64_t pos = request->Sequence() - base_sequence_;
if (pos >= ordering_.size()) {
ordering_.resize(pos + 1, NULL);
}
ordering_[pos] = request;
while (!ordering_.empty() && ordering_.front()) {
ordering_.front()->Flush(output_);
done_.Produce(ordering_.front());
ordering_.pop_front();
++base_sequence_;
}
}
private:
Output &output_;
util::PCQueue<Request> &done_;
std::deque<Request> ordering_;
uint64_t base_sequence_;
};
template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
private:
typedef ThreadBatch<OutputBuffer> Batch;
public:
Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
: batch_size_(batch_size), queue_size_(queue),
batches_(queue),
to_read_(queue),
output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
sequence_(0) {
for (size_t i = 0; i < queue; ++i) {
batches_[i].Reserve(batch_size);
local_read_.push(&batches_[i]);
}
NewInput();
}
void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
input_->AddNGram(ngram, line, output);
if (input_->Size() == batch_size_) {
FlushInput();
NewInput();
}
}
void Flush() {
FlushInput();
while (local_read_.size() < queue_size_) {
MoveRead();
}
NewInput();
}
private:
void FlushInput() {
if (input_->Empty()) return;
filter_.Produce(local_read_.top());
local_read_.pop();
if (local_read_.empty()) MoveRead();
}
void NewInput() {
input_ = &local_read_.top()->Fill(sequence_++);
}
void MoveRead() {
local_read_.push(to_read_.Consume());
}
const size_t batch_size_;
const size_t queue_size_;
std::vector<Batch> batches_;
util::PCQueue<Batch*> to_read_;
std::stack<Batch*> local_read_;
util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
uint64_t sequence_;
InputBuffer *input_;
};
} // namespace lm
#endif // LM_FILTER_THREAD_H
|