mahnerak's picture
Initial Commit πŸš€
ce00289
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional
import torch
import transformer_lens
import transformers
from fancy_einsum import einsum
from jaxtyping import Float, Int
from typeguard import typechecked
import streamlit as st
from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm
@dataclass
class _RunInfo:
tokens: Int[torch.Tensor, "batch pos"]
logits: Float[torch.Tensor, "batch pos d_vocab"]
cache: transformer_lens.ActivationCache
@st.cache_resource(
max_entries=1,
show_spinner=True,
hash_funcs={
transformers.PreTrainedModel: id,
transformers.PreTrainedTokenizer: id
}
)
def load_hooked_transformer(
model_name: str,
hf_model: Optional[transformers.PreTrainedModel] = None,
tlens_device: str = "cuda",
dtype: torch.dtype = torch.float32,
):
# if tlens_device == "cuda":
# n_devices = torch.cuda.device_count()
# else:
# n_devices = 1
tlens_model = transformer_lens.HookedTransformer.from_pretrained(
model_name,
hf_model=hf_model,
fold_ln=False, # Keep layer norm where it is.
center_writing_weights=False,
center_unembed=False,
device=tlens_device,
# n_devices=n_devices,
dtype=dtype,
)
tlens_model.eval()
return tlens_model
# TODO(igortufanov): If we want to scale the app to multiple users, we need more careful
# thread-safe implementation. The simplest option could be to wrap the existing methods
# in mutexes.
class TransformerLensTransparentLlm(TransparentLlm):
"""
Implementation of Transparent LLM based on transformer lens.
Args:
- model_name: The official name of the model from HuggingFace. Even if the model was
patched or loaded locally, the name should still be official because that's how
transformer_lens treats the model.
- hf_model: The language model as a HuggingFace class.
- tokenizer,
- device: "gpu" or "cpu"
"""
def __init__(
self,
model_name: str,
hf_model: Optional[transformers.PreTrainedModel] = None,
tokenizer: Optional[transformers.PreTrainedTokenizer] = None,
device: str = "gpu",
dtype: torch.dtype = torch.float32,
):
if device == "gpu":
self.device = "cuda"
if not torch.cuda.is_available():
RuntimeError("Asked to run on gpu, but torch couldn't find cuda")
elif device == "cpu":
self.device = "cpu"
else:
raise RuntimeError(f"Specified device {device} is not a valid option")
self.dtype = dtype
self.hf_tokenizer = tokenizer
self.hf_model = hf_model
# self._model = tlens_model
self._model_name = model_name
self._prepend_bos = True
self._last_run = None
self._run_exception = RuntimeError(
"Tried to use the model output before calling the `run` method"
)
def copy(self):
import copy
return copy.copy(self)
@property
def _model(self):
tlens_model = load_hooked_transformer(
self._model_name,
hf_model=self.hf_model,
tlens_device=self.device,
dtype=self.dtype,
)
if self.hf_tokenizer is not None:
tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left")
tlens_model.set_use_attn_result(True)
tlens_model.set_use_attn_in(False)
tlens_model.set_use_split_qkv_input(False)
return tlens_model
def model_info(self) -> ModelInfo:
cfg = self._model.cfg
return ModelInfo(
name=self._model_name,
n_params_estimate=cfg.n_params,
n_layers=cfg.n_layers,
n_heads=cfg.n_heads,
d_model=cfg.d_model,
d_vocab=cfg.d_vocab,
)
@torch.no_grad()
def run(self, sentences: List[str]) -> None:
tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos)
logits, cache = self._model.run_with_cache(tokens)
self._last_run = _RunInfo(
tokens=tokens,
logits=logits,
cache=cache,
)
def batch_size(self) -> int:
if not self._last_run:
raise self._run_exception
return self._last_run.logits.shape[0]
@typechecked
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
if not self._last_run:
raise self._run_exception
return self._last_run.tokens
@typechecked
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
return self._model.to_str_tokens(tokens)
@typechecked
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
if not self._last_run:
raise self._run_exception
return self._last_run.logits
@torch.no_grad()
@typechecked
def unembed(
self,
t: Float[torch.Tensor, "d_model"],
normalize: bool,
) -> Float[torch.Tensor, "vocab"]:
# t: [d_model] -> [batch, pos, d_model]
tdim = t.unsqueeze(0).unsqueeze(0)
if normalize:
normalized = self._model.ln_final(tdim)
result = self._model.unembed(normalized)
else:
result = self._model.unembed(tdim)
return result[0][0]
def _get_block(self, layer: int, block_name: str) -> str:
if not self._last_run:
raise self._run_exception
return self._last_run.cache[f"blocks.{layer}.{block_name}"]
# ================= Methods related to the residual stream =================
@typechecked
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_resid_pre")
@typechecked
def residual_after_attn(
self, layer: int
) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_resid_mid")
@typechecked
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_resid_post")
# ================ Methods related to the feed-forward layer ===============
@typechecked
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
if not self._last_run:
raise self._run_exception
return self._get_block(layer, "hook_mlp_out")
@torch.no_grad()
@typechecked
def decomposed_ffn_out(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "hidden d_model"]:
# Take activations right before they're multiplied by W_out, i.e. non-linearity
# and layer norm are already applied.
processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos]
return torch.mul(processed_activations.unsqueeze(-1), self._model.W_out[layer])
@typechecked
def neuron_activations(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "hidden"]:
return self._get_block(layer, "mlp.hook_pre")[batch_i][pos]
@typechecked
def neuron_output(
self,
layer: int,
neuron: int,
) -> Float[torch.Tensor, "d_model"]:
return self._model.W_out[layer][neuron]
# ==================== Methods related to the attention ====================
@typechecked
def attention_matrix(
self, batch_i: int, layer: int, head: int
) -> Float[torch.Tensor, "query_pos key_pos"]:
return self._get_block(layer, "attn.hook_pattern")[batch_i][head]
@typechecked
def attention_output_per_head(
self,
batch_i: int,
layer: int,
pos: int,
head: int,
) -> Float[torch.Tensor, "d_model"]:
return self._get_block(layer, "attn.hook_result")[batch_i][pos][head]
@typechecked
def attention_output(
self,
batch_i: int,
layer: int,
pos: int,
) -> Float[torch.Tensor, "d_model"]:
return self._get_block(layer, "hook_attn_out")[batch_i][pos]
@torch.no_grad()
@typechecked
def decomposed_attn(
self, batch_i: int, layer: int
) -> Float[torch.Tensor, "pos key_pos head d_model"]:
if not self._last_run:
raise self._run_exception
hook_v = self._get_block(layer, "attn.hook_v")[batch_i]
b_v = self._model.b_V[layer]
v = hook_v + b_v
pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype)
z = einsum(
"key_pos head d_head, "
"head query_pos key_pos -> "
"query_pos key_pos head d_head",
v,
pattern,
)
decomposed_attn = einsum(
"pos key_pos head d_head, "
"head d_head d_model -> "
"pos key_pos head d_model",
z,
self._model.W_O[layer],
)
return decomposed_attn