File size: 7,051 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
#ifndef LM_MODEL_H
#define LM_MODEL_H
#include "bhiksha.hh"
#include "binary_format.hh"
#include "config.hh"
#include "facade.hh"
#include "quantize.hh"
#include "search_hashed.hh"
#include "search_trie.hh"
#include "state.hh"
#include "value.hh"
#include "vocab.hh"
#include "weights.hh"
#include "../util/murmur_hash.hh"
#include <algorithm>
#include <vector>
#include <cstring>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
namespace detail {
// Should return the same results as SRI.
// ModelFacade typedefs Vocabulary so we use VocabularyT to avoid naming conflicts.
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
// This is the model type returned by RecognizeBinary.
static const ModelType kModelType;
static const unsigned int kVersion = Search::kVersion;
/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
* itself.
*/
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
/* Load the model from a file. It may be an ARPA or binary file. Binary
* files must have the format expected by this class or you'll get an
* exception. So TrieModel can only load ARPA or binary created by
* TrieModel. To classify binary files, call RecognizeBinary in
* lm/binary_format.hh.
*/
explicit GenericModel(const char *file, const Config &config = Config());
/* Score p(new_word | in_state) and incorporate new_word into out_state.
* Note that in_state and out_state must be different references:
* &in_state != &out_state.
*/
FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
/* Slower call without in_state. Try to remember state, but sometimes it
* would cost too much memory or your decoder isn't setup properly.
* To use this function, make an array of WordIndex containing the context
* vocabulary ids in reverse order. Then, pass the bounds of the array:
* [context_rbegin, context_rend). The new_word is not part of the context
* array unless you intend to repeat words.
*/
FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
/* Get the state for a context. Don't use this if you can avoid it. Use
* BeginSentenceState or NullContextState and extend from those. If
* you're only going to use this state to call FullScore once, use
* FullScoreForgotState.
* To use this function, make an array of WordIndex containing the context
* vocabulary ids in reverse order. Then, pass the bounds of the array:
* [context_rbegin, context_rend).
*/
void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
/* More efficient version of FullScore where a partial n-gram has already
* been scored.
* NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE.
*/
FullScoreReturn ExtendLeft(
// Additional context in reverse order. This will update add_rend to
const WordIndex *add_rbegin, const WordIndex *add_rend,
// Backoff weights to use.
const float *backoff_in,
// extend_left returned by a previous query.
uint64_t extend_pointer,
// Length of n-gram that the pointer corresponds to.
unsigned char extend_length,
// Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)]
float *backoff_out,
// Amount of additional content that should be considered by the next call.
unsigned char &next_use) const;
/* Return probabilities minus rest costs for an array of pointers. The
* first length should be the length of the n-gram to which pointers_begin
* points.
*/
float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
// Compiler should optimize this if away.
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
}
private:
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
// Score bigrams and above. Do not include backoff.
void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const;
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
void InitializeFromARPA(int fd, const char *file, const Config &config);
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
BinaryFormat backing_;
VocabularyT vocab_;
Search search_;
};
} // namespace detail
// Instead of typedef, inherit. This allows the Model etc to be forward declared.
// Oh the joys of C and C++.
#define LM_COMMA() ,
#define LM_NAME_MODEL(name, from)\
class name : public from {\
public:\
name(const char *file, const Config &config = Config()) : from(file, config) {}\
};
LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>);
LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>);
LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
// Default implementation. No real reason for it to be the default.
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef ProbingModel Model;
/* Autorecognize the file type, load, and return the virtual base class. Don't
* use the virtual base class if you can avoid it. Instead, use the above
* classes as template arguments to your own virtual feature function.*/
base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
} // namespace ngram
} // namespace lm
#endif // LM_MODEL_H
|