File size: 2,956 Bytes
1ce325b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#ifndef LM_INTERPOLATE_TUNE_INSTANCE_H
#define LM_INTERPOLATE_TUNE_INSTANCE_H

#include "tune_matrix.hh"
#include "../word_index.hh"
#include "../../util/scoped.hh"
#include "../../util/stream/config.hh"
#include "../../util/string_piece.hh"

#include <boost/optional.hpp>

#include <vector>

namespace util { namespace stream {
class Chain;
class FileBuffer;
}} // namespaces

namespace lm { namespace interpolate {

typedef uint32_t InstanceIndex;
typedef uint32_t ModelIndex;

struct Extension {
  // Which tuning instance does this belong to?
  InstanceIndex instance;
  WordIndex word;
  ModelIndex model;
  // ln p_{model} (word | context(instance))
  float ln_prob;

  bool operator<(const Extension &other) const;
};

class ExtensionsFirstIteration;

struct InstancesConfig {
  // For batching the model reads.  This is per order.
  std::size_t model_read_chain_mem;
  // This is being sorted, make it larger.
  std::size_t extension_write_chain_mem;
  std::size_t lazy_memory;
  util::stream::SortConfig sort;
};

class Instances {
  private:
    typedef Eigen::Matrix<Accum, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> BackoffMatrix;

  public:
    Instances(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config);

    // For destruction of forward-declared classes.
    ~Instances();

    // Full backoff from unigram for each model.
    typedef BackoffMatrix::ConstRowXpr FullBackoffs;
    FullBackoffs LNBackoffs(InstanceIndex instance) const {
      return ln_backoffs_.row(instance);
    }

    InstanceIndex NumInstances() const { return ln_backoffs_.rows(); }

    const Vector &CorrectGradientTerm() const { return neg_ln_correct_sum_; }

    const Matrix &LNUnigrams() const { return ln_unigrams_; }

    // Entry size to use to configure the chain (since in practice order is needed).
    std::size_t ReadExtensionsEntrySize() const;
    void ReadExtensions(util::stream::Chain &chain);

    // Vocab id of the beginning of sentence.  Used to ignore it for normalization.
    WordIndex BOS() const { return bos_; }

  private:
    // Allow the derivatives test to get access.
    friend class MockInstances;
    Instances();

    // backoffs_(instance, model) is the backoff all the way to unigrams.
    BackoffMatrix ln_backoffs_;

    // neg_correct_sum_(model) = -\sum_{instances} ln p_{model}(correct(instance) | context(instance)).
    // This appears as a term in the gradient.
    Vector neg_ln_correct_sum_;

    // ln_unigrams_(word, model) = ln p_{model}(word).
    Matrix ln_unigrams_;

    // This is the source of data for the first iteration.
    util::scoped_ptr<ExtensionsFirstIteration> extensions_first_;

    // Source of data for subsequent iterations.  This contains already-sorted data.
    util::scoped_ptr<util::stream::FileBuffer> extensions_subsequent_;

    WordIndex bos_;

    std::string temp_prefix_;
};

}} // namespaces
#endif // LM_INTERPOLATE_TUNE_INSTANCE_H