File size: 5,667 Bytes
ce00289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List

import torch
from jaxtyping import Float, Int


@dataclass
class ModelInfo:
    name: str

    # Not the actual number of parameters, but rather the order of magnitude
    n_params_estimate: int

    n_layers: int
    n_heads: int
    d_model: int
    d_vocab: int


class TransparentLlm(ABC):
    """
    An abstract stateful interface for a language model. The model is supposed to be
    loaded at the class initialization.

    The internal state is the resulting tensors from the last call of the `run` method.
    Most of the methods could return values based on the state, but some may do cheap
    computations based on them.
    """

    @abstractmethod
    def model_info(self) -> ModelInfo:
        """
        Gives general info about the model. This method must be available before any
        calls of the `run`.
        """
        pass

    @abstractmethod
    def run(self, sentences: List[str]) -> None:
        """
        Run the inference on the given sentences in a single batch and store all
        necessary info in the internal state.
        """
        pass

    @abstractmethod
    def batch_size(self) -> int:
        """
        The size of the batch that was used for the last call of `run`.
        """
        pass

    @abstractmethod
    def tokens(self) -> Int[torch.Tensor, "batch pos"]:
        pass

    @abstractmethod
    def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
        pass

    @abstractmethod
    def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
        pass

    @abstractmethod
    def unembed(
        self,
        t: Float[torch.Tensor, "d_model"],
        normalize: bool,
    ) -> Float[torch.Tensor, "vocab"]:
        """
        Project the given vector (for example, the state of the residual stream for a
        layer and token) into the output vocabulary.

        normalize: whether to apply the final normalization before the unembedding.
        Setting it to True and applying to output of the last layer gives the output of
        the model.
        """
        pass

    # ================= Methods related to the residual stream =================

    @abstractmethod
    def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
        """
        The state of the residual stream before entering the layer. For example, when
        layer == 0 these must the embedded tokens (including positional embedding).
        """
        pass

    @abstractmethod
    def residual_after_attn(
        self, layer: int
    ) -> Float[torch.Tensor, "batch pos d_model"]:
        """
        The state of the residual stream after attention, but before the FFN in the
        given layer.
        """
        pass

    @abstractmethod
    def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
        """
        The state of the residual stream after the given layer. This is equivalent to the
        next layer's input.
        """
        pass

    # ================ Methods related to the feed-forward layer ===============

    @abstractmethod
    def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
        """
        The output of the FFN layer, before it gets merged into the residual stream.
        """
        pass

    @abstractmethod
    def decomposed_ffn_out(
        self,
        batch_i: int,
        layer: int,
        pos: int,
    ) -> Float[torch.Tensor, "hidden d_model"]:
        """
        A collection of vectors added to the residual stream by each neuron. It should
        be the same as neuron activations multiplied by neuron outputs.
        """
        pass

    @abstractmethod
    def neuron_activations(
        self,
        batch_i: int,
        layer: int,
        pos: int,
    ) -> Float[torch.Tensor, "d_ffn"]:
        """
        The content of the hidden layer right after the activation function was applied.
        """
        pass

    @abstractmethod
    def neuron_output(
        self,
        layer: int,
        neuron: int,
    ) -> Float[torch.Tensor, "d_model"]:
        """
        Return the value that the given neuron adds to the residual stream. It's a raw
        vector from the model parameters, no activation involved.
        """
        pass

    # ==================== Methods related to the attention ====================

    @abstractmethod
    def attention_matrix(
        self, batch_i, layer: int, head: int
    ) -> Float[torch.Tensor, "query_pos key_pos"]:
        """
        Return a lower-diagonal attention matrix.
        """
        pass

    @abstractmethod
    def attention_output(
        self,
        batch_i: int,
        layer: int,
        pos: int,
        head: int,
    ) -> Float[torch.Tensor, "d_model"]:
        """
        Return what the given head at the given layer and pos added to the residual
        stream.
        """
        pass

    @abstractmethod
    def decomposed_attn(
        self, batch_i: int, layer: int
    ) -> Float[torch.Tensor, "source target head d_model"]:
        """
        Here
        - source: index of token from the previous layer
        - target: index of token on the current layer
        The decomposed attention tells what vector from source representation was used
        in order to contribute to the taget representation.
        """
        pass