# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from typing import List, Optional import torch import transformer_lens import transformers from fancy_einsum import einsum from jaxtyping import Float, Int from typeguard import typechecked import streamlit as st from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm @dataclass class _RunInfo: tokens: Int[torch.Tensor, "batch pos"] logits: Float[torch.Tensor, "batch pos d_vocab"] cache: transformer_lens.ActivationCache @st.cache_resource( max_entries=1, show_spinner=True, hash_funcs={ transformers.PreTrainedModel: id, transformers.PreTrainedTokenizer: id } ) def load_hooked_transformer( model_name: str, hf_model: Optional[transformers.PreTrainedModel] = None, tlens_device: str = "cuda", dtype: torch.dtype = torch.float32, ): # if tlens_device == "cuda": # n_devices = torch.cuda.device_count() # else: # n_devices = 1 tlens_model = transformer_lens.HookedTransformer.from_pretrained( model_name, hf_model=hf_model, fold_ln=False, # Keep layer norm where it is. center_writing_weights=False, center_unembed=False, device=tlens_device, # n_devices=n_devices, dtype=dtype, ) tlens_model.eval() return tlens_model # TODO(igortufanov): If we want to scale the app to multiple users, we need more careful # thread-safe implementation. The simplest option could be to wrap the existing methods # in mutexes. class TransformerLensTransparentLlm(TransparentLlm): """ Implementation of Transparent LLM based on transformer lens. Args: - model_name: The official name of the model from HuggingFace. Even if the model was patched or loaded locally, the name should still be official because that's how transformer_lens treats the model. - hf_model: The language model as a HuggingFace class. - tokenizer, - device: "gpu" or "cpu" """ def __init__( self, model_name: str, hf_model: Optional[transformers.PreTrainedModel] = None, tokenizer: Optional[transformers.PreTrainedTokenizer] = None, device: str = "gpu", dtype: torch.dtype = torch.float32, ): if device == "gpu": self.device = "cuda" if not torch.cuda.is_available(): RuntimeError("Asked to run on gpu, but torch couldn't find cuda") elif device == "cpu": self.device = "cpu" else: raise RuntimeError(f"Specified device {device} is not a valid option") self.dtype = dtype self.hf_tokenizer = tokenizer self.hf_model = hf_model # self._model = tlens_model self._model_name = model_name self._prepend_bos = True self._last_run = None self._run_exception = RuntimeError( "Tried to use the model output before calling the `run` method" ) def copy(self): import copy return copy.copy(self) @property def _model(self): tlens_model = load_hooked_transformer( self._model_name, hf_model=self.hf_model, tlens_device=self.device, dtype=self.dtype, ) if self.hf_tokenizer is not None: tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left") tlens_model.set_use_attn_result(True) tlens_model.set_use_attn_in(False) tlens_model.set_use_split_qkv_input(False) return tlens_model def model_info(self) -> ModelInfo: cfg = self._model.cfg return ModelInfo( name=self._model_name, n_params_estimate=cfg.n_params, n_layers=cfg.n_layers, n_heads=cfg.n_heads, d_model=cfg.d_model, d_vocab=cfg.d_vocab, ) @torch.no_grad() def run(self, sentences: List[str]) -> None: tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos) logits, cache = self._model.run_with_cache(tokens) self._last_run = _RunInfo( tokens=tokens, logits=logits, cache=cache, ) def batch_size(self) -> int: if not self._last_run: raise self._run_exception return self._last_run.logits.shape[0] @typechecked def tokens(self) -> Int[torch.Tensor, "batch pos"]: if not self._last_run: raise self._run_exception return self._last_run.tokens @typechecked def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]: return self._model.to_str_tokens(tokens) @typechecked def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]: if not self._last_run: raise self._run_exception return self._last_run.logits @torch.no_grad() @typechecked def unembed( self, t: Float[torch.Tensor, "d_model"], normalize: bool, ) -> Float[torch.Tensor, "vocab"]: # t: [d_model] -> [batch, pos, d_model] tdim = t.unsqueeze(0).unsqueeze(0) if normalize: normalized = self._model.ln_final(tdim) result = self._model.unembed(normalized) else: result = self._model.unembed(tdim) return result[0][0] def _get_block(self, layer: int, block_name: str) -> str: if not self._last_run: raise self._run_exception return self._last_run.cache[f"blocks.{layer}.{block_name}"] # ================= Methods related to the residual stream ================= @typechecked def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: if not self._last_run: raise self._run_exception return self._get_block(layer, "hook_resid_pre") @typechecked def residual_after_attn( self, layer: int ) -> Float[torch.Tensor, "batch pos d_model"]: if not self._last_run: raise self._run_exception return self._get_block(layer, "hook_resid_mid") @typechecked def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: if not self._last_run: raise self._run_exception return self._get_block(layer, "hook_resid_post") # ================ Methods related to the feed-forward layer =============== @typechecked def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]: if not self._last_run: raise self._run_exception return self._get_block(layer, "hook_mlp_out") @torch.no_grad() @typechecked def decomposed_ffn_out( self, batch_i: int, layer: int, pos: int, ) -> Float[torch.Tensor, "hidden d_model"]: # Take activations right before they're multiplied by W_out, i.e. non-linearity # and layer norm are already applied. processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos] return torch.mul(processed_activations.unsqueeze(-1), self._model.W_out[layer]) @typechecked def neuron_activations( self, batch_i: int, layer: int, pos: int, ) -> Float[torch.Tensor, "hidden"]: return self._get_block(layer, "mlp.hook_pre")[batch_i][pos] @typechecked def neuron_output( self, layer: int, neuron: int, ) -> Float[torch.Tensor, "d_model"]: return self._model.W_out[layer][neuron] # ==================== Methods related to the attention ==================== @typechecked def attention_matrix( self, batch_i: int, layer: int, head: int ) -> Float[torch.Tensor, "query_pos key_pos"]: return self._get_block(layer, "attn.hook_pattern")[batch_i][head] @typechecked def attention_output_per_head( self, batch_i: int, layer: int, pos: int, head: int, ) -> Float[torch.Tensor, "d_model"]: return self._get_block(layer, "attn.hook_result")[batch_i][pos][head] @typechecked def attention_output( self, batch_i: int, layer: int, pos: int, ) -> Float[torch.Tensor, "d_model"]: return self._get_block(layer, "hook_attn_out")[batch_i][pos] @torch.no_grad() @typechecked def decomposed_attn( self, batch_i: int, layer: int ) -> Float[torch.Tensor, "pos key_pos head d_model"]: if not self._last_run: raise self._run_exception hook_v = self._get_block(layer, "attn.hook_v")[batch_i] b_v = self._model.b_V[layer] v = hook_v + b_v pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype) z = einsum( "key_pos head d_head, " "head query_pos key_pos -> " "query_pos key_pos head d_head", v, pattern, ) decomposed_attn = einsum( "pos key_pos head d_head, " "head d_head d_model -> " "pos key_pos head d_model", z, self._model.W_O[layer], ) return decomposed_attn