File size: 5,667 Bytes
ce00289 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
import torch
from jaxtyping import Float, Int
@dataclass
class ModelInfo:
name: str
# Not the actual number of parameters, but rather the order of magnitude
n_params_estimate: int
n_layers: int
n_heads: int
d_model: int
d_vocab: int
class TransparentLlm(ABC):
"""
An abstract stateful interface for a language model. The model is supposed to be
loaded at the class initialization.
The internal state is the resulting tensors from the last call of the `run` method.
Most of the methods could return values based on the state, but some may do cheap
computations based on them.
"""
@abstractmethod
def model_info(self) -> ModelInfo:
"""
Gives general info about the model. This method must be available before any
calls of the `run`.
"""
pass
@abstractmethod
def run(self, sentences: List[str]) -> None:
"""
Run the inference on the given sentences in a single batch and store all
necessary info in the internal state.
"""
pass
@abstractmethod
def batch_size(self) -> int:
"""
The size of the batch that was used for the last call of `run`.
"""
pass
@abstractmethod
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
pass
@abstractmethod
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
pass
@abstractmethod
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
pass
@abstractmethod
def unembed(
self,
t: Float[torch.Tensor, "d_model"],
normalize: bool,
) -> Float[torch.Tensor, "vocab"]:
"""
Project the given vector (for example, the state of the residual stream for a
layer and token) into the output vocabulary.
normalize: whether to apply the final normalization before the unembedding.
Setting it to True and applying to output of the last layer gives the output of
the model.
"""
pass
# ================= Methods related to the residual stream =================
@abstractmethod
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream before entering the layer. For example, when
layer == 0 these must the embedded tokens (including positional embedding).
"""
pass
@abstractmethod
def residual_after_attn(
self, layer: int
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream after attention, but before the FFN in the
given layer.
"""
pass
@abstractmethod
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream after the given layer. This is equivalent to the
next layer's input.
"""
pass
# ================ Methods related to the feed-forward layer ===============
@abstractmethod
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The output of the FFN layer, before it gets merged into the residual stream.
"""
pass
@abstractmethod
def decomposed_ffn_out(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "hidden d_model"]:
"""
A collection of vectors added to the residual stream by each neuron. It should
be the same as neuron activations multiplied by neuron outputs.
"""
pass
@abstractmethod
def neuron_activations(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "d_ffn"]:
"""
The content of the hidden layer right after the activation function was applied.
"""
pass
@abstractmethod
def neuron_output(
self,
layer: int,
neuron: int,
) -> Float[torch.Tensor, "d_model"]:
"""
Return the value that the given neuron adds to the residual stream. It's a raw
vector from the model parameters, no activation involved.
"""
pass
# ==================== Methods related to the attention ====================
@abstractmethod
def attention_matrix(
self, batch_i, layer: int, head: int
) -> Float[torch.Tensor, "query_pos key_pos"]:
"""
Return a lower-diagonal attention matrix.
"""
pass
@abstractmethod
def attention_output(
self,
batch_i: int,
layer: int,
pos: int,
head: int,
) -> Float[torch.Tensor, "d_model"]:
"""
Return what the given head at the given layer and pos added to the residual
stream.
"""
pass
@abstractmethod
def decomposed_attn(
self, batch_i: int, layer: int
) -> Float[torch.Tensor, "source target head d_model"]:
"""
Here
- source: index of token from the previous layer
- target: index of token on the current layer
The decomposed attention tells what vector from source representation was used
in order to contribute to the taget representation.
"""
pass
|