File size: 2,534 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
#ifndef LM_BUILDER_OUTPUT_H
#define LM_BUILDER_OUTPUT_H
#include "lm/builder/header_info.hh"
#include "util/file.hh"
#include <boost/ptr_container/ptr_vector.hpp>
#include <boost/utility.hpp>
#include <map>
namespace util { namespace stream { class Chains; class ChainPositions; } }
/* Outputs from lmplz: ARPA< sharded files, etc */
namespace lm { namespace builder {
// These are different types of hooks. Values should be consecutive to enable a vector lookup.
enum HookType {
COUNT_HOOK, // Raw N-gram counts, highest order only.
PROB_PARALLEL_HOOK, // Probability and backoff (or just q). Output must process the orders in parallel or there will be a deadlock.
PROB_SEQUENTIAL_HOOK, // Probability and backoff (or just q). Output can process orders any way it likes. This requires writing the data to disk then reading. Useful for ARPA files, which put unigrams first etc.
NUMBER_OF_HOOKS // Keep this last so we know how many values there are.
};
class Output;
class OutputHook {
public:
explicit OutputHook(HookType hook_type) : type_(hook_type), master_(NULL) {}
virtual ~OutputHook();
virtual void Apply(util::stream::Chains &chains);
virtual void Run(const util::stream::ChainPositions &positions) = 0;
protected:
const HeaderInfo &GetHeader() const;
int GetVocabFD() const;
private:
friend class Output;
const HookType type_;
const Output *master_;
};
class Output : boost::noncopyable {
public:
Output() {}
// Takes ownership.
void Add(OutputHook *hook) {
hook->master_ = this;
outputs_[hook->type_].push_back(hook);
}
bool Have(HookType hook_type) const {
return !outputs_[hook_type].empty();
}
void SetVocabFD(int to) { vocab_fd_ = to; }
int GetVocabFD() const { return vocab_fd_; }
void SetHeader(const HeaderInfo &header) { header_ = header; }
const HeaderInfo &GetHeader() const { return header_; }
void Apply(HookType hook_type, util::stream::Chains &chains) {
for (boost::ptr_vector<OutputHook>::iterator entry = outputs_[hook_type].begin(); entry != outputs_[hook_type].end(); ++entry) {
entry->Apply(chains);
}
}
private:
boost::ptr_vector<OutputHook> outputs_[NUMBER_OF_HOOKS];
int vocab_fd_;
HeaderInfo header_;
};
inline const HeaderInfo &OutputHook::GetHeader() const {
return master_->GetHeader();
}
inline int OutputHook::GetVocabFD() const {
return master_->GetVocabFD();
}
}} // namespaces
#endif // LM_BUILDER_OUTPUT_H
|