# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List import torch from jaxtyping import Float, Int @dataclass class ModelInfo: name: str # Not the actual number of parameters, but rather the order of magnitude n_params_estimate: int n_layers: int n_heads: int d_model: int d_vocab: int class TransparentLlm(ABC): """ An abstract stateful interface for a language model. The model is supposed to be loaded at the class initialization. The internal state is the resulting tensors from the last call of the `run` method. Most of the methods could return values based on the state, but some may do cheap computations based on them. """ @abstractmethod def model_info(self) -> ModelInfo: """ Gives general info about the model. This method must be available before any calls of the `run`. """ pass @abstractmethod def run(self, sentences: List[str]) -> None: """ Run the inference on the given sentences in a single batch and store all necessary info in the internal state. """ pass @abstractmethod def batch_size(self) -> int: """ The size of the batch that was used for the last call of `run`. """ pass @abstractmethod def tokens(self) -> Int[torch.Tensor, "batch pos"]: pass @abstractmethod def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]: pass @abstractmethod def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]: pass @abstractmethod def unembed( self, t: Float[torch.Tensor, "d_model"], normalize: bool, ) -> Float[torch.Tensor, "vocab"]: """ Project the given vector (for example, the state of the residual stream for a layer and token) into the output vocabulary. normalize: whether to apply the final normalization before the unembedding. Setting it to True and applying to output of the last layer gives the output of the model. """ pass # ================= Methods related to the residual stream ================= @abstractmethod def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: """ The state of the residual stream before entering the layer. For example, when layer == 0 these must the embedded tokens (including positional embedding). """ pass @abstractmethod def residual_after_attn( self, layer: int ) -> Float[torch.Tensor, "batch pos d_model"]: """ The state of the residual stream after attention, but before the FFN in the given layer. """ pass @abstractmethod def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: """ The state of the residual stream after the given layer. This is equivalent to the next layer's input. """ pass # ================ Methods related to the feed-forward layer =============== @abstractmethod def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: """ The output of the FFN layer, before it gets merged into the residual stream. """ pass @abstractmethod def decomposed_ffn_out( self, batch_i: int, layer: int, pos: int, ) -> Float[torch.Tensor, "hidden d_model"]: """ A collection of vectors added to the residual stream by each neuron. It should be the same as neuron activations multiplied by neuron outputs. """ pass @abstractmethod def neuron_activations( self, batch_i: int, layer: int, pos: int, ) -> Float[torch.Tensor, "d_ffn"]: """ The content of the hidden layer right after the activation function was applied. """ pass @abstractmethod def neuron_output( self, layer: int, neuron: int, ) -> Float[torch.Tensor, "d_model"]: """ Return the value that the given neuron adds to the residual stream. It's a raw vector from the model parameters, no activation involved. """ pass # ==================== Methods related to the attention ==================== @abstractmethod def attention_matrix( self, batch_i, layer: int, head: int ) -> Float[torch.Tensor, "query_pos key_pos"]: """ Return a lower-diagonal attention matrix. """ pass @abstractmethod def attention_output( self, batch_i: int, layer: int, pos: int, head: int, ) -> Float[torch.Tensor, "d_model"]: """ Return what the given head at the given layer and pos added to the residual stream. """ pass @abstractmethod def decomposed_attn( self, batch_i: int, layer: int ) -> Float[torch.Tensor, "source target head d_model"]: """ Here - source: index of token from the previous layer - target: index of token on the current layer The decomposed attention tells what vector from source representation was used in order to contribute to the taget representation. """ pass