mahnerak's picture
Initial Commit πŸš€
ce00289
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
import torch
from jaxtyping import Float, Int
@dataclass
class ModelInfo:
name: str
# Not the actual number of parameters, but rather the order of magnitude
n_params_estimate: int
n_layers: int
n_heads: int
d_model: int
d_vocab: int
class TransparentLlm(ABC):
"""
An abstract stateful interface for a language model. The model is supposed to be
loaded at the class initialization.
The internal state is the resulting tensors from the last call of the `run` method.
Most of the methods could return values based on the state, but some may do cheap
computations based on them.
"""
@abstractmethod
def model_info(self) -> ModelInfo:
"""
Gives general info about the model. This method must be available before any
calls of the `run`.
"""
pass
@abstractmethod
def run(self, sentences: List[str]) -> None:
"""
Run the inference on the given sentences in a single batch and store all
necessary info in the internal state.
"""
pass
@abstractmethod
def batch_size(self) -> int:
"""
The size of the batch that was used for the last call of `run`.
"""
pass
@abstractmethod
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
pass
@abstractmethod
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
pass
@abstractmethod
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
pass
@abstractmethod
def unembed(
self,
t: Float[torch.Tensor, "d_model"],
normalize: bool,
) -> Float[torch.Tensor, "vocab"]:
"""
Project the given vector (for example, the state of the residual stream for a
layer and token) into the output vocabulary.
normalize: whether to apply the final normalization before the unembedding.
Setting it to True and applying to output of the last layer gives the output of
the model.
"""
pass
# ================= Methods related to the residual stream =================
@abstractmethod
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream before entering the layer. For example, when
layer == 0 these must the embedded tokens (including positional embedding).
"""
pass
@abstractmethod
def residual_after_attn(
self, layer: int
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream after attention, but before the FFN in the
given layer.
"""
pass
@abstractmethod
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The state of the residual stream after the given layer. This is equivalent to the
next layer's input.
"""
pass
# ================ Methods related to the feed-forward layer ===============
@abstractmethod
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
"""
The output of the FFN layer, before it gets merged into the residual stream.
"""
pass
@abstractmethod
def decomposed_ffn_out(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "hidden d_model"]:
"""
A collection of vectors added to the residual stream by each neuron. It should
be the same as neuron activations multiplied by neuron outputs.
"""
pass
@abstractmethod
def neuron_activations(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "d_ffn"]:
"""
The content of the hidden layer right after the activation function was applied.
"""
pass
@abstractmethod
def neuron_output(
self,
layer: int,
neuron: int,
) -> Float[torch.Tensor, "d_model"]:
"""
Return the value that the given neuron adds to the residual stream. It's a raw
vector from the model parameters, no activation involved.
"""
pass
# ==================== Methods related to the attention ====================
@abstractmethod
def attention_matrix(
self, batch_i, layer: int, head: int
) -> Float[torch.Tensor, "query_pos key_pos"]:
"""
Return a lower-diagonal attention matrix.
"""
pass
@abstractmethod
def attention_output(
self,
batch_i: int,
layer: int,
pos: int,
head: int,
) -> Float[torch.Tensor, "d_model"]:
"""
Return what the given head at the given layer and pos added to the residual
stream.
"""
pass
@abstractmethod
def decomposed_attn(
self, batch_i: int, layer: int
) -> Float[torch.Tensor, "source target head d_model"]:
"""
Here
- source: index of token from the previous layer
- target: index of token on the current layer
The decomposed attention tells what vector from source representation was used
in order to contribute to the taget representation.
"""
pass