|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
from typing import List |
|
|
|
import torch |
|
from jaxtyping import Float, Int |
|
|
|
|
|
@dataclass |
|
class ModelInfo: |
|
name: str |
|
|
|
|
|
n_params_estimate: int |
|
|
|
n_layers: int |
|
n_heads: int |
|
d_model: int |
|
d_vocab: int |
|
|
|
|
|
class TransparentLlm(ABC): |
|
""" |
|
An abstract stateful interface for a language model. The model is supposed to be |
|
loaded at the class initialization. |
|
|
|
The internal state is the resulting tensors from the last call of the `run` method. |
|
Most of the methods could return values based on the state, but some may do cheap |
|
computations based on them. |
|
""" |
|
|
|
@abstractmethod |
|
def model_info(self) -> ModelInfo: |
|
""" |
|
Gives general info about the model. This method must be available before any |
|
calls of the `run`. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def run(self, sentences: List[str]) -> None: |
|
""" |
|
Run the inference on the given sentences in a single batch and store all |
|
necessary info in the internal state. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def batch_size(self) -> int: |
|
""" |
|
The size of the batch that was used for the last call of `run`. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def tokens(self) -> Int[torch.Tensor, "batch pos"]: |
|
pass |
|
|
|
@abstractmethod |
|
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]: |
|
pass |
|
|
|
@abstractmethod |
|
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]: |
|
pass |
|
|
|
@abstractmethod |
|
def unembed( |
|
self, |
|
t: Float[torch.Tensor, "d_model"], |
|
normalize: bool, |
|
) -> Float[torch.Tensor, "vocab"]: |
|
""" |
|
Project the given vector (for example, the state of the residual stream for a |
|
layer and token) into the output vocabulary. |
|
|
|
normalize: whether to apply the final normalization before the unembedding. |
|
Setting it to True and applying to output of the last layer gives the output of |
|
the model. |
|
""" |
|
pass |
|
|
|
|
|
|
|
@abstractmethod |
|
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: |
|
""" |
|
The state of the residual stream before entering the layer. For example, when |
|
layer == 0 these must the embedded tokens (including positional embedding). |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def residual_after_attn( |
|
self, layer: int |
|
) -> Float[torch.Tensor, "batch pos d_model"]: |
|
""" |
|
The state of the residual stream after attention, but before the FFN in the |
|
given layer. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: |
|
""" |
|
The state of the residual stream after the given layer. This is equivalent to the |
|
next layer's input. |
|
""" |
|
pass |
|
|
|
|
|
|
|
@abstractmethod |
|
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: |
|
""" |
|
The output of the FFN layer, before it gets merged into the residual stream. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def decomposed_ffn_out( |
|
self, |
|
batch_i: int, |
|
layer: int, |
|
pos: int, |
|
) -> Float[torch.Tensor, "hidden d_model"]: |
|
""" |
|
A collection of vectors added to the residual stream by each neuron. It should |
|
be the same as neuron activations multiplied by neuron outputs. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def neuron_activations( |
|
self, |
|
batch_i: int, |
|
layer: int, |
|
pos: int, |
|
) -> Float[torch.Tensor, "d_ffn"]: |
|
""" |
|
The content of the hidden layer right after the activation function was applied. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def neuron_output( |
|
self, |
|
layer: int, |
|
neuron: int, |
|
) -> Float[torch.Tensor, "d_model"]: |
|
""" |
|
Return the value that the given neuron adds to the residual stream. It's a raw |
|
vector from the model parameters, no activation involved. |
|
""" |
|
pass |
|
|
|
|
|
|
|
@abstractmethod |
|
def attention_matrix( |
|
self, batch_i, layer: int, head: int |
|
) -> Float[torch.Tensor, "query_pos key_pos"]: |
|
""" |
|
Return a lower-diagonal attention matrix. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def attention_output( |
|
self, |
|
batch_i: int, |
|
layer: int, |
|
pos: int, |
|
head: int, |
|
) -> Float[torch.Tensor, "d_model"]: |
|
""" |
|
Return what the given head at the given layer and pos added to the residual |
|
stream. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def decomposed_attn( |
|
self, batch_i: int, layer: int |
|
) -> Float[torch.Tensor, "source target head d_model"]: |
|
""" |
|
Here |
|
- source: index of token from the previous layer |
|
- target: index of token on the current layer |
|
The decomposed attention tells what vector from source representation was used |
|
in order to contribute to the taget representation. |
|
""" |
|
pass |
|
|