|
from transformers import TrainerCallback, Trainer |
|
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM |
|
from peft import PeftModel |
|
from datasets import Dataset |
|
from transformers.utils import is_sagemaker_mp_enabled, is_sagemaker_dp_enabled |
|
from typing import Any, Dict, Union, Optional, Tuple |
|
from torch.nn import MSELoss |
|
from transformers.utils import is_flash_attn_2_available, logging |
|
import inspect |
|
import warnings |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import time |
|
import os |
|
import copy |
|
|
|
from transformers.models.mistral.modeling_mistral import ( |
|
MistralMLP, |
|
MistralAttention, |
|
MistralModel, |
|
MistralDecoderLayer, |
|
MistralConfig, |
|
MISTRAL_ATTENTION_CLASSES, |
|
MistralRMSNorm, |
|
MistralForCausalLM, |
|
MistralFlashAttention2, |
|
) |
|
from experiments.models.sparse_mistral.svd_router import ( |
|
low_rank_approximation, |
|
SparsePredictor, |
|
) |
|
from utils.utils import ( |
|
print_size_of_model, |
|
is_running_deepspeed, |
|
is_mainprocess, |
|
get_datetime, |
|
ds_print, |
|
) |
|
|
|
if is_flash_attn_2_available(): |
|
from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
_flash_supports_window_size = "window_size" in list( |
|
inspect.signature(flash_attn_func).parameters |
|
) |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class SparseSFTTTrainer(SFTTrainer): |
|
def __init__(self, *args, **kwargs): |
|
self.regularization_coefficient = kwargs.pop("regularization_coefficient", 10) |
|
self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", False) |
|
self.use_spm_loss = False |
|
self.freeze_original_weights = False |
|
self.regularization_type = kwargs.pop( |
|
"regularization_type", "L1 positive activation" |
|
) |
|
assert self.regularization_type in [ |
|
"L2 activation", |
|
"L1 positive activation", |
|
], f"Invalid regularization type: {self.regularization_type}" |
|
self.sparse_layers = [] |
|
self.sparse_decoder_layers = [] |
|
super(SparseSFTTTrainer, self).__init__(*args, **kwargs) |
|
|
|
def initialize_sparse_silu_layers(self, model): |
|
self.sparse_layers = [ |
|
m for m in model.modules() if isinstance(m, MistralSparseSiluMLP) |
|
] |
|
|
|
def initialize_sparse_decoder_layers(self, model): |
|
self.sparse_decoder_layers = [ |
|
m for m in model.modules() if isinstance(m, SparseMistralDecoderLayer) |
|
] |
|
|
|
def training_step( |
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] |
|
) -> torch.Tensor: |
|
""" |
|
Override the huggingface's training_step function to add a regularization term. |
|
A regularization term is computed with intermediate values, which are freed after "backward()." |
|
You need to set `retain_graph=True` inside `backward` function to keep the values. |
|
""" |
|
model.train() |
|
inputs = self._prepare_inputs(inputs) |
|
|
|
with self.compute_loss_context_manager(): |
|
loss = self.compute_loss(model, inputs) |
|
|
|
if self.args.n_gpu > 1: |
|
loss = loss.mean() |
|
if not self.freeze_original_weights: |
|
if loss is not None: |
|
self.accelerator.backward(loss, retain_graph=False) |
|
|
|
if self.use_sparse_regularization: |
|
regularization_loss = self.compute_regularization(model) |
|
if self.args.n_gpu > 1: |
|
regularization_loss = regularization_loss.mean() |
|
if regularization_loss is not None: |
|
self.accelerator.backward(regularization_loss, retain_graph=True) |
|
loss += regularization_loss |
|
|
|
if self.use_spm_loss: |
|
spm_loss = self.compute_spm_loss(model) |
|
if self.args.n_gpu > 1: |
|
spm_loss = spm_loss.mean() |
|
if spm_loss is not None: |
|
self.accelerator.backward(spm_loss, retain_graph=False) |
|
loss += spm_loss |
|
|
|
return loss.detach() / self.args.gradient_accumulation_steps |
|
|
|
def compute_regularization(self, model): |
|
""" |
|
Compute a sparse regularization loss for SiLU |
|
""" |
|
loss = 0 |
|
if len(self.sparse_layers) == 0: |
|
self.initialize_sparse_silu_layers(model) |
|
num_layers = len(self.sparse_layers) |
|
|
|
for module in self.sparse_layers: |
|
if module.activation_norm is not None: |
|
loss += module.activation_norm |
|
|
|
loss /= num_layers |
|
loss *= self.regularization_coefficient |
|
|
|
if self.state.global_step % 20 == 0 and loss != 0: |
|
print("Negative relularizer loss: ", loss.item()) |
|
return loss |
|
|
|
def compute_spm_loss(self, model): |
|
loss = 0 |
|
if len(self.sparse_decoder_layers) == 0: |
|
self.initialize_sparse_decoder_layers(model) |
|
for module in self.sparse_decoder_layers: |
|
if module.distill_loss != None: |
|
loss += module.distill_loss |
|
if self.state.global_step % 20 == 0 and loss != 0: |
|
print("Sparse Predictor Distillation loss: ", loss.item()) |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SparseTrainer(Trainer): |
|
def __init__(self, *args, **kwargs): |
|
self.regularization_coefficient = kwargs.pop("regularization_coefficient", 10) |
|
self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", False) |
|
self.use_spm_loss = False |
|
self.freeze_original_weights = False |
|
self.regularization_type = kwargs.pop( |
|
"regularization_type", "L1 positive activation" |
|
) |
|
assert self.regularization_type in [ |
|
"L2 activation", |
|
"L1 positive activation", |
|
], f"Invalid regularization type: {self.regularization_type}" |
|
self.sparse_layers = [] |
|
self.sparse_decoder_layers = [] |
|
super(SparseTrainer, self).__init__(*args, **kwargs) |
|
|
|
def initialize_sparse_silu_layers(self, model): |
|
self.sparse_layers = [ |
|
m for m in model.modules() if isinstance(m, MistralSparseSiluMLP) |
|
] |
|
|
|
def initialize_sparse_decoder_layers(self, model): |
|
self.sparse_decoder_layers = [ |
|
m for m in model.modules() if isinstance(m, SparseMistralDecoderLayer) |
|
] |
|
|
|
def training_step( |
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] |
|
) -> torch.Tensor: |
|
""" |
|
Override the huggingface's training_step function to add a regularization term. |
|
A regularization term is computed with intermediate values, which are freed after "backward()." |
|
You need to set `retain_graph=True` inside `backward` function to keep the values. |
|
""" |
|
model.train() |
|
inputs = self._prepare_inputs(inputs) |
|
|
|
with self.compute_loss_context_manager(): |
|
loss = self.compute_loss(model, inputs) |
|
|
|
if self.args.n_gpu > 1: |
|
loss = loss.mean() |
|
if not self.freeze_original_weights: |
|
if loss is not None: |
|
self.accelerator.backward(loss, retain_graph=False) |
|
|
|
if self.use_sparse_regularization: |
|
regularization_loss = self.compute_regularization(model) |
|
if self.args.n_gpu > 1: |
|
regularization_loss = regularization_loss.mean() |
|
if regularization_loss is not None: |
|
self.accelerator.backward(regularization_loss, retain_graph=True) |
|
loss += regularization_loss |
|
|
|
if self.use_spm_loss: |
|
spm_loss = self.compute_spm_loss(model) |
|
if self.args.n_gpu > 1: |
|
spm_loss = spm_loss.mean() |
|
if spm_loss is not None: |
|
self.accelerator.backward(spm_loss, retain_graph=False) |
|
loss += spm_loss |
|
|
|
return loss.detach() / self.args.gradient_accumulation_steps |
|
|
|
def compute_regularization(self, model): |
|
""" |
|
Compute a sparse regularization loss for SiLU |
|
""" |
|
loss = 0 |
|
if len(self.sparse_layers) == 0: |
|
self.initialize_sparse_silu_layers(model) |
|
num_layers = len(self.sparse_layers) |
|
|
|
for module in self.sparse_layers: |
|
if module.activation_norm is not None: |
|
loss += module.activation_norm |
|
|
|
loss /= num_layers |
|
loss *= self.regularization_coefficient |
|
|
|
if self.state.global_step % 20 == 0 and loss != 0: |
|
print("Negative relularizer loss: ", loss.item()) |
|
return loss |
|
|
|
def compute_spm_loss(self, model): |
|
loss = 0 |
|
if len(self.sparse_decoder_layers) == 0: |
|
self.initialize_sparse_decoder_layers(model) |
|
for module in self.sparse_decoder_layers: |
|
if module.distill_loss != None: |
|
loss += module.distill_loss |
|
if self.state.global_step % 20 == 0 and loss != 0: |
|
print("Sparse Predictor Distillation loss: ", loss.item()) |
|
return loss |
|
|
|
|
|
class SparseSiLU(nn.SiLU): |
|
def __init__(self, threshold): |
|
super(SparseSiLU, self).__init__() |
|
self.threshold = threshold |
|
self.m = nn.Threshold(self.threshold, 0) |
|
|
|
def set_new_threshold(self, threshold): |
|
self.threshold = threshold |
|
self.m = nn.Threshold(threshold, 0) |
|
|
|
def forward(self, x): |
|
act = super(SparseSiLU, self).forward(x) |
|
return self.m(act) - self.m(-act) |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`): |
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
|
used to pass offsetted position ids when working with a KV-cache. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
|
sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def _get_unpad_data(attention_mask): |
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
|
max_seqlen_in_batch = seqlens_in_batch.max().item() |
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
return ( |
|
indices, |
|
cu_seqlens, |
|
max_seqlen_in_batch, |
|
) |
|
|
|
|
|
class SparseMistralFlashAttention(MistralFlashAttention2): |
|
""" |
|
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer |
|
and "Generating Long Sequences with Sparse Transformers". |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.counts = 0 |
|
self.pre_attn_sparsity = 0 |
|
self.visit_counts = 0 |
|
self.is_stats = False |
|
self.cutoff_value = 0 |
|
|
|
def activate_stats(self): |
|
self.is_stats = True |
|
self.visit_counts = 0 |
|
self.pre_attn_sparsity = 0 |
|
self.cutoff_value = 0 |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
): |
|
if "padding_mask" in kwargs: |
|
warnings.warn( |
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
|
) |
|
|
|
|
|
attention_mask = kwargs.pop("padding_mask") |
|
bsz, q_len, _ = hidden_states.size() |
|
mask = abs(hidden_states - hidden_states.mean()) < 0.6 * hidden_states.std() |
|
hidden_states[mask] = 0 |
|
self.counts += 1 |
|
|
|
if self.is_stats: |
|
self.pre_attn_sparsity = ( |
|
self.pre_attn_sparsity * self.visit_counts + (hidden_states == 0).float().mean() |
|
) / (self.visit_counts + 1) |
|
self.cutoff_value += hidden_states.std() |
|
self.visit_counts += 1 |
|
|
|
if 10 <= self.counts <= 11: |
|
print(f"Attention {self.layer_idx}: ", (hidden_states == 0).float().mean()) |
|
print( |
|
mask.shape, |
|
) |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view( |
|
bsz, q_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
key_states = key_states.view( |
|
bsz, q_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
value_states = value_states.view( |
|
bsz, q_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
if self.layer_idx is None: |
|
raise ValueError( |
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
|
"with a layer index." |
|
) |
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
|
|
|
|
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 |
|
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) |
|
|
|
query_states, key_states = apply_rotary_pos_emb( |
|
query_states, key_states, cos, sin, position_ids |
|
) |
|
|
|
use_sliding_windows = ( |
|
_flash_supports_window_size |
|
and getattr(self.config, "sliding_window", None) is not None |
|
and kv_seq_len > self.config.sliding_window |
|
) |
|
|
|
if not _flash_supports_window_size: |
|
logger.warning_once( |
|
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation" |
|
" make sure to upgrade flash-attn library." |
|
) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 |
|
if ( |
|
getattr(self.config, "sliding_window", None) is not None |
|
and kv_seq_len > self.config.sliding_window |
|
and cache_has_contents |
|
): |
|
slicing_tokens = 1 - self.config.sliding_window |
|
|
|
past_key = past_key_value[self.layer_idx][0] |
|
past_value = past_key_value[self.layer_idx][1] |
|
|
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous() |
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous() |
|
|
|
if past_key.shape[-2] != self.config.sliding_window - 1: |
|
raise ValueError( |
|
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" |
|
f" {past_key.shape}" |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask[:, slicing_tokens:] |
|
attention_mask = torch.cat( |
|
[attention_mask, torch.ones_like(attention_mask[:, -1:])], |
|
dim=-1, |
|
) |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos} |
|
key_states, value_states = past_key_value.update( |
|
key_states, value_states, self.layer_idx, cache_kwargs |
|
) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
dropout_rate = 0.0 if not self.training else self.attention_dropout |
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
if input_dtype == torch.float32: |
|
if torch.is_autocast_enabled(): |
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
else: |
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
logger.warning_once( |
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
query_states = query_states.to(target_dtype) |
|
key_states = key_states.to(target_dtype) |
|
value_states = value_states.to(target_dtype) |
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
attn_output = self._flash_attention_forward( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
q_len, |
|
dropout=dropout_rate, |
|
use_sliding_windows=use_sliding_windows, |
|
) |
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
def _flash_attention_forward( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
query_length, |
|
dropout=0.0, |
|
softmax_scale=None, |
|
use_sliding_windows=False, |
|
): |
|
""" |
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
|
first unpad the input, then computes the attention scores and pad the final attention scores. |
|
|
|
Args: |
|
query_states (`torch.Tensor`): |
|
Input query states to be passed to Flash Attention API |
|
key_states (`torch.Tensor`): |
|
Input key states to be passed to Flash Attention API |
|
value_states (`torch.Tensor`): |
|
Input value states to be passed to Flash Attention API |
|
attention_mask (`torch.Tensor`): |
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
|
position of padding tokens and 1 for the position of non-padding tokens. |
|
dropout (`float`): |
|
Attention dropout |
|
softmax_scale (`float`, *optional*): |
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
|
use_sliding_windows (`bool`, *optional*): |
|
Whether to activate sliding window attention. |
|
""" |
|
if not self._flash_attn_uses_top_left_mask: |
|
causal = self.is_causal |
|
else: |
|
|
|
causal = self.is_causal and query_length != 1 |
|
|
|
|
|
if attention_mask is not None: |
|
batch_size = query_states.shape[0] |
|
( |
|
query_states, |
|
key_states, |
|
value_states, |
|
indices_q, |
|
cu_seq_lens, |
|
max_seq_lens, |
|
) = self._upad_input( |
|
query_states, key_states, value_states, attention_mask, query_length |
|
) |
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
|
if not use_sliding_windows: |
|
attn_output_unpad = flash_attn_varlen_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_in_batch_q, |
|
max_seqlen_k=max_seqlen_in_batch_k, |
|
dropout_p=dropout, |
|
softmax_scale=softmax_scale, |
|
causal=causal, |
|
) |
|
else: |
|
attn_output_unpad = flash_attn_varlen_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_in_batch_q, |
|
max_seqlen_k=max_seqlen_in_batch_k, |
|
dropout_p=dropout, |
|
softmax_scale=softmax_scale, |
|
causal=causal, |
|
window_size=( |
|
self.config.sliding_window, |
|
self.config.sliding_window, |
|
), |
|
) |
|
|
|
attn_output = pad_input( |
|
attn_output_unpad, indices_q, batch_size, query_length |
|
) |
|
else: |
|
if not use_sliding_windows: |
|
attn_output = flash_attn_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
dropout, |
|
softmax_scale=softmax_scale, |
|
causal=causal, |
|
) |
|
else: |
|
attn_output = flash_attn_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
dropout, |
|
softmax_scale=softmax_scale, |
|
causal=causal, |
|
window_size=( |
|
self.config.sliding_window, |
|
self.config.sliding_window, |
|
), |
|
) |
|
|
|
return attn_output |
|
|
|
def _upad_input( |
|
self, query_layer, key_layer, value_layer, attention_mask, query_length |
|
): |
|
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape |
|
|
|
|
|
|
|
if kv_seq_len != attention_mask.shape[-1]: |
|
attention_mask_num_tokens = attention_mask.shape[-1] |
|
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] |
|
|
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
|
|
|
key_layer = index_first_axis( |
|
key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k |
|
) |
|
value_layer = index_first_axis( |
|
value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k |
|
) |
|
|
|
if query_length == kv_seq_len: |
|
query_layer = index_first_axis( |
|
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), |
|
indices_k, |
|
) |
|
cu_seqlens_q = cu_seqlens_k |
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k |
|
indices_q = indices_k |
|
elif query_length == 1: |
|
max_seqlen_in_batch_q = 1 |
|
cu_seqlens_q = torch.arange( |
|
batch_size + 1, dtype=torch.int32, device=query_layer.device |
|
) |
|
indices_q = cu_seqlens_q[:-1] |
|
query_layer = query_layer.squeeze(1) |
|
else: |
|
|
|
attention_mask = attention_mask[:, -query_length:] |
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( |
|
query_layer, attention_mask |
|
) |
|
|
|
return ( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
indices_q, |
|
(cu_seqlens_q, cu_seqlens_k), |
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
|
) |
|
|
|
|
|
class SparseMistralAttention(MistralAttention): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.counts = 0 |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
if "padding_mask" in kwargs: |
|
warnings.warn( |
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
|
) |
|
bsz, q_len, _ = hidden_states.size() |
|
mask = abs(hidden_states - hidden_states.mean()) < 0.6 * hidden_states.std() |
|
hidden_states[mask] = 0 |
|
self.counts += 1 |
|
if 10 <= self.counts <= 11: |
|
print(f"Attention {self.layer_idx}: ", (hidden_states == 0).float().mean()) |
|
self.counts += 1 |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view( |
|
bsz, q_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
key_states = key_states.view( |
|
bsz, q_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
value_states = value_states.view( |
|
bsz, q_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
if self.layer_idx is None: |
|
raise ValueError( |
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " |
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " |
|
"with a layer index." |
|
) |
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
query_states, key_states = apply_rotary_pos_emb( |
|
query_states, key_states, cos, sin, position_ids |
|
) |
|
|
|
if past_key_value is not None: |
|
cache_kwargs = {"sin": sin, "cos": cos} |
|
key_states, value_states = past_key_value.update( |
|
key_states, value_states, self.layer_idx, cache_kwargs |
|
) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul( |
|
query_states, key_states.transpose(2, 3) |
|
) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
|
raise ValueError( |
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
|
) |
|
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax( |
|
attn_weights, dim=-1, dtype=torch.float32 |
|
).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout( |
|
attn_weights, p=self.attention_dropout, training=self.training |
|
) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
class MistralSparseSiluMLP(MistralMLP): |
|
def __init__(self, config, *args, **kwargs): |
|
super().__init__(config) |
|
self.swish_outputs = None |
|
self.relu = nn.ReLU() |
|
self.resilu = nn.Sequential(nn.SiLU()) |
|
|
|
self.kill_sparse_swish_outputs = False |
|
self.dead_percentage = 0 |
|
self.is_stats = False |
|
self.visit_counts = 0 |
|
self.pre_mlp_sparsity = 0 |
|
|
|
|
|
self.dead_threshold = kwargs.pop("dead_threshold", 0) |
|
self.pre_mlp_threshold = kwargs.pop("pre_mlp_threshold", 0) |
|
self.pre_mlp_dead_threshold = kwargs.pop("pre_mlp_dead_threshold", 0) |
|
self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", True) |
|
self.regularization_type = kwargs.pop( |
|
"regularization_type", "L1 regularization" |
|
) |
|
self.regularization_threshold = kwargs.pop("regularization_threshold", 0.5) |
|
self.use_relu = kwargs.pop("use_relu", False) |
|
self.use_resilu = kwargs.pop("use_resilu", False) |
|
self.activation_norm = None |
|
|
|
|
|
self.is_collect_histogram = False |
|
num_bins = 1000 |
|
self.histogram_bins = torch.linspace(-1, 1, num_bins - 2) |
|
self.histogram_bins = torch.cat( |
|
[torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])] |
|
) |
|
self.pre_mlp_hist_counts = torch.zeros(num_bins - 1) |
|
self.pre_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.post_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.t = 0 |
|
self.count = 0 |
|
self.agg_sparsity = 0 |
|
|
|
|
|
self.sparse_act_fn = SparseSiLU(threshold=self.dead_threshold) |
|
|
|
def activate_stats(self, is_collect_histogram: bool = True): |
|
self.is_stats = True |
|
self.dead_percentage = 0 |
|
self.visit_counts = 0 |
|
self.is_collect_histogram = is_collect_histogram |
|
self.histogram_counts = torch.zeros(2000) |
|
|
|
def deactivate_stats(self): |
|
self.is_stats = False |
|
|
|
def collect_stats( |
|
self, |
|
pre_mlp, |
|
pre_activation, |
|
post_activation, |
|
): |
|
start_time = time.time() |
|
pre_activation = pre_activation.float().cpu().detach() |
|
post_activation = post_activation.float().cpu().detach() |
|
|
|
self.pre_mlp_hist_counts = torch.histogram(pre_mlp, bins=self.histogram_bins)[0] |
|
self.pre_act_hist_counts += torch.histogram( |
|
pre_activation, bins=self.histogram_bins |
|
)[0] |
|
self.post_act_hist_counts += torch.histogram( |
|
torch.abs(post_activation), bins=self.histogram_bins |
|
)[0] |
|
self.t += time.time() - start_time |
|
if self.visit_counts % 30 == 0: |
|
print(f"Time taken to collect stats: {self.t}s.") |
|
|
|
def forward( |
|
self, |
|
x, |
|
sp_mask: torch.tensor = None, |
|
): |
|
""" |
|
If kill_sparse_swish_outputs is set to False, this layer functions exactly like a normal MLP layer. |
|
""" |
|
if sp_mask != None: |
|
return self.down_proj( |
|
self.sparse_act_fn(self.gate_proj(x) * sp_mask) * self.up_proj(x) |
|
) |
|
|
|
elif self.use_relu or self.use_resilu: |
|
if self.use_relu: |
|
post_act = self.relu(self.gate_proj(x)) |
|
else: |
|
post_act = self.resilu(self.gate_proj(x)) |
|
self.count += 1 |
|
if self.count <= 1: |
|
print("USING RELU or ReSiLU!!!!") |
|
|
|
if self.is_stats: |
|
dead_neurons = post_act == 0 |
|
dead_percentage = dead_neurons.float().mean() |
|
agg_sparsity = dead_neurons.all(dim=0).float().mean() |
|
|
|
self.dead_percentage = ( |
|
self.dead_percentage * self.visit_counts + dead_percentage |
|
) / (self.visit_counts + 1) |
|
self.agg_sparsity = ( |
|
self.agg_sparsity * self.visit_counts + agg_sparsity |
|
) / (self.visit_counts + 1) |
|
self.visit_counts += 1 |
|
|
|
return self.down_proj(post_act * self.up_proj(x)) |
|
|
|
else: |
|
self.count += 1 |
|
|
|
if self.count <= 1: |
|
print("USING SparseSILU!!!!") |
|
|
|
|
|
pre_act = self.gate_proj(x) |
|
post_act = self.act_fn(pre_act) |
|
if self.kill_sparse_swish_outputs: |
|
dead_neurons = post_act.abs() <= self.dead_threshold |
|
|
|
|
|
dead_percentage = dead_neurons.float().mean() |
|
agg_sparsity = dead_neurons.all(dim=0).float().mean() |
|
|
|
if self.is_stats: |
|
self.dead_percentage = ( |
|
self.dead_percentage * self.visit_counts + dead_percentage |
|
) / (self.visit_counts + 1) |
|
self.agg_sparsity = ( |
|
self.agg_sparsity * self.visit_counts + agg_sparsity |
|
) / (self.visit_counts + 1) |
|
self.pre_mlp_sparsity = ( |
|
self.pre_mlp_sparsity * self.visit_counts |
|
+ (x == 0).float().mean() |
|
) / (self.visit_counts + 1) |
|
self.visit_counts += 1 |
|
|
|
self.a = dead_percentage |
|
|
|
|
|
|
|
|
|
if ( |
|
self.is_collect_histogram |
|
and pre_act.eq(0).float().mean() < 0.99 |
|
): |
|
self.collect_stats(x, pre_act, post_act) |
|
|
|
post_act[dead_neurons] = 0 |
|
if 10 <= self.count <= 11: |
|
print( |
|
f"sparsity: {dead_percentage}/ pre-activation sparsity: {(x==0).float().mean()}" |
|
) |
|
|
|
out = self.down_proj(post_act * self.up_proj(x)) |
|
if self.use_sparse_regularization: |
|
if self.regularization_type == "L1 regularization": |
|
self.activation_norm = torch.abs(post_act)[ |
|
post_act < self.regularization_threshold |
|
].mean() |
|
elif self.regularization_type == "L2 regularization": |
|
self.activation_norm = torch.sqrt( |
|
torch.square(post_act)[post_act < self.regularization_threshold] |
|
).mean() |
|
|
|
return out |
|
|
|
|
|
class SparseMistralDecoderLayer(MistralDecoderLayer): |
|
def __init__( |
|
self, |
|
config: MistralConfig, |
|
layer_idx: int, |
|
decoder_layer: MistralDecoderLayer, |
|
init_svd: bool = True, |
|
*args, |
|
**kwargs, |
|
): |
|
assert isinstance( |
|
decoder_layer.mlp, MistralSparseSiluMLP |
|
), f"{type(decoder_layer.mlp)} should MistralSparseSiluMLP." |
|
|
|
super().__init__(config, layer_idx) |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
|
|
self.init_svd = init_svd |
|
self.self_attn = decoder_layer.self_attn |
|
|
|
self.mlp = decoder_layer.mlp |
|
self.input_layernorm = decoder_layer.input_layernorm |
|
self.post_attention_layernorm = decoder_layer.post_attention_layernorm |
|
|
|
|
|
self.low_rank = kwargs.pop("low_rank", 64) |
|
self.sparse_act_func = decoder_layer.mlp.sparse_act_fn |
|
|
|
print( |
|
f"Setting {layer_idx}th mlp layer's sparse predictor... svd init: {init_svd}" |
|
) |
|
self.sp_mlp = low_rank_approximation( |
|
decoder_layer.mlp.gate_proj, |
|
act_func=self.sparse_act_func, |
|
init_svd=init_svd, |
|
) |
|
self.use_async = kwargs.pop("use_async", False) |
|
self.use_sparse_predictor = False |
|
self.distill_loss = None |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
) -> Tuple[ |
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
|
]: |
|
print("hidden_states shape: ", hidden_states.shape) |
|
if "padding_mask" in kwargs: |
|
warnings.warn( |
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
|
) |
|
|
|
residual = hidden_states |
|
sp_mask = None |
|
|
|
if self.use_async: |
|
sp_mask = self.sp_mlp(hidden_states) |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
if not self.use_async: |
|
sp_mask = self.sp_mlp(hidden_states) |
|
|
|
|
|
gating_output = self.mlp.sparse_act_fn(self.mlp.gate_proj(hidden_states)) |
|
loss_func = MSELoss() |
|
self.distill_loss = loss_func(sp_mask, gating_output) |
|
|
|
|
|
sp_mask = sp_mask > 0 |
|
|
|
if self.training: |
|
sp_mask = None |
|
|
|
|
|
|
|
hidden_states = self.mlp(hidden_states, sp_mask) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
|
|
class SparseMistralConfig(MistralConfig): |
|
model_type = "sparse_mistral" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class SparseMistralforCausalLM(MistralForCausalLM): |
|
config_class = SparseMistralConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
if config.use_sparse_model: |
|
self.apply_sparse_mlp() |
|
if config.thresholds is not None: |
|
for idx, m in enumerate(self.model.layers): |
|
if isinstance(m.mlp, MistralSparseSiluMLP): |
|
m.mlp.dead_threshold = config.thresholds[idx] |
|
m.mlp.sparse_act_fn.set_new_threshold(m.mlp.dead_threshold) |
|
m.mlp.kill_sparse_swish_outputs = True |
|
m.mlp.use_relu = getattr(config, "use_relu", False) |
|
m.mlp.use_resilu = getattr(config, "use_resilu", False) |
|
if config.use_sparse_predictor: |
|
self.apply_sparse_predictor(init_svd=config.init_svd) |
|
|
|
def apply_sparse_mlp(self): |
|
apply_mistral_sparse_silu_mlp( |
|
self, |
|
config=self.config, |
|
use_sparse_regularization=self.config.use_sparse_regularization, |
|
) |
|
|
|
def apply_sparse_predictor(self, init_svd: bool = True): |
|
apply_mistral_sparse_decoder_layer(self, config=self.config, init_svd=init_svd) |
|
|
|
|
|
class GracefulRegularizationScheduler(TrainerCallback): |
|
def __init__( |
|
self, |
|
num_warmup_steps=40, |
|
is_enabled: bool = False, |
|
model_name: str = "mistral", |
|
test_dataset: Dataset = None, |
|
targeted_sparsity: float = 0.5, |
|
keep_regularization_with_kill: bool = False, |
|
): |
|
"""Scheduler for regularizing the model first before applying the dead threshold. |
|
|
|
:param num_warmup_steps: number of training steps required to reach the dead threshold, defaults to 40 |
|
:param increment_ratio: by how much to increase the dead threshold. |
|
For example, 0.5 means "increase the threshold by 0.5 * desired threshold |
|
""" |
|
self.num_warmup_steps = num_warmup_steps |
|
self.is_enabled = is_enabled |
|
self.model_name = model_name |
|
self.test_dataset = test_dataset |
|
self.targeted_sparsity = targeted_sparsity |
|
self.keep_regularization_with_kill = keep_regularization_with_kill |
|
self.act_hist_path = ( |
|
f"/matx/u/vxbrando/histograms/warm_up_reg_{targeted_sparsity}/act_hist.pt" |
|
) |
|
if self.is_enabled: |
|
print("GracefulRegularizationScheduler is enabled.") |
|
self.trainer = None |
|
|
|
def set_trainer(self, trainer): |
|
self.trainer = trainer |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
if not self.is_enabled: |
|
return |
|
|
|
model = kwargs["model"] |
|
if isinstance(model, PeftModel): |
|
base_model = model.get_base_model() |
|
else: |
|
base_model = model |
|
|
|
if state.global_step == 1: |
|
ds_print("Setting an initial reg threshold to 0.1") |
|
set_regularization_threshold(base_model, 0.1) |
|
|
|
|
|
if state.global_step == self.num_warmup_steps: |
|
activate_stats(base_model) |
|
enable_sparse_silu(base_model) |
|
self.trainer.evaluate() |
|
save_act_hist(base_model, self.act_hist_path) |
|
set_sparse_threshold(base_model, self.targeted_sparsity, True) |
|
deactivate_stats(base_model) |
|
self.trainer.use_sparse_regularization = self.keep_regularization_with_kill |
|
|
|
print_dead_neuron_stats(model.get_base_model()) |
|
|
|
if state.global_step % 2000 == 0: |
|
if is_mainprocess(): |
|
ds_print( |
|
f"Saving to /scr/lukeai/{self.model_name}_{state.global_step}.pt", |
|
) |
|
torch.save( |
|
model.state_dict(), |
|
f"/scr/lukeai/{self.model_name}_{state.global_step}.pt", |
|
) |
|
|
|
|
|
class GradualSparsificationScheduler(TrainerCallback): |
|
def __init__( |
|
self, |
|
num_warmup_steps=40, |
|
increment_ratio=0.5, |
|
is_enabled: bool = False, |
|
model_name: str = "mistral", |
|
): |
|
"""Scheduler for gradually increasing a dead threshold until it reaches the desired threshold. |
|
|
|
:param num_warmup_steps: number of training steps required to reach the dead threshold, defaults to 40 |
|
:param increment_ratio: by how much to increase the dead threshold. |
|
For example, 0.5 means "increase the threshold by 0.5 * desired threshold |
|
""" |
|
self.num_warmup_steps = num_warmup_steps |
|
self.increment_ratio = increment_ratio |
|
self.step_size = int(num_warmup_steps * increment_ratio) |
|
self.is_enabled = is_enabled |
|
self.model_name = model_name |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
model = kwargs["model"] |
|
|
|
if not self.is_enabled: |
|
if state.global_step <= 10: |
|
for module in model.modules(): |
|
if isinstance(module, MistralSparseSiluMLP): |
|
module.current_dead_threshold = module.dead_threshold |
|
return |
|
|
|
current_dead_threshold = 0 |
|
desired_dead_threshold = 0 |
|
|
|
if is_mainprocess(): |
|
ds_print(state.global_step) |
|
|
|
if state.global_step % self.step_size == 2: |
|
for module in model.modules(): |
|
if isinstance(module, MistralSparseSiluMLP): |
|
desired_dead_threshold = copy.deepcopy(module.dead_threshold) |
|
current_dead_threshold = module.current_dead_threshold |
|
current_dead_threshold += ( |
|
self.increment_ratio * desired_dead_threshold |
|
) |
|
module.current_dead_threshold = min( |
|
desired_dead_threshold, current_dead_threshold |
|
) |
|
|
|
if is_running_deepspeed and is_mainprocess(): |
|
ds_print( |
|
state.global_step, |
|
current_dead_threshold, |
|
desired_dead_threshold, |
|
) |
|
|
|
if state.global_step % 2000 == 0: |
|
if is_running_deepspeed and is_mainprocess(): |
|
ds_print( |
|
f"Saving to /matx/u/lukeai/{self.model_name}_{state.global_step - 2}.pt", |
|
) |
|
torch.save( |
|
model.state_dict(), |
|
f"/matx/u/lukeai/{self.model_name}_{state.global_step - 2}.pt", |
|
) |
|
|
|
|
|
def get_sparse_mistral_config( |
|
config: MistralConfig, |
|
use_sparse_model=False, |
|
use_sparse_predictor=False, |
|
use_sparse_regularization=False, |
|
thresholds=None, |
|
): |
|
new_config = SparseMistralConfig() |
|
new_config.__dict__.update(config.__dict__) |
|
config = new_config |
|
config.use_sparse_model = use_sparse_model |
|
config.use_sparse_predictor = use_sparse_predictor |
|
config.use_sparse_regularization = use_sparse_regularization |
|
config.thresholds = thresholds |
|
|
|
return config |
|
|
|
|
|
def apply_mistral_sparse_silu_mlp( |
|
model, config, use_sparse_regularization: bool = False, use_flash_attn: bool = False |
|
): |
|
for layer in model.model.layers: |
|
|
|
|
|
|
|
original_mlp = layer.mlp |
|
new_mlp = MistralSparseSiluMLP( |
|
config, use_sparse_regularization=use_sparse_regularization |
|
) |
|
new_mlp.gate_proj = original_mlp.gate_proj |
|
new_mlp.up_proj = original_mlp.up_proj |
|
new_mlp.down_proj = original_mlp.down_proj |
|
layer.mlp = new_mlp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_mistral_sparse_attention( |
|
model, |
|
config, |
|
): |
|
for layer in model.model.layers: |
|
layer.self_attention = layer.self_attention |
|
|
|
|
|
def apply_mistral_sparse_decoder_layer( |
|
model, |
|
config, |
|
init_svd: bool = True, |
|
): |
|
assert isinstance(model.model, MistralModel), "model.model must be a MistralModel." |
|
new_layers = [] |
|
for layer_idx, layer in enumerate(model.model.layers): |
|
if isinstance(layer.mlp, MistralSparseSiluMLP): |
|
new_layers.append( |
|
SparseMistralDecoderLayer( |
|
config=config, |
|
layer_idx=layer_idx, |
|
decoder_layer=layer, |
|
init_svd=init_svd, |
|
) |
|
) |
|
print(f"{layer_idx}th mlp layer activation: {layer.mlp.sparse_act_fn}") |
|
else: |
|
new_layers.append(layer) |
|
model.model.layers = nn.ModuleList(new_layers) |
|
|
|
|
|
def enable_sparse_predictor( |
|
model, |
|
): |
|
for layer_idx, layer in enumerate(model.model.layers): |
|
if isinstance(layer, MistralDecoderLayer): |
|
layer.use_sparse_predictor = True |
|
|
|
|
|
def disable_sparse_predictor( |
|
model, |
|
): |
|
for layer_idx, layer in enumerate(model.model.layers): |
|
if isinstance(layer, MistralDecoderLayer): |
|
layer.use_sparse_predictor = False |
|
|
|
|
|
def activate_stats(model, is_collect_histogram: bool = True): |
|
for layer in model.model.layers: |
|
if isinstance(layer.mlp, MistralSparseSiluMLP): |
|
layer.mlp.activate_stats(is_collect_histogram=is_collect_histogram) |
|
if isinstance(layer.self_attn, SparseMistralAttention): |
|
layer.self_attn.activate_stats() |
|
|
|
|
|
def deactivate_stats(model): |
|
for layer in model.model.layers: |
|
if isinstance(layer.mlp, MistralSparseSiluMLP): |
|
layer.mlp.deactivate_stats() |
|
|
|
|
|
def enable_sparse_silu(model): |
|
print("Enabling SparseSilu") |
|
for i, layer in enumerate(model.model.layers): |
|
if isinstance(layer.mlp, MistralSparseSiluMLP): |
|
layer.mlp.kill_sparse_swish_outputs = True |
|
|
|
|
|
def print_dead_neuron_stats(model): |
|
total_sparsity = 0 |
|
counts = 0 |
|
for i, layer in enumerate(model.model.layers): |
|
if isinstance(layer.mlp, MistralSparseSiluMLP): |
|
dead_percentage = layer.mlp.dead_percentage * 100 |
|
agg_sparsity = layer.mlp.agg_sparsity * 100 |
|
pre_mlp_sparsity = layer.mlp.pre_mlp_sparsity * 100 |
|
print(f"layer {i} sparsity: {dead_percentage:.3f}%") |
|
print(f"layer {i} agg sparsity: {agg_sparsity:.3f}%") |
|
print(f"layer {i} pre_mlp_sparsity: {pre_mlp_sparsity:.3f}%") |
|
|
|
total_sparsity += dead_percentage |
|
counts += 1 |
|
if isinstance(layer.self_attn, SparseMistralAttention) or isinstance(layer.self_attn, SparseMistralFlashAttention): |
|
print( |
|
f"Attention layer {i} sparsity: {layer.self_attn.pre_attn_sparsity * 100: .3f}%" |
|
) |
|
|
|
print(f"Total sparsity: {total_sparsity/counts: .3f}%") |
|
return total_sparsity / counts |
|
|
|
|
|
def get_sparse_layers(model: MistralModel): |
|
sparse_layers = [ |
|
m.mlp for m in model.layers() if isinstance(m.mlp, MistralSparseSiluMLP) |
|
] |
|
return sparse_layers |
|
|
|
|
|
def get_threshold( |
|
bin_edges: torch.tensor, histogram_counts: torch.tensor, sparsity_level: float |
|
): |
|
assert ( |
|
len(bin_edges.shape) == len(histogram_counts.shape) == 1 |
|
), "bin_edges and histogram are expected to be 1-dimensional." |
|
histogram_counts /= histogram_counts.sum() |
|
threshold_idx = torch.searchsorted( |
|
histogram_counts.cumsum(0), sparsity_level, side="right" |
|
) |
|
|
|
return bin_edges[threshold_idx] |
|
|
|
|
|
def set_regularization_threshold(model, threshold: float = 0.1): |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
|
): |
|
layer.mlp.regularization_threshold = threshold |
|
|
|
|
|
def set_sparse_threshold( |
|
model, sparsity_level: float, use_relu: bool = False, use_resilu: bool = False |
|
): |
|
assert not (use_relu and use_resilu), "It's not allowed to use both relu and resilu" |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
|
): |
|
if use_relu: |
|
layer.mlp.sparse_act_fn = nn.ReLU() |
|
layer.mlp.use_relu = True |
|
layer.mlp.use_resilu = False |
|
elif use_resilu: |
|
layer.mlp.sparse_act_fn = nn.Sequential(nn.ReLU(), nn.SiLU()) |
|
layer.mlp.use_resilu = True |
|
layer.mlp.use_relu = False |
|
else: |
|
layer.mlp.dead_threshold = get_threshold( |
|
layer.mlp.histogram_bins, |
|
layer.mlp.post_act_hist_counts, |
|
sparsity_level, |
|
) |
|
layer.mlp.sparse_act_fn.set_new_threshold(layer.mlp.dead_threshold) |
|
layer.mlp.regularization_threshold = ( |
|
layer.mlp.dead_threshold * 1.2 |
|
) |
|
|
|
|
|
def plot_histogram( |
|
bin_edges, |
|
histogram_counts: torch.tensor, |
|
title: str = "Activation Distribution", |
|
fig_dir: str = "figures", |
|
): |
|
plt.bar( |
|
bin_edges[:-1], histogram_counts, width=np.diff(bin_edges), edgecolor="black" |
|
) |
|
plt.title(title) |
|
plt.xlabel("Activation Value") |
|
plt.ylabel("Frequency") |
|
os.makedirs(fig_dir, exist_ok=True) |
|
plt.savefig(f"{fig_dir}/{title}.png") |
|
|
|
plt.clf() |
|
|
|
|
|
def plot_act(model, fig_dir: str = "figures"): |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
|
): |
|
plot_title = f"Layer: {i} Pre-Activation Distribution" |
|
plot_histogram( |
|
layer.mlp.histogram_bins, layer.mlp.pre_act_hist_counts, plot_title |
|
) |
|
|
|
plot_title = f"Layer: {i} Post-Activation Absolute Distribution" |
|
plot_histogram( |
|
layer.mlp.histogram_bins, layer.mlp.post_act_hist_counts, plot_title |
|
) |
|
|
|
|
|
def save_act_hist( |
|
model, filename="/scr/jay/models/mistral/pre_finetune/cola_act_hist.pt" |
|
): |
|
os.makedirs(os.path.dirname(filename), exist_ok=True) |
|
act_dict = {} |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
|
): |
|
act_dict[i] = ( |
|
layer.mlp.histogram_bins, |
|
|
|
layer.mlp.pre_act_hist_counts, |
|
layer.mlp.post_act_hist_counts, |
|
) |
|
print("Saving activation histograms...\n\n\n") |
|
torch.save(act_dict, filename) |
|
|
|
|
|
def load_act_hist( |
|
model, filename="/scr/jay/models/mistral/pre_finetune/cola_act_hist.pt" |
|
): |
|
assert os.path.exists( |
|
filename |
|
), f"{filename} does not exist when loading pre/post-activation histogram of SparseMistralSiluMLP." |
|
print("Loading activation histograms...\n\n\n") |
|
|
|
act_dict = torch.load(filename) |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, MistralSparseSiluMLP) and layer.mlp.is_stats |
|
): |
|
( |
|
layer.mlp.histogram_bins, |
|
|
|
layer.mlp.pre_act_hist_counts, |
|
layer.mlp.post_act_hist_counts, |
|
) = act_dict[i] |
|
|
|
|
|
def enable_last_k_modules(model, start_module_idx: int): |
|
assert 32 > start_module_idx >= 0 |
|
new_modules = [] |
|
new_idx = 0 |
|
for idx in range(start_module_idx, len(model.model.original_layers)): |
|
module = model.model.original_layers[idx] |
|
module.layer_idx = new_idx |
|
module.self_attn.layer_idx = new_idx |
|
new_modules.append(module) |
|
new_idx += 1 |
|
print(module.layer_idx) |
|
|
|
model.model.layers = nn.ModuleList(new_modules) |
|
|
|
|
|
def enable_first_k_modules(model, end_module_idx: int): |
|
assert 32 > end_module_idx >= 0 |
|
new_modules = [] |
|
new_idx = 0 |
|
for idx in range(0, end_module_idx + 1): |
|
module = model.model.original_layers[idx] |
|
module.layer_idx = new_idx |
|
module.self_attn.layer_idx = new_idx |
|
new_modules.append(module) |
|
new_idx += 1 |
|
print(module.layer_idx) |
|
|
|
model.model.layers = nn.ModuleList(new_modules) |
|
|