import math import warnings import hashlib import os from typing import List, Optional, Tuple, Union from dataclasses import dataclass import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel, LlamaForCausalLM from transformers.models.gemma.modeling_gemma import GemmaModel, GemmaPreTrainedModel, GemmaForCausalLM from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, Qwen2PreTrainedModel, Qwen2ForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.cache_utils import Cache from models.modeling_gemma import GemmaForCausalLM from models.modeling_qwen2 import Qwen2ForCausalLM def svd_with_cache(matrix, cache_dir, max_rank=1024): """ SVD with cache mechanism to avoid repeated SVD computation. SVD can be very slow for large matrices, so we cache the results. """ in_dim, out_dim = matrix.shape # slice_weight = matrix[::1000, :] # too sensitive to precision # weight_hash = hashlib.md5(slice_weight.detach().cpu().numpy().tobytes()).hexdigest() weight_hash = in_dim * out_dim cache_file = os.path.join(cache_dir, f'{weight_hash}.pt') if not os.path.exists(cache_dir): os.makedirs(cache_dir) if os.path.exists(cache_file): # Load cached SVD results U, S, Vh = torch.load(cache_file) else: # Perform SVD and cache the results U, S, Vh = torch.linalg.svd(matrix.float()) U = U[:, :max_rank].clone() # Shape: [out_features, rank] S = S[:max_rank].clone() # Shape: [rank] Vh = Vh[:max_rank, :].clone() # Shape: [rank, in_features] # Save the SVD results to cache torch.save((U, S, Vh), cache_file) return U, S, Vh def create_factorized_compression_for_linear(source_linear, rank, svd_cache_dir='experiment_cache/'): """ Adapt from: https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/cli_svd.py Create a factorized compression for a given linear layer using SVD. Args: source_linear (nn.Linear): The original linear layer to be compressed. rank (int, optional): The rank for the factorization. If None, it will be calculated based on rank_factor. rank_factor (float, optional): The factor to determine the rank if rank is not provided. Default is 0.3. Returns: nn.Sequential: A sequential container of the compressed linear layers. """ with torch.no_grad(): dtype = source_linear.weight.dtype # Check if the source linear layer has a bias term if hasattr(source_linear, 'bias'): bias = source_linear.bias else: bias = None # Calculate the total number of parameters in the source linear layer source_num_params = sum(param.numel() for param in source_linear.parameters()) # Get the weight matrix of the source linear layer source_linear_weight = source_linear.weight.data # Ensure rank is less than the minimum dimension of the weight matrix assert rank < min(source_linear_weight.shape) # Perform SVD on the weight matrix # U, S, Vh = torch.linalg.svd(source_linear_weight.float()) U, S, Vh = svd_with_cache(source_linear_weight, svd_cache_dir) # Truncate U, S, Vh to the specified rank U = U[:, :rank].contiguous() # Shape: [out_features, rank] S = S[:rank].contiguous() # Shape: [rank] Vh = Vh[:rank, :].contiguous() # Shape: [rank, in_features] # Incorporate singular values into U U = U @ torch.diag(S) # Shape: [out_features, rank] # Flatten U and Vh for quantile computation U_flatten = U.flatten() Vh_flatten = Vh.flatten() # Define the maximum quantization size max_quant_size = 2**23 # Compute high and low quantile values for clamping if len(U_flatten) + len(Vh_flatten) >= max_quant_size: dist2 = U_flatten[:min(len(U_flatten), max_quant_size)] dist3 = Vh_flatten[:min(len(Vh_flatten), max_quant_size)] hi_val = max(torch.quantile(dist3, 1), torch.quantile(dist2, 1)) else: dist = torch.cat([U_flatten, Vh_flatten]) hi_val = torch.quantile(dist, 1) low_val = -hi_val # Clamp U and Vh to the quantile values U = U.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val) # Create the down projection linear layer (Vh) lora_down = nn.Linear(Vh.shape[1], Vh.shape[0], dtype=dtype, bias=False, device=source_linear_weight.device) lora_down.weight.data = Vh.to(device=source_linear_weight.device, dtype=dtype) # Create the up projection linear layer (U) lora_up = nn.Linear(U.shape[1], U.shape[0], dtype=dtype, bias=bias is not None, device=source_linear_weight.device) lora_up.weight.data = U.to(device=source_linear_weight.device, dtype=dtype) # If the original linear layer had a bias, copy it to the up projection layer if bias is not None: lora_up.bias = nn.Parameter(bias.clone()) # Print compression ratio (for debugging purposes) #print('compression', sum(param.numel() for param in ret.parameters()) / source_num_params) return lora_down, lora_up @dataclass class AdaCausalLMOutputWithPast(CausalLMOutputWithPast): # keep original `loss` for `training_step` and `predictions_step`, # Add 3 sub losses: `lm_loss`, `mask_loss`, `topk_loss` # add `lm_head_logits` for original lm_head logits, which is optional (required for train and eval, not required for generation) lm_head_logits: Optional[torch.FloatTensor] = None lm_loss: Optional[torch.FloatTensor] = None mask_loss: Optional[torch.FloatTensor] = None topk_loss: Optional[torch.FloatTensor] = None class AdaVocabHead_MLP(nn.Module): # No improvement compare to LoRA solution def __init__(self, lm_head, sub_vocab_dim, activation_func=torch.nn.GELU()): hidden_size, vocab_size = lm_head.in_features, lm_head.out_features super().__init__() self.A = nn.Linear(hidden_size, sub_vocab_dim, bias=False) self.B = nn.Linear(sub_vocab_dim, sub_vocab_dim, bias=True) self.C = nn.Linear(sub_vocab_dim, vocab_size, bias=False) std_dev = 1 / math.sqrt(sub_vocab_dim) nn.init.normal_(self.A.weight, 0, std_dev) nn.init.normal_(self.B.weight, 0, std_dev) nn.init.zeros_(self.C.weight) self.activation_func = activation_func def forward(self, x): # x.shape: (..., hidden_size), # A.shape: (hidden_size, sub_vocab_dim) # B.shape: (sub_vocab_dim, sub_vocab_dim) # C.shape: (sub_vocab_dim, vocab_size) logits = self.A(x) # logits.shape: (..., sub_vocab_dim) logits = self.activation_func(logits) logits = self.B(logits) # logits.shape: (..., sub_vocab_dim) # logits = self.activation_func(logits) ada_vocab_logits = self.C(logits) # ada_vocab_logits.shape: (..., vocab_size) return ada_vocab_logits class AdaVocabHead_LORA(nn.Module): def __init__(self, lm_head, sub_vocab_dim, svd=False): hidden_size, vocab_size = lm_head.in_features, lm_head.out_features super().__init__() if svd: # SVD initialization self.A, self.B = create_factorized_compression_for_linear(lm_head, sub_vocab_dim) else: # Random initialization self.A = nn.Linear(hidden_size, sub_vocab_dim, bias=False) self.B = nn.Linear(sub_vocab_dim, vocab_size, bias=False) std_dev = 1 / math.sqrt(sub_vocab_dim) nn.init.normal_(self.A.weight, 0, std_dev) nn.init.zeros_(self.B.weight) def forward(self, x): # x.shape: (..., hidden_size), A.shape: (hidden_size, sub_vocab_dim), B.shape: (sub_vocab_dim, vocab_size) logits = self.A(x) ada_vocab_logits = self.B(logits) # ada_vocab_logits.shape: (..., vocab_size) return ada_vocab_logits def create_AdaVocabCausalLM(base_class): # Support LLama, Qwen2, Gemma class AdaVocabCausalLM(base_class): # TODO: Check the function of this variable and if it affects the AdaVocab Head model _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.sub_vocab_dim = config.ADA_DIM self.offload_tag = False # AdaVocabHead is already initialized with random weights/ SVD weights # so no need to use `self.post_init` method after this if config.ADA_ACT: self.adavocab_head = AdaVocabHead_MLP(self.lm_head, self.sub_vocab_dim, activation_func=nn.GELU()) else: self.adavocab_head = AdaVocabHead_LORA(self.lm_head, self.sub_vocab_dim, svd=config.ADA_SVD) self.freeze_original_model() def freeze_original_model(self): # freeze orginal llama except AdaVocabHead for param in self.model.parameters(): param.requires_grad = False for param in self.lm_head.parameters(): param.requires_grad = False for param in self.adavocab_head.parameters(): param.requires_grad = True def offload_lm_head(self): self.offload_tag = True self.lm_head = self.lm_head.to(torch.device('cpu')) def topk_mask(self, logits): # logits.shape: (batch_size, seq_len, vocab_size) topk_values, topk_indices = torch.topk(logits, self.config.ADA_TOPK, dim=-1) # topk_values.shape, topk_indices.shape: (batch_size, seq_len, topK) mask = torch.zeros_like(logits) # (batch_size, seq_len, vocab_size) # Only in top-k positions, put 1 to the corresponding position mask.scatter_(dim=-1, index=topk_indices, src=torch.ones_like(mask)) return mask def pred_with_sliced_lm_head_simple(self, ada_logits, hidden_states): # nll_loss = None # Limit activated tokens to ADA_TOPK during inference # ada_logits_mask = self.topk_mask(ada_logits) # (batch_size, seq_len, vocab_size) ada_logits, topk_indices = torch.topk(ada_logits, self.config.ADA_TOPK, dim=-1) # ada_logits: # (batch_size, seq_len, vocab_size) = # (batch_size, 1, vocab_size) # ada_logits = ada_logits * ada_logits_mask # (batch_size, seq_len, vocab_size) # ada_logits = topk_values # batch_size, seq_len, vocab_size = ada_logits.size() gt_zero_pos = torch.nonzero(ada_logits[:, -1, :] > 0, as_tuple=True)[-1].shape[0] ada_index_slice = topk_indices[:, :, :gt_zero_pos].flatten().to(self.lm_head.weight.device) # equivalent to `sigmoid(ada_logits) > 0.5` # union_ada_index_slice = torch.unique(ada_index_slice).to(self.lm_head.weight.device) # torch_size([union_size]) sliced_lm_head_weight = self.lm_head.weight[ada_index_slice, :].contiguous().to(hidden_states.device) # torch.Size([union_size, hidden_size]) lm_logits_sliced = hidden_states @ sliced_lm_head_weight.T # (batch_size, seq_len, union_size) return lm_logits_sliced, ada_index_slice def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, # TODO: check the effect of this new variable ) -> Union[Tuple, CausalLMOutputWithPast]: # TODO: How does forward know whether is training or inference? output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] # hidden_states.shape: (batch_size, seq_len, hidden_size) batch_size, seq_len, _ = hidden_states.size() vocab_size = self.lm_head.weight.shape[0] # This activation could be very large during training if vocab_size is large, # but in inference, storing activation is not needed # TINGYUAN self.adavocab_head.A.to(hidden_states.device) self.adavocab_head.B.to(hidden_states.device) ada_logits = self.adavocab_head(hidden_states[:, -1:, :]) # (batch_size, seq_len, vocab_size) # ada_logits = ada_logits.float() self.adavocab_head.A.to("cpu") self.adavocab_head.B.to("cpu") lm_head_logits = None lm_loss, mask_loss, topk_loss = None, None, None loss = None if labels is not None: # For prediction_step, training_step. Not for generation # ------ Only for Training and Eval Loop------ # During Inference, we don't need self.lm_head in GPU memory lm_head_logits = self.lm_head(hidden_states) # (batch_size, seq_len, vocab_size) lm_head_logits = lm_head_logits.float() # ------------------------------- # Supervised Signal of `self.adavocab_head` from two sources: # 1. (Primary) BCEWithLogitsLoss between ada_logits and topk_gt_mask (distillation signal) # 2. CrossEntropyLoss between ada_logits and labels with constraint (from ground truth labels) if self.training: # training_step # Loss from the second source # Shift so that tokens < n predict n shift_logits = ada_logits[..., :-1, :].contiguous() # (batch_size, seq_len - 1, vocab_size) shift_labels = labels[..., 1:].contiguous() # (batch_size, seq_len - 1) # Flatten the tokens loss_fct = CrossEntropyLoss() # CE loss includes the softmax function shift_logits = shift_logits.view(-1, self.config.vocab_size) # (batch_size * (seq_len - 1), vocab_size) shift_labels = shift_labels.view(-1) # (batch_size * seq_len) shift_labels = shift_labels.to(shift_logits.device) lm_loss = loss_fct(shift_logits, shift_labels) else: # prediction_step _, lm_loss = self.pred_with_sliced_lm_head(ada_logits, hidden_states, input_ids, labels, min_logit=-100) # Loss from the first source ada_logits_flat = ada_logits.view(-1, self.config.vocab_size) # (batch_size * seq_len, vocab_size) ada_probs = torch.sigmoid(ada_logits_flat) # (batch_size * seq_len, vocab_size) topk_gt_mask = self.topk_mask(lm_head_logits) # (batch_size, seq_len, vocab_size) # TODO: Add weights from lm_head_logits topk_gt_mask = topk_gt_mask.view(-1, self.config.vocab_size) # (batch_size * seq_len, vocab_size) mask_loss_fct = BCEWithLogitsLoss() # BCE Loss including the sigmoid function mask_loss = mask_loss_fct(ada_logits_flat, topk_gt_mask) ada_ones = ada_probs.sum() # scalar # TODO: Handle pad token in no-packing case target_ones = batch_size * seq_len * self.config.ADA_TOPK # scalar target_ones = torch.tensor(target_ones, dtype=torch.float32).to(ada_ones.device) # We need to normalize this loss, make it agnostic to batch size, seq_len, topK topk_loss = F.l1_loss(ada_ones, target_ones) / target_ones loss = self.config.ADA_LOSS_WEIGHT * lm_loss + self.config.ADA_MASK_WEIGHT * mask_loss + self.config.ADA_TOPK_WEIGHT * topk_loss else: # For generation with torch.no_grad(): ada_logits, lm_head_logits = self.pred_with_sliced_lm_head_simple(ada_logits, hidden_states[:, -1:, :]) if not return_dict: output = (ada_logits,) + outputs[1:] return (loss,) + output if loss is not None else output return AdaCausalLMOutputWithPast( loss=loss, logits=ada_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, # Added by AdaVocab lm_head_logits=lm_head_logits if lm_head_logits is not None else None, lm_loss=self.config.ADA_LOSS_WEIGHT * lm_loss if lm_loss is not None else None, mask_loss=self.config.ADA_MASK_WEIGHT * mask_loss if mask_loss is not None else None, topk_loss=self.config.ADA_TOPK_WEIGHT * topk_loss if topk_loss is not None else None, ) def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings # TODO: Add `get` and `set` methods for `adavocab_head` return AdaVocabCausalLM AdaVocabLlamaForCausalLM = create_AdaVocabCausalLM(LlamaForCausalLM) AdaVocabGemmaforCausalLM = create_AdaVocabCausalLM(GemmaForCausalLM) AdaVocabQwen2ForCausalLM = create_AdaVocabCausalLM(Qwen2ForCausalLM)