from typing import List, Optional, Tuple, Union, Callable, Any import torch import torch.nn as nn import torch.nn.functional as F try: from configuration_llama3_SAE import LLama3_SAE_Config except: from .configuration_llama3_SAE import LLama3_SAE_Config from transformers import ( LlamaPreTrainedModel, LlamaModel, ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.cache_utils import Cache from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class LLama3_SAE(LlamaPreTrainedModel): config_class = LLama3_SAE_Config _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: LLama3_SAE_Config): super().__init__(config) self.model = LlamaModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.activation == "topk": if isinstance(config.activation_k, int): activation = TopK(torch.tensor(config.activation_k)) else: activation = TopK(config.activation_k) elif config.activation == "topk-tanh": if isinstance(config.activation_k, int): activation = TopK(torch.tensor(config.activation_k), nn.Tanh()) else: activation = TopK(config.activation_k, nn.Tanh()) elif config.activation == "topk-sigmoid": if isinstance(config.activation_k, int): activation = TopK(torch.tensor(config.activation_k), nn.Sigmoid()) else: activation = TopK(config.activation_k, nn.Sigmoid()) elif config.activation == "jumprelu": activation = JumpReLu() elif config.activation == "relu": activation = "ReLU" elif config.activation == "identity": activation = "Identity" else: raise ( NotImplementedError, f"Activation '{config.activation}' not implemented.", ) self.SAE = Autoencoder( n_inputs=config.n_inputs, n_latents=config.n_latents, activation=activation, tied=False, normalize=True, ) self.hook = HookedTransformer_with_SAE_suppresion( block=config.hook_block_num, sae=self.SAE, mod_features=config.mod_features, mod_threshold=config.mod_threshold, mod_replacement=config.mod_replacement, mod_scaling=config.mod_scaling, ).register_with(self.model, config.site) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split( self.vocab_size // self.config.pretraining_tp, dim=0 ) logits = [ F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp) ] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss(reduction="none") shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) loss = loss.view(logits.size(0), -1) mask = loss != 0 loss = loss.sum(dim=-1) / mask.sum(dim=-1) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=True, **kwargs, ): past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): past_length = ( cache_position[0] if cache_position is not None else past_key_values.get_seq_length() ) max_cache_length = ( torch.tensor( past_key_values.get_max_length(), device=input_ids.device ) if past_key_values.get_max_length() is not None else None ) cache_length = ( past_length if max_cache_length is None else torch.min(max_cache_length, past_length) ) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if ( attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1] ): input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} input_length = ( position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] ) if cache_position is None: cache_position = torch.arange( past_length, past_length + input_length, device=input_ids.device ) elif use_cache: cache_position = cache_position[-input_length:] model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ), ) return reordered_past def LN( x: torch.Tensor, eps: float = 1e-5 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu = x.mean(dim=-1, keepdim=True) x = x - mu std = x.std(dim=-1, keepdim=True) x = x / (std + eps) return x, mu, std class Autoencoder(nn.Module): """Sparse autoencoder Implements: latents = activation(encoder(x - pre_bias) + latent_bias) recons = decoder(latents) + pre_bias """ def __init__( self, n_latents: int, n_inputs: int, activation: Callable = nn.ReLU(), tied: bool = False, normalize: bool = False, ) -> None: """ :param n_latents: dimension of the autoencoder latent :param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units) :param activation: activation function :param tied: whether to tie the encoder and decoder weights """ super().__init__() self.n_inputs = n_inputs self.n_latents = n_latents self.pre_bias = nn.Parameter(torch.zeros(n_inputs)) self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False) self.latent_bias = nn.Parameter(torch.zeros(n_latents)) self.activation = activation if isinstance(activation, JumpReLu): self.threshold = nn.Parameter(torch.empty(n_latents)) torch.nn.init.constant_(self.threshold, 0.001) self.forward = self.forward_jumprelu elif isinstance(activation, TopK): self.forward = self.forward_topk else: logger.warning( f"Using TopK forward function even if activation is not TopK, but is {activation}" ) self.forward = self.forward_topk if tied: # self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder) self.decoder = nn.Linear(n_latents, n_inputs, bias=False) self.decoder.weight.data = self.encoder.weight.data.T.clone() else: self.decoder = nn.Linear(n_latents, n_inputs, bias=False) self.normalize = normalize def encode_pre_act( self, x: torch.Tensor, latent_slice: slice = slice(None) ) -> torch.Tensor: """ :param x: input data (shape: [batch, n_inputs]) :param latent_slice: slice of latents to compute Example: latent_slice = slice(0, 10) to compute only the first 10 latents. :return: autoencoder latents before activation (shape: [batch, n_latents]) """ x = x - self.pre_bias latents_pre_act = F.linear( x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice] ) return latents_pre_act def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: if not self.normalize: return x, dict() x, mu, std = LN(x) return x, dict(mu=mu, std=std) def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: """ :param x: input data (shape: [batch, n_inputs]) :return: autoencoder latents (shape: [batch, n_latents]) """ x, info = self.preprocess(x) return self.activation(self.encode_pre_act(x)), info def decode( self, latents: torch.Tensor, info: dict[str, Any] | None = None ) -> torch.Tensor: """ :param latents: autoencoder latents (shape: [batch, n_latents]) :return: reconstructed data (shape: [batch, n_inputs]) """ ret = self.decoder(latents) + self.pre_bias if self.normalize: assert info is not None ret = ret * info["std"] + info["mu"] return ret def forward_topk( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ :param x: input data (shape: [batch, n_inputs]) :return: autoencoder latents pre activation (shape: [batch, n_latents]) autoencoder latents (shape: [batch, n_latents]) reconstructed data (shape: [batch, n_inputs]) """ x, info = self.preprocess(x) latents_pre_act = self.encode_pre_act(x) latents = self.activation(latents_pre_act) recons = self.decode(latents, info) return latents_pre_act, latents, recons def forward_jumprelu( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ :param x: input data (shape: [batch, n_inputs]) :return: autoencoder latents pre activation (shape: [batch, n_latents]) autoencoder latents (shape: [batch, n_latents]) reconstructed data (shape: [batch, n_inputs]) """ x, info = self.preprocess(x) latents_pre_act = self.encode_pre_act(x) latents = self.activation(F.relu(latents_pre_act), torch.exp(self.threshold)) recons = self.decode(latents, info) return latents_pre_act, latents, recons class TiedTranspose(nn.Module): def __init__(self, linear: nn.Linear): super().__init__() self.linear = linear def forward(self, x: torch.Tensor) -> torch.Tensor: assert self.linear.bias is None # torch.nn.parameter.Parameter(layer_e.weights.T) return F.linear(x, self.linear.weight.t(), None) @property def weight(self) -> torch.Tensor: return self.linear.weight.t() @property def bias(self) -> torch.Tensor: return self.linear.bias class TopK(nn.Module): def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None: super().__init__() self.k = k self.postact_fn = postact_fn def forward(self, x: torch.Tensor) -> torch.Tensor: topk = torch.topk(x, k=self.k, dim=-1) values = self.postact_fn(topk.values) # make all other values 0 result = torch.zeros_like(x) result.scatter_(-1, topk.indices, values) return result class JumpReLu(nn.Module): def __init__(self): super().__init__() def forward(self, input, threshold): return JumpReLUFunction.apply(input, threshold) class HeavyStep(nn.Module): def __init__(self): super().__init__() def forward(self, input, threshold): return HeavyStepFunction.apply(input, threshold) def rectangle(x): return (x > -0.5) & (x < 0.5) class JumpReLUFunction(torch.autograd.Function): @staticmethod def forward(input, threshold): output = input * (input > threshold) return output @staticmethod def setup_context(ctx, inputs, output): input, threshold = inputs ctx.save_for_backward(input, threshold) @staticmethod def backward(ctx, grad_output): bandwidth = 0.001 # bandwidth = 0.0001 input, threshold = ctx.saved_tensors grad_input = grad_threshold = None grad_input = input > threshold grad_threshold = ( -(threshold / bandwidth) * rectangle((input - threshold) / bandwidth) * grad_output ) return grad_input, grad_threshold class HeavyStepFunction(torch.autograd.Function): @staticmethod def forward(input, threshold): output = input * threshold return output @staticmethod def setup_context(ctx, inputs, output): input, threshold = inputs ctx.save_for_backward(input, threshold) @staticmethod def backward(ctx, grad_output): bandwidth = 0.001 # bandwidth = 0.0001 input, threshold = ctx.saved_tensors grad_input = grad_threshold = None grad_input = torch.zeros_like(input) grad_threshold = ( -(1.0 / bandwidth) * rectangle((input - threshold) / bandwidth) * grad_output ) return grad_input, grad_threshold ACTIVATIONS_CLASSES = { "ReLU": nn.ReLU, "Identity": nn.Identity, "TopK": TopK, "JumpReLU": JumpReLu, } class HookedTransformer_with_SAE: """Auxilliary class used to extract mlp activations from transformer models.""" def __init__(self, block: int, sae) -> None: self.block = block self.sae = sae self.remove_handle = ( None # Can be used to remove this hook from the model again ) self._features = None def register_with(self, model): # At the moment only activations from Feed Forward MLP layer self.remove_handle = model.layers[self.block].mlp.register_forward_hook(self) return self def pop(self) -> torch.Tensor: """Remove and return extracted feature from this hook. We only allow access to the features this way to not have any lingering references to them. """ assert self._features is not None, "Feature extractor was not called yet!" features = self._features self._features = None return features def __call__(self, module, inp, outp) -> None: self._features = outp return self.sae(outp)[2] class HookedTransformer_with_SAE_suppresion: """Auxilliary class used to extract mlp activations from transformer models.""" def __init__( self, block: int, sae: Autoencoder, mod_features: list = None, mod_threshold: list = None, mod_replacement: list = None, mod_scaling: list = None, mod_balance: bool = False, multi_feature: bool = False, ) -> None: self.block = block self.sae = sae self.remove_handle = ( None # Can be used to remove this hook from the model again ) self._features = None self.mod_features = mod_features self.mod_threshold = mod_threshold self.mod_replacement = mod_replacement self.mod_scaling = mod_scaling self.mod_balance = mod_balance self.mod_vector = None self.mod_vec_factor = 1.0 if multi_feature: self.modify = self.modify_list else: self.modify = self.modify_single if isinstance(self.sae.activation, JumpReLu): logger.info("Setting __call__ function for JumpReLU.") setattr(self, "call", self.__call__jumprelu) elif isinstance(self.sae.activation, TopK): logger.info("Setting __call__ function for TopK.") setattr(self, "call", self.__call__topk) else: logger.warning( f"Using TopK forward function even if activation is not TopK, but is {self.sae.activation}" ) setattr(self, "call", self.__call__topk) def register_with(self, model, site="mlp"): self.site = site # Decision on where to extract activations from if site == "mlp": # output of the FF module of block self.remove_handle = model.layers[self.block].mlp.register_forward_hook( self ) elif ( site == "block" ): # output of the residual connection AFTER it is added to the FF output self.remove_handle = model.layers[self.block].register_forward_hook(self) elif site == "attention": raise NotImplementedError else: raise NotImplementedError # self.remove_handle = model.model.layers[self.block].mlp.act_fn.register_forward_hook(self) return self def modify_list(self, latents: torch.Tensor) -> torch.Tensor: if self.mod_replacement is not None: for feat, thresh, mod in zip( self.mod_features, self.mod_threshold, self.mod_replacement ): latents[:, :, feat][latents[:, :, feat] > thresh] = mod elif self.mod_scaling is not None: for feat, thresh, mod in zip( self.mod_features, self.mod_threshold, self.mod_scaling ): latents[:, :, feat][latents[:, :, feat] > thresh] *= mod elif self.mod_vector is not None: latents = latents + self.mod_vec_factor * self.mod_vector else: pass return latents def modify_single(self, latents: torch.Tensor) -> torch.Tensor: old_cond_feats = latents[:, :, self.mod_features] if self.mod_replacement is not None: # latents[:, :, self.mod_features][ # latents[:, :, self.mod_features] > self.mod_threshold # ] = self.mod_replacement latents[:, :, self.mod_features] = self.mod_replacement elif self.mod_scaling is not None: latents_scaled = latents.clone() latents_scaled[:, :, self.mod_features][ latents[:, :, self.mod_features] > 0 ] *= self.mod_scaling latents_scaled[:, :, self.mod_features][ latents[:, :, self.mod_features] < 0 ] *= -1 * self.mod_scaling latents = latents_scaled # latents[:, :, self.mod_features] *= self.mod_scaling elif self.mod_vector is not None: latents = latents + self.mod_vec_factor * self.mod_vector else: pass if self.mod_balance: # logger.warning("The balancing does not work yet!!!") # TODO: Look into it more closely, not sure if this is correct num_feat = latents.shape[2] - 1 diff = old_cond_feats - latents[:, :, self.mod_features] if self.mod_features != 0: latents[:, :, : self.mod_features] += (diff / num_feat)[:, :, None] latents[:, :, self.mod_features + 1 :] += (diff / num_feat)[:, :, None] return latents def pop(self) -> torch.Tensor: """Remove and return extracted feature from this hook. We only allow access to the features this way to not have any lingering references to them. """ assert self._features is not None, "Feature extractor was not called yet!" if isinstance(self._features, tuple): features = self._features[0] else: features = self._features self._features = None return features def __call__topk(self, module, inp, outp) -> torch.Tensor: self._features = outp if isinstance(self._features, tuple): features = self._features[0] else: features = self._features if self.mod_features is None: recons = features else: x, info = self.sae.preprocess(features) latents_pre_act = self.sae.encode_pre_act(x) latents = self.sae.activation(latents_pre_act) # latents[:, :, self.mod_features] = F.sigmoid( # latents_pre_act[:, :, self.mod_features] # ) # latents[:, :, self.mod_features] = torch.abs(latents_pre_act[:, :, self.mod_features]) # latents[:, :, self.mod_features] = latents_pre_act[:, :, self.mod_features] mod_latents = self.modify(latents) # mod_latents[:, :, self.mod_features] = F.sigmoid( # mod_latents[:, :, self.mod_features] # ) recons = self.sae.decode(mod_latents, info) if isinstance(self._features, tuple): outp = list(outp) outp[0] = recons return tuple(outp) else: return recons def __call__jumprelu(self, module, inp, outp) -> torch.Tensor: self._features = outp if self.mod_features is None: recons = outp else: x, info = self.sae.preprocess(outp) latents_pre_act = self.sae.encode_pre_act(x) latents = self.sae.activation( F.relu(latents_pre_act), torch.exp(self.sae.threshold) ) latents[:, :, self.mod_features] = latents_pre_act[:, :, self.mod_features] mod_latents = self.modify(latents) recons = self.sae.decode(mod_latents, info) return recons def __call__(self, module, inp, outp) -> torch.Tensor: return self.call(module, inp, outp)