File size: 1,347 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
#ifndef LM_INTERPOLATE_UNIVERSAL_VOCAB_H
#define LM_INTERPOLATE_UNIVERSAL_VOCAB_H
#include "../word_index.hh"
#include <vector>
#include <cstddef>
namespace lm {
namespace interpolate {
class UniversalVocab {
public:
explicit UniversalVocab(const std::vector<WordIndex>& model_vocab_sizes);
// GetUniversalIndex takes the model number and index for the specific
// model and returns the universal model number
WordIndex GetUniversalIdx(std::size_t model_num, WordIndex model_word_index) const {
return model_index_map_[model_num][model_word_index];
}
const WordIndex *Mapping(std::size_t model) const {
return &*model_index_map_[model].begin();
}
WordIndex SlowConvertToModel(std::size_t model, WordIndex index) const {
std::vector<WordIndex>::const_iterator i = lower_bound(model_index_map_[model].begin(), model_index_map_[model].end(), index);
if (i == model_index_map_[model].end() || *i != index) return 0;
return i - model_index_map_[model].begin();
}
void InsertUniversalIdx(std::size_t model_num, WordIndex word_index,
WordIndex universal_word_index) {
model_index_map_[model_num][word_index] = universal_word_index;
}
private:
std::vector<std::vector<WordIndex> > model_index_map_;
};
} // namespace interpolate
} // namespace lm
#endif // LM_INTERPOLATE_UNIVERSAL_VOCAB_H
|