|
from typing import Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import MSELoss |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import os |
|
import time |
|
import os |
|
import copy |
|
import warnings |
|
from datasets import Dataset |
|
from peft import PeftModel |
|
from transformers import TrainerCallback |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import time |
|
import os |
|
import copy |
|
from transformers import Trainer |
|
from typing import Any, Dict, Union |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import flash_gemv |
|
import seaborn as sns |
|
|
|
|
|
|
|
|
|
from utils.utils import ( |
|
is_running_deepspeed, |
|
is_mainprocess, |
|
ds_print, |
|
get_model_type, |
|
get_model_type_from_name, |
|
) |
|
from utils.constants import MISTRAL |
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
from transformers.models.mistral.modeling_mistral import ( |
|
MistralMLP, |
|
MistralDecoderLayer, |
|
MistralConfig, |
|
MistralForCausalLM, |
|
MistralModel, |
|
) |
|
from experiments.models.sparse_mistral.svd_router import ( |
|
low_rank_approximation, |
|
) |
|
|
|
|
|
from transformers.models.llama.modeling_llama import ( |
|
LlamaModel, |
|
LlamaMLP, |
|
LlamaDecoderLayer, |
|
LlamaConfig, |
|
LlamaForCausalLM, |
|
) |
|
|
|
|
|
def get_mlp_class(model): |
|
model_type = get_model_type(model) |
|
return MistralSparseSiluMLP if model_type == MISTRAL else LlamaSparseSiluMLP |
|
|
|
|
|
def get_decoder_class(model): |
|
model_type = get_model_type(model) |
|
return ( |
|
SparseMistralDecoderLayer if model_type == MISTRAL else LlamaSparseDecoderLayer |
|
) |
|
|
|
|
|
def get_model_class(model): |
|
model_type = get_model_type(model) |
|
return MistralModel if model_type == MISTRAL else LlamaModel |
|
|
|
|
|
class SparseSiLU(nn.SiLU): |
|
def __init__(self, threshold): |
|
super(SparseSiLU, self).__init__() |
|
self.threshold = threshold |
|
self.m = nn.Threshold(self.threshold, 0) |
|
|
|
def set_new_threshold(self, threshold): |
|
self.threshold = threshold |
|
self.m = nn.Threshold(threshold, 0) |
|
|
|
def forward(self, x): |
|
act = super(SparseSiLU, self).forward(x) |
|
return self.m(act) - self.m(-act) |
|
|
|
|
|
def get_sparse_config( |
|
config: PretrainedConfig, |
|
model_type: str = None, |
|
use_sparse_model=False, |
|
use_sparse_predictor=False, |
|
use_sparse_regularization=False, |
|
use_graceful_regularization=False, |
|
thresholds=None, |
|
): |
|
if model_type == MISTRAL: |
|
new_config = SparseMistralConfig() |
|
else: |
|
new_config = SparseLlamaConfig() |
|
new_config.__dict__.update(config.__dict__) |
|
config = new_config |
|
config.use_sparse_model = use_sparse_model |
|
config.use_sparse_predictor = use_sparse_predictor |
|
config.use_sparse_regularization = use_sparse_regularization |
|
config.use_graceful_regularization = use_graceful_regularization |
|
config.thresholds = thresholds |
|
|
|
return config |
|
|
|
|
|
def apply_sparse_silu_mlp( |
|
model, |
|
config, |
|
use_sparse_regularization: bool = False, |
|
): |
|
SparseMLP = get_mlp_class(model) |
|
for layer in model.model.layers: |
|
original_mlp = layer.mlp |
|
new_mlp = SparseMLP(config, use_sparse_regularization=use_sparse_regularization) |
|
new_mlp.gate_proj = original_mlp.gate_proj |
|
new_mlp.up_proj = original_mlp.up_proj |
|
new_mlp.down_proj = original_mlp.down_proj |
|
layer.mlp = new_mlp |
|
|
|
|
|
def apply_sparse_decoder_layer( |
|
model, |
|
config, |
|
init_svd: bool = True, |
|
): |
|
Model = get_model_type(model) |
|
SparseMLP = get_mlp_class(model) |
|
DecoderLayer = get_decoder_class(model) |
|
|
|
assert isinstance(model.model, Model), "model.model must be a MistralModel." |
|
new_layers = [] |
|
for layer_idx, layer in enumerate(model.model.layers): |
|
if isinstance(layer.mlp, SparseMLP): |
|
new_layers.append( |
|
DecoderLayer( |
|
config=config, |
|
layer_idx=layer_idx, |
|
decoder_layer=layer, |
|
init_svd=init_svd, |
|
) |
|
) |
|
print(f"{layer_idx}th mlp layer activation: {layer.mlp.sparse_act_fn}") |
|
else: |
|
new_layers.append(layer) |
|
model.model.layers = nn.ModuleList(new_layers) |
|
|
|
|
|
def enable_sparse_predictor( |
|
model, |
|
): |
|
DecoderLayer = get_decoder_class(model) |
|
for layer_idx, layer in enumerate(model.model.layers): |
|
if isinstance(layer, DecoderLayer): |
|
layer.use_sparse_predictor = True |
|
|
|
|
|
def disable_sparse_predictor( |
|
model, |
|
): |
|
DecoderLayer = get_decoder_class(model) |
|
for layer_idx, layer in enumerate(model.model.layers): |
|
if isinstance(layer, DecoderLayer): |
|
layer.use_sparse_predictor = False |
|
|
|
|
|
def activate_stats(model, is_collect_histogram: bool = True): |
|
SparseMLP = get_mlp_class(model) |
|
for layer in model.model.layers: |
|
if isinstance(layer.mlp, SparseMLP): |
|
layer.mlp.activate_stats(is_collect_histogram=is_collect_histogram) |
|
|
|
|
|
def deactivate_stats( |
|
model, |
|
): |
|
SparseMLP = get_mlp_class(model) |
|
for layer in model.model.layers: |
|
if isinstance(layer.mlp, SparseMLP): |
|
layer.mlp.deactivate_stats() |
|
|
|
|
|
def enable_sparse_silu(model): |
|
print("Enabling SparseSilu") |
|
SparseMLP = get_mlp_class(model) |
|
for i, layer in enumerate(model.model.layers): |
|
if isinstance(layer.mlp, SparseMLP): |
|
layer.mlp.kill_sparse_swish_outputs = True |
|
|
|
|
|
def disable_sparse_silu(model): |
|
print("Disabling SparseSilu") |
|
SparseMLP = get_mlp_class(model) |
|
for i, layer in enumerate(model.model.layers): |
|
if isinstance(layer.mlp, SparseMLP): |
|
layer.mlp.kill_sparse_swish_outputs = False |
|
|
|
|
|
def print_dead_neuron_stats(model): |
|
SparseMLP = get_mlp_class(model) |
|
total_sparsity = 0 |
|
counts = 0 |
|
for i, layer in enumerate(model.model.layers): |
|
if isinstance(layer.mlp, SparseMLP): |
|
dead_percentage = layer.mlp.dead_percentage * 100 |
|
agg_sparsity = layer.mlp.agg_sparsity * 100 |
|
ds_print(f"layer {i} sparsity: {dead_percentage:.3f}%") |
|
ds_print(f"layer {i} agg sparsity: {agg_sparsity:.3f}%") |
|
total_sparsity += dead_percentage |
|
counts += 1 |
|
|
|
ds_print(f"Total sparsity: {total_sparsity/counts: .3f}%") |
|
return total_sparsity / counts |
|
|
|
|
|
def get_sparse_layers(model): |
|
SparseMLP = get_mlp_class(model) |
|
sparse_layers = [m.mlp for m in model.layers() if isinstance(m.mlp, SparseMLP)] |
|
return sparse_layers |
|
|
|
|
|
def get_threshold( |
|
bin_edges: torch.tensor, histogram_counts: torch.tensor, sparsity_level: float |
|
): |
|
assert ( |
|
len(bin_edges.shape) == len(histogram_counts.shape) == 1 |
|
), "bin_edges and histogram are expected to be 1-dimensional." |
|
histogram_counts /= histogram_counts.sum() |
|
threshold_idx = torch.searchsorted( |
|
histogram_counts.cumsum(0), sparsity_level, side="right" |
|
) |
|
|
|
return bin_edges[threshold_idx] |
|
|
|
|
|
def set_regularization_threshold(model, threshold: float = 0.1): |
|
SparseMLP = get_mlp_class(model) |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, SparseMLP) and layer.mlp.is_stats |
|
): |
|
layer.mlp.regularization_threshold = threshold |
|
|
|
|
|
def set_sparse_threshold(model, sparsity_level: float, use_relu: bool = False): |
|
SparseMLP = get_mlp_class(model) |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, SparseMLP) and layer.mlp.is_stats |
|
): |
|
if use_relu: |
|
layer.mlp.sparse_act_fn = nn.ReLU() |
|
layer.mlp.use_relu = True |
|
else: |
|
layer.mlp.dead_threshold = get_threshold( |
|
layer.mlp.histogram_bins, |
|
layer.mlp.post_act_hist_counts, |
|
sparsity_level, |
|
) |
|
layer.mlp.sparse_act_fn.set_new_threshold(layer.mlp.dead_threshold) |
|
layer.mlp.regularization_threshold = ( |
|
layer.mlp.dead_threshold * 1.2 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_histogram( |
|
bin_edges, |
|
histogram_counts: torch.tensor, |
|
threshold: float = 0.5, |
|
title: str = "Activation Distribution", |
|
fig_dir: str = "figures", |
|
layer_index: int = 0, |
|
): |
|
if is_mainprocess(): |
|
torch.save(bin_edges, f"{fig_dir}/bin_edges_{layer_index}.pt") |
|
torch.save(histogram_counts, f"{fig_dir}/histogram_counts_{layer_index}.pt") |
|
|
|
fig, ax = plt.subplots() |
|
|
|
|
|
within_threshold_mask = (bin_edges[:-1] >= -threshold) & (bin_edges[:-1] <= threshold) |
|
ax.bar( |
|
bin_edges[:-1][within_threshold_mask][:-1], |
|
histogram_counts[within_threshold_mask][:-1], |
|
width=np.diff(bin_edges[:-1][within_threshold_mask]), |
|
|
|
color="#227CF6", |
|
alpha=0.2, |
|
label="Within Threshold", |
|
) |
|
|
|
|
|
outside_threshold_mask = ~within_threshold_mask |
|
ax.bar( |
|
bin_edges[:-1][outside_threshold_mask][:-1], |
|
histogram_counts[outside_threshold_mask][:-1], |
|
width=np.diff(bin_edges[:-1][outside_threshold_mask]), |
|
|
|
color="#227CF6", |
|
alpha=1.0, |
|
label="Outside Threshold", |
|
clip_on=False, |
|
) |
|
|
|
|
|
ax.axvline( |
|
x=threshold, |
|
color="#227CF6", |
|
alpha=0.6, |
|
linestyle="--", |
|
label="Threshold", |
|
) |
|
|
|
ax.axvline(x=0, color="#227CF6", alpha=0.3, linestyle="--") |
|
|
|
|
|
|
|
ax.set_xlabel("Activation Value") |
|
ax.set_ylabel("Frequency") |
|
|
|
ax.set_xlim(-0.7, 0.7) |
|
|
|
|
|
ax.legend() |
|
|
|
|
|
os.makedirs(fig_dir, exist_ok=True) |
|
|
|
|
|
plt.savefig(f"{fig_dir}/{title}.png") |
|
|
|
|
|
|
|
plt.close(fig) |
|
|
|
|
|
def plot_act(model, fig_dir: str = "figures"): |
|
SparseMLP = get_mlp_class(model) |
|
|
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, SparseMLP) and layer.mlp.is_stats |
|
): |
|
|
|
|
|
|
|
plot_title = f"Layer: {i} Post-Activation Absolute Distribution" |
|
plot_histogram( |
|
layer.mlp.histogram_bins, |
|
layer.mlp.post_act_hist_counts, |
|
layer.mlp.dead_threshold, |
|
plot_title, |
|
fig_dir, |
|
layer_index=i, |
|
) |
|
|
|
|
|
def save_act_hist( |
|
model, filename="/scr/jay/models/mistral/pre_finetune/cola_act_hist.pt" |
|
): |
|
SparseMLP = get_mlp_class(model) |
|
os.makedirs(os.path.dirname(filename), exist_ok=True) |
|
act_dict = {} |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, SparseMLP) and layer.mlp.is_stats |
|
): |
|
act_dict[i] = ( |
|
layer.mlp.histogram_bins, |
|
layer.mlp.pre_act_hist_counts, |
|
layer.mlp.post_act_hist_counts, |
|
) |
|
print("Saving activation histograms...\n\n\n") |
|
torch.save(act_dict, filename) |
|
|
|
|
|
def load_act_hist( |
|
model, filename="/scr/jay/models/mistral/pre_finetune/cola_act_hist.pt" |
|
): |
|
assert os.path.exists( |
|
filename |
|
), f"{filename} does not exist when loading pre/post-activation histogram of SparseMistralSiluMLP." |
|
SparseMLP = get_mlp_class(model) |
|
|
|
print("Loading activation histograms...\n\n\n") |
|
|
|
act_dict = torch.load(filename) |
|
for i, layer in enumerate(model.model.layers): |
|
if ( |
|
isinstance(layer.mlp, SparseMLP) and layer.mlp.is_stats |
|
): |
|
( |
|
layer.mlp.histogram_bins, |
|
layer.mlp.pre_act_hist_counts, |
|
layer.mlp.post_act_hist_counts, |
|
) = act_dict[i] |
|
|
|
|
|
def enable_last_k_modules(model, start_module_idx: int): |
|
assert 32 > start_module_idx >= 0 |
|
new_modules = [] |
|
new_idx = 0 |
|
for idx in range(start_module_idx, len(model.model.original_layers)): |
|
module = model.model.original_layers[idx] |
|
module.layer_idx = new_idx |
|
module.self_attn.layer_idx = new_idx |
|
new_modules.append(module) |
|
new_idx += 1 |
|
print(module.layer_idx) |
|
|
|
model.model.layers = nn.ModuleList(new_modules) |
|
|
|
|
|
def enable_first_k_modules(model, end_module_idx: int): |
|
assert 32 > end_module_idx >= 0 |
|
new_modules = [] |
|
new_idx = 0 |
|
for idx in range(0, end_module_idx + 1): |
|
module = model.model.original_layers[idx] |
|
module.layer_idx = new_idx |
|
module.self_attn.layer_idx = new_idx |
|
new_modules.append(module) |
|
new_idx += 1 |
|
print(module.layer_idx) |
|
|
|
model.model.layers = nn.ModuleList(new_modules) |
|
|
|
|
|
|
|
|
|
|
|
class MistralSparseSiluMLP(MistralMLP): |
|
def __init__(self, config, *args, **kwargs): |
|
super().__init__(config) |
|
self.swish_outputs = None |
|
self.relu = nn.ReLU() |
|
self.is_profile = False |
|
|
|
self.kill_sparse_swish_outputs = False |
|
self.dead_percentage = 0 |
|
self.is_stats = False |
|
self.visit_counts = 0 |
|
|
|
|
|
self.dead_threshold = kwargs.pop("dead_threshold", 0) |
|
self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", True) |
|
self.regularization_type = kwargs.pop( |
|
"regularization_type", "L1 regularization" |
|
) |
|
self.regularization_threshold = kwargs.pop("regularization_threshold", 0.5) |
|
self.use_relu = kwargs.pop("use_relu", False) |
|
self.activation_norm = None |
|
|
|
|
|
self.is_collect_histogram = False |
|
num_bins = 1000 |
|
self.histogram_bins = torch.linspace(-1, 1, num_bins - 2) |
|
self.histogram_bins = torch.cat( |
|
[torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])] |
|
) |
|
self.pre_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.abs_post_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.post_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.t = 0 |
|
self.count = 0 |
|
self.agg_sparsity = 0 |
|
|
|
|
|
self.sparse_act_fn = SparseSiLU(threshold=self.dead_threshold) |
|
|
|
def activate_stats(self, is_collect_histogram: bool = True): |
|
self.is_stats = True |
|
self.dead_percentage = 0 |
|
self.visit_counts = 0 |
|
self.is_collect_histogram = is_collect_histogram |
|
self.histogram_counts = torch.zeros(2000) |
|
|
|
def deactivate_stats(self): |
|
self.is_stats = False |
|
|
|
def collect_stats(self, pre_activation, post_activation): |
|
start_time = time.time() |
|
pre_activation = pre_activation.float().cpu().detach() |
|
post_activation = post_activation.float().cpu().detach() |
|
|
|
self.pre_act_hist_counts += torch.histogram( |
|
pre_activation, bins=self.histogram_bins |
|
)[0] |
|
self.post_act_hist_counts += torch.histogram( |
|
torch.abs(post_activation), bins=self.histogram_bins |
|
)[0] |
|
|
|
self.t += time.time() - start_time |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
x, |
|
sp_mask: torch.tensor = None, |
|
): |
|
""" |
|
If kill_sparse_swish_outputs is set to False, this layer functions exactly like a normal MLP layer. |
|
""" |
|
if sp_mask != None: |
|
return self.down_proj( |
|
self.sparse_act_fn(self.gate_proj(x) * sp_mask) * self.up_proj(x) |
|
) |
|
|
|
if self.is_profile: |
|
if x.shape[1] == 1: |
|
if self.sp_method == 1: |
|
return flash_gemv.flag_gemv_gemv_inner_bf16( |
|
x, |
|
self.gate_proj.weight, |
|
self.up_proj.weight, |
|
self.down_proj.weight, |
|
self.dead_threshold, |
|
) |
|
elif self.sp_method == 2: |
|
return flash_gemv.gemv_gemv_triton( |
|
x, |
|
self.act_fn(self.gate_proj(x)), |
|
self.up_proj.weight, |
|
self.wdown_t, |
|
self.dead_threshold, |
|
) |
|
else: |
|
post_act = self.act_fn(self.gate_proj(x)) |
|
dead_neurons = post_act.abs() <= self.dead_threshold |
|
post_act[dead_neurons] = 0 |
|
return self.down_proj(post_act * self.up_proj(x)) |
|
else: |
|
post_act = self.act_fn(self.gate_proj(x)) |
|
dead_neurons = post_act.abs() <= self.dead_threshold |
|
post_act[dead_neurons] = 0 |
|
return self.down_proj(post_act * self.up_proj(x)) |
|
|
|
elif self.use_relu: |
|
post_act = self.relu(self.gate_proj(x)) |
|
self.count += 1 |
|
if self.count <= 1: |
|
print("USING RELU!!!!") |
|
|
|
if self.is_stats: |
|
dead_neurons = post_act == 0 |
|
dead_percentage = dead_neurons.float().mean() |
|
agg_sparsity = dead_neurons.all(dim=0).float().mean() |
|
|
|
self.dead_percentage = ( |
|
self.dead_percentage * self.visit_counts + dead_percentage |
|
) / (self.visit_counts + 1) |
|
self.agg_sparsity = ( |
|
self.agg_sparsity * self.visit_counts + agg_sparsity |
|
) / (self.visit_counts + 1) |
|
self.visit_counts += 1 |
|
|
|
return self.down_proj(post_act * self.up_proj(x)) |
|
|
|
else: |
|
self.count += 1 |
|
if self.count <= 1: |
|
ds_print("USING SparseSILU!!!!") |
|
pre_act = self.gate_proj(x) |
|
post_act = self.act_fn(pre_act) |
|
if self.kill_sparse_swish_outputs: |
|
dead_neurons = post_act.abs() <= self.dead_threshold |
|
|
|
|
|
dead_percentage = dead_neurons.float().mean() |
|
agg_sparsity = dead_neurons.all(dim=0).float().mean() |
|
|
|
if self.is_stats: |
|
self.dead_percentage = ( |
|
self.dead_percentage * self.visit_counts + dead_percentage |
|
) / (self.visit_counts + 1) |
|
self.agg_sparsity = ( |
|
self.agg_sparsity * self.visit_counts + agg_sparsity |
|
) / (self.visit_counts + 1) |
|
self.visit_counts += 1 |
|
|
|
self.a = dead_percentage |
|
|
|
|
|
if ( |
|
self.is_collect_histogram |
|
and pre_act.eq(0).float().mean() < 0.99 |
|
): |
|
self.collect_stats(pre_act, post_act) |
|
|
|
if self.count <= 1: |
|
ds_print("KILL!") |
|
post_act[dead_neurons] = 0 |
|
|
|
out = self.down_proj(post_act * self.up_proj(x)) |
|
if self.use_sparse_regularization: |
|
if self.regularization_type == "L1 regularization": |
|
self.activation_norm = torch.abs(post_act)[ |
|
torch.abs(post_act) < self.regularization_threshold |
|
].mean() |
|
elif self.regularization_type == "L2 regularization": |
|
self.activation_norm = torch.sqrt( |
|
torch.square(post_act)[ |
|
torch.abs(post_act) < self.regularization_threshold |
|
] |
|
).mean() |
|
|
|
return out |
|
|
|
|
|
class SparseMistralDecoderLayer(MistralDecoderLayer): |
|
def __init__( |
|
self, |
|
config: MistralConfig, |
|
layer_idx: int, |
|
decoder_layer: MistralDecoderLayer, |
|
init_svd: bool = True, |
|
*args, |
|
**kwargs, |
|
): |
|
assert isinstance( |
|
decoder_layer.mlp, MistralSparseSiluMLP |
|
), f"{type(decoder_layer.mlp)} should MistralSparseSiluMLP." |
|
|
|
super().__init__(config, layer_idx) |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
|
|
self.init_svd = init_svd |
|
self.self_attn = decoder_layer.self_attn |
|
|
|
self.mlp = decoder_layer.mlp |
|
self.input_layernorm = decoder_layer.input_layernorm |
|
self.post_attention_layernorm = decoder_layer.post_attention_layernorm |
|
|
|
|
|
self.low_rank = kwargs.pop("low_rank", 64) |
|
self.sparse_act_func = decoder_layer.mlp.sparse_act_fn |
|
|
|
print( |
|
f"Setting {layer_idx}th mlp layer's sparse predictor... svd init: {init_svd}" |
|
) |
|
self.sp_mlp = low_rank_approximation( |
|
decoder_layer.mlp.gate_proj, |
|
act_func=self.sparse_act_func, |
|
init_svd=init_svd, |
|
) |
|
self.use_async = kwargs.pop("use_async", False) |
|
self.use_sparse_predictor = False |
|
self.distill_loss = None |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
) -> Tuple[ |
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
|
]: |
|
print("hidden_states shape: ", hidden_states.shape) |
|
if "padding_mask" in kwargs: |
|
warnings.warn( |
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
|
) |
|
|
|
residual = hidden_states |
|
sp_mask = None |
|
|
|
if self.use_async: |
|
sp_mask = self.sp_mlp(hidden_states) |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
if not self.use_async: |
|
sp_mask = self.sp_mlp(hidden_states) |
|
|
|
|
|
gating_output = self.mlp.sparse_act_fn(self.mlp.gate_proj(hidden_states)) |
|
loss_func = MSELoss() |
|
self.distill_loss = loss_func(sp_mask, gating_output) |
|
|
|
|
|
sp_mask = sp_mask > 0 |
|
|
|
if self.training: |
|
sp_mask = None |
|
|
|
|
|
|
|
hidden_states = self.mlp(hidden_states, sp_mask) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
|
|
class SparseMistralConfig(MistralConfig): |
|
model_type = "sparse_mistral" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class SparseMistralforCausalLM(MistralForCausalLM): |
|
config_class = SparseMistralConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
if config.use_sparse_model: |
|
self.apply_sparse_mlp() |
|
if config.thresholds is not None: |
|
for idx, m in enumerate(self.model.layers): |
|
if isinstance(m.mlp, MistralSparseSiluMLP): |
|
m.mlp.dead_threshold = config.thresholds[idx] |
|
m.mlp.sparse_act_fn.set_new_threshold(m.mlp.dead_threshold) |
|
m.mlp.kill_sparse_swish_outputs = True |
|
m.mlp.use_relu = config.use_relu |
|
if config.use_sparse_predictor: |
|
self.apply_sparse_predictor(init_svd=config.init_svd) |
|
|
|
def apply_sparse_mlp(self): |
|
apply_sparse_silu_mlp( |
|
self, |
|
config=self.config, |
|
use_sparse_regularization=self.config.use_sparse_regularization, |
|
) |
|
|
|
def apply_sparse_predictor(self, init_svd: bool = True): |
|
apply_sparse_decoder_layer(self, config=self.config, init_svd=init_svd) |
|
|
|
|
|
|
|
|
|
|
|
class SparseLlamaConfig(LlamaConfig): |
|
model_type = "sparse_llama" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class SparseLlamaForCausalLM(LlamaForCausalLM): |
|
config_class = SparseLlamaConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
if config.use_sparse_model: |
|
self.apply_sparse_mlp() |
|
if config.thresholds is not None: |
|
for idx, m in enumerate(self.model.layers): |
|
if isinstance(m.mlp, LlamaSparseSiluMLP): |
|
m.mlp.dead_threshold = config.thresholds[idx] |
|
m.mlp.sparse_act_fn.set_new_threshold(m.mlp.dead_threshold) |
|
m.mlp.kill_sparse_swish_outputs = True |
|
m.mlp.use_relu = config.use_relu |
|
if config.use_sparse_predictor: |
|
self.apply_sparse_predictor(init_svd=config.init_svd) |
|
|
|
def apply_sparse_mlp(self): |
|
apply_sparse_silu_mlp( |
|
self, |
|
config=self.config, |
|
use_sparse_regularization=self.config.use_sparse_regularization, |
|
) |
|
|
|
def apply_sparse_predictor(self, init_svd: bool = True): |
|
apply_sparse_decoder_layer(self, config=self.config, init_svd=init_svd) |
|
|
|
|
|
class LlamaSparseSiluMLP(LlamaMLP): |
|
def __init__(self, config, *args, **kwargs): |
|
super().__init__(config) |
|
self.swish_outputs = None |
|
self.relu = nn.ReLU() |
|
self.is_profile = False |
|
|
|
self.kill_sparse_swish_outputs = False |
|
self.dead_percentage = 0 |
|
self.is_stats = False |
|
self.visit_counts = 0 |
|
|
|
|
|
self.dead_threshold = kwargs.pop("dead_threshold", 0) |
|
self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", True) |
|
self.regularization_type = kwargs.pop( |
|
"regularization_type", "L1 regularization" |
|
) |
|
self.regularization_threshold = kwargs.pop("regularization_threshold", 0.5) |
|
self.use_relu = kwargs.pop("use_relu", False) |
|
self.activation_norm = None |
|
|
|
|
|
self.is_collect_histogram = False |
|
num_bins = 1000 |
|
self.histogram_bins = torch.linspace(-1, 1, num_bins - 2) |
|
self.histogram_bins = torch.cat( |
|
[torch.tensor([-torch.inf]), self.histogram_bins, torch.tensor([torch.inf])] |
|
) |
|
self.pre_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.abs_post_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.post_act_hist_counts = torch.zeros(num_bins - 1) |
|
self.t = 0 |
|
self.count = 0 |
|
self.agg_sparsity = 0 |
|
|
|
|
|
self.sparse_act_fn = SparseSiLU(threshold=self.dead_threshold) |
|
|
|
def activate_stats(self, is_collect_histogram: bool = True): |
|
self.is_stats = True |
|
self.dead_percentage = 0 |
|
self.visit_counts = 0 |
|
self.is_collect_histogram = is_collect_histogram |
|
self.histogram_counts = torch.zeros(2000) |
|
|
|
def deactivate_stats(self): |
|
self.is_stats = False |
|
|
|
def collect_stats(self, pre_activation, post_activation): |
|
start_time = time.time() |
|
pre_activation = pre_activation.float().cpu().detach() |
|
post_activation = post_activation.float().cpu().detach() |
|
|
|
self.pre_act_hist_counts += torch.histogram( |
|
pre_activation, bins=self.histogram_bins |
|
)[0] |
|
self.post_act_hist_counts += torch.histogram( |
|
torch.abs(post_activation), bins=self.histogram_bins |
|
)[0] |
|
|
|
self.t += time.time() - start_time |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
x, |
|
sp_mask: torch.tensor = None, |
|
): |
|
""" |
|
If kill_sparse_swish_outputs is set to False, this layer functions exactly like a normal MLP layer. |
|
""" |
|
if sp_mask != None: |
|
return self.down_proj( |
|
self.sparse_act_fn(self.gate_proj(x) * sp_mask) * self.up_proj(x) |
|
) |
|
|
|
if self.is_profile: |
|
if x.shape[1] == 1: |
|
if self.sp_method == 1: |
|
return flash_gemv.flag_gemv_gemv_inner_bf16( |
|
x, |
|
self.gate_proj.weight, |
|
self.up_proj.weight, |
|
self.down_proj.weight, |
|
self.dead_threshold, |
|
) |
|
elif self.sp_method == 2: |
|
return flash_gemv.gemv_gemv_triton( |
|
x, |
|
self.act_fn(self.gate_proj(x)), |
|
self.up_proj.weight, |
|
self.wdown_t, |
|
self.dead_threshold, |
|
) |
|
else: |
|
post_act = self.act_fn(self.gate_proj(x)) |
|
dead_neurons = post_act.abs() <= self.dead_threshold |
|
post_act[dead_neurons] = 0 |
|
return self.down_proj(post_act * self.up_proj(x)) |
|
else: |
|
post_act = self.act_fn(self.gate_proj(x)) |
|
dead_neurons = post_act.abs() <= self.dead_threshold |
|
post_act[dead_neurons] = 0 |
|
return self.down_proj(post_act * self.up_proj(x)) |
|
|
|
elif self.use_relu: |
|
post_act = self.relu(self.gate_proj(x)) |
|
self.count += 1 |
|
|
|
if self.is_stats: |
|
dead_neurons = post_act == 0 |
|
dead_percentage = dead_neurons.float().mean() |
|
agg_sparsity = dead_neurons.all(dim=0).float().mean() |
|
|
|
self.dead_percentage = ( |
|
self.dead_percentage * self.visit_counts + dead_percentage |
|
) / (self.visit_counts + 1) |
|
self.agg_sparsity = ( |
|
self.agg_sparsity * self.visit_counts + agg_sparsity |
|
) / (self.visit_counts + 1) |
|
self.visit_counts += 1 |
|
|
|
return self.down_proj(post_act * self.up_proj(x)) |
|
|
|
else: |
|
self.count += 1 |
|
pre_act = self.gate_proj(x) |
|
post_act = self.act_fn(pre_act) |
|
if self.kill_sparse_swish_outputs: |
|
dead_neurons = post_act.abs() <= self.dead_threshold |
|
dead_percentage = dead_neurons.float().mean() |
|
agg_sparsity = dead_neurons.all(dim=0).float().mean() |
|
|
|
if self.is_stats: |
|
self.dead_percentage = ( |
|
self.dead_percentage * self.visit_counts + dead_percentage |
|
) / (self.visit_counts + 1) |
|
self.agg_sparsity = ( |
|
self.agg_sparsity * self.visit_counts + agg_sparsity |
|
) / (self.visit_counts + 1) |
|
self.visit_counts += 1 |
|
|
|
self.a = dead_percentage |
|
|
|
|
|
|
|
if self.is_collect_histogram: |
|
self.collect_stats(pre_act, post_act) |
|
|
|
if self.count <= 1: |
|
ds_print("KILL!") |
|
post_act[dead_neurons] = 0 |
|
|
|
out = self.down_proj(post_act * self.up_proj(x)) |
|
if self.use_sparse_regularization: |
|
if self.regularization_type == "L1 regularization": |
|
self.activation_norm = torch.abs(post_act)[ |
|
torch.abs(post_act) < self.regularization_threshold |
|
].mean() |
|
elif self.regularization_type == "L2 regularization": |
|
self.activation_norm = torch.sqrt( |
|
torch.square(post_act)[ |
|
torch.abs(post_act) < self.regularization_threshold |
|
] |
|
).mean() |
|
|
|
return out |
|
|
|
|
|
class LlamaSparseDecoderLayer(LlamaDecoderLayer): |
|
def __init__( |
|
self, |
|
config: LlamaConfig, |
|
layer_idx: int, |
|
decoder_layer: LlamaDecoderLayer, |
|
init_svd: bool = True, |
|
*args, |
|
**kwargs, |
|
): |
|
assert isinstance( |
|
decoder_layer.mlp, LlamaSparseSiluMLP |
|
), f"{type(decoder_layer.mlp)} should be LlamaSparseSiluMLP." |
|
|
|
super().__init__(config, layer_idx) |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
|
|
self.init_svd = init_svd |
|
self.self_attn = decoder_layer.self_attn |
|
|
|
self.mlp = decoder_layer.mlp |
|
self.input_layernorm = decoder_layer.input_layernorm |
|
self.post_attention_layernorm = decoder_layer.post_attention_layernorm |
|
|
|
|
|
self.low_rank = kwargs.pop("low_rank", 64) |
|
self.sparse_act_func = decoder_layer.mlp.sparse_act_fn |
|
|
|
print( |
|
f"Setting {layer_idx}th mlp layer's sparse predictor... svd init: {init_svd}" |
|
) |
|
self.sp_mlp = low_rank_approximation( |
|
decoder_layer.mlp.gate_proj, |
|
act_func=self.sparse_act_func, |
|
init_svd=init_svd, |
|
) |
|
self.use_async = kwargs.pop("use_async", False) |
|
self.use_sparse_predictor = False |
|
self.distill_loss = None |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
**kwargs, |
|
) -> Tuple[ |
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
|
]: |
|
print("hidden_states shape: ", hidden_states.shape) |
|
if "padding_mask" in kwargs: |
|
warnings.warn( |
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
|
) |
|
|
|
residual = hidden_states |
|
sp_mask = None |
|
|
|
if self.use_async: |
|
sp_mask = self.sp_mlp(hidden_states) |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
**kwargs, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
if not self.use_async: |
|
sp_mask = self.sp_mlp(hidden_states) |
|
|
|
|
|
gating_output = self.mlp.sparse_act_fn(self.mlp.gate_proj(hidden_states)) |
|
loss_func = MSELoss() |
|
self.distill_loss = loss_func(sp_mask, gating_output) |
|
|
|
|
|
sp_mask = sp_mask > 0 |
|
|
|
if self.training: |
|
sp_mask = None |
|
|
|
|
|
|
|
hidden_states = self.mlp(hidden_states, sp_mask) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
class GracefulRegularizationScheduler(TrainerCallback): |
|
def __init__( |
|
self, |
|
num_warmup_steps=40, |
|
is_enabled: bool = False, |
|
model_name: str = "mistral", |
|
test_dataset: Dataset = None, |
|
targeted_sparsity: float = 0.5, |
|
keep_regularization_with_kill: bool = False, |
|
): |
|
"""Scheduler for regularizing the model first before applying the dead threshold. |
|
|
|
:param num_warmup_steps: number of training steps required to reach the dead threshold, defaults to 40 |
|
:param increment_ratio: by how much to increase the dead threshold. |
|
For example, 0.5 means "increase the threshold by 0.5 * desired threshold |
|
""" |
|
self.num_warmup_steps = num_warmup_steps |
|
self.is_enabled = is_enabled |
|
self.model_name = model_name |
|
self.test_dataset = test_dataset |
|
self.targeted_sparsity = targeted_sparsity |
|
self.keep_regularization_with_kill = keep_regularization_with_kill |
|
self.act_hist_path = ( |
|
f"/scr/lukeai/histograms/warm_up_reg_{targeted_sparsity}/act_hist.pt" |
|
) |
|
if self.is_enabled: |
|
print("GracefulRegularizationScheduler is enabled.") |
|
self.trainer = None |
|
|
|
def set_trainer(self, trainer): |
|
self.trainer = trainer |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
if not self.is_enabled: |
|
return |
|
|
|
model = kwargs["model"] |
|
if isinstance(model, PeftModel): |
|
base_model = model.get_base_model() |
|
else: |
|
base_model = model |
|
|
|
if state.global_step == 1: |
|
ds_print("Setting an initial reg threshold to 0.1") |
|
set_regularization_threshold(base_model, 0.1) |
|
disable_sparse_silu(base_model) |
|
|
|
if state.global_step == self.num_warmup_steps: |
|
activate_stats(base_model) |
|
enable_sparse_silu(base_model) |
|
self.trainer.evaluate() |
|
save_act_hist(base_model, self.act_hist_path) |
|
set_sparse_threshold(base_model, self.targeted_sparsity, False) |
|
deactivate_stats(base_model) |
|
self.trainer.use_sparse_regularization = self.keep_regularization_with_kill |
|
print_dead_neuron_stats(model.get_base_model()) |
|
|
|
|
|
class GradualSparsificationScheduler(TrainerCallback): |
|
def __init__( |
|
self, |
|
num_warmup_steps=40, |
|
increment_ratio=0.5, |
|
is_enabled: bool = False, |
|
model_name: str = "mistral", |
|
): |
|
"""Scheduler for gradually increasing a dead threshold until it reaches the desired threshold. |
|
|
|
:param num_warmup_steps: number of training steps required to reach the dead threshold, defaults to 40 |
|
:param increment_ratio: by how much to increase the dead threshold. |
|
For example, 0.5 means "increase the threshold by 0.5 * desired threshold |
|
""" |
|
self.num_warmup_steps = num_warmup_steps |
|
self.increment_ratio = increment_ratio |
|
self.step_size = int(num_warmup_steps * increment_ratio) |
|
self.is_enabled = is_enabled |
|
self.model_name = model_name |
|
self.model_type = get_model_type(model_name) |
|
self.mlp_type = ( |
|
MistralSparseSiluMLP if self.model_type == MISTRAL else LlamaSparseSiluMLP |
|
) |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
model = kwargs["model"] |
|
|
|
if not self.is_enabled: |
|
if state.global_step <= 10: |
|
for module in model.modules(): |
|
if isinstance(module, self.mlp_type): |
|
module.current_dead_threshold = module.dead_threshold |
|
return |
|
|
|
current_dead_threshold = 0 |
|
desired_dead_threshold = 0 |
|
|
|
if is_mainprocess(): |
|
ds_print(state.global_step) |
|
|
|
if state.global_step % self.step_size == 2: |
|
for module in model.modules(): |
|
if isinstance(module, self.mlp_type): |
|
desired_dead_threshold = copy.deepcopy(module.dead_threshold) |
|
current_dead_threshold = module.current_dead_threshold |
|
current_dead_threshold += ( |
|
self.increment_ratio * desired_dead_threshold |
|
) |
|
module.current_dead_threshold = min( |
|
desired_dead_threshold, current_dead_threshold |
|
) |
|
|
|
if is_running_deepspeed and is_mainprocess(): |
|
ds_print( |
|
state.global_step, |
|
current_dead_threshold, |
|
desired_dead_threshold, |
|
) |
|
|
|
if state.global_step % 2000 == 0: |
|
if is_running_deepspeed and is_mainprocess(): |
|
ds_print( |
|
f"Saving to /matx/u/lukeai/{self.model_name}_{state.global_step - 2}.pt", |
|
) |
|
torch.save( |
|
model.state_dict(), |
|
f"/matx/u/lukeai/{self.model_name}_{state.global_step - 2}.pt", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
class SparseTrainer(Trainer): |
|
def __init__(self, *args, **kwargs): |
|
self.regularization_coefficient = kwargs.pop("regularization_coefficient", 10) |
|
self.use_sparse_regularization = kwargs.pop("use_sparse_regularization", False) |
|
self.use_spm_loss = False |
|
self.freeze_original_weights = False |
|
self.regularization_type = kwargs.pop( |
|
"regularization_type", "L1 positive activation" |
|
) |
|
assert self.regularization_type in [ |
|
"L2 activation", |
|
"L1 positive activation", |
|
], f"Invalid regularization type: {self.regularization_type}" |
|
self.sparse_layers = [] |
|
self.sparse_decoder_layers = [] |
|
super(SparseTrainer, self).__init__(*args, **kwargs) |
|
|
|
def initialize_sparse_silu_layers(self, model): |
|
SparseMLP = get_mlp_class(model) |
|
self.sparse_layers = [m for m in model.modules() if isinstance(m, SparseMLP)] |
|
|
|
def initialize_sparse_decoder_layers(self, model): |
|
SparseDecoder = get_decoder_class(model) |
|
self.sparse_decoder_layers = [ |
|
m for m in model.modules() if isinstance(m, SparseDecoder) |
|
] |
|
|
|
def training_step( |
|
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] |
|
) -> torch.Tensor: |
|
""" |
|
Override the huggingface's training_step function to add a regularization term. |
|
A regularization term is computed with intermediate values, which are freed after "backward()." |
|
You need to set `retain_graph=True` inside `backward` function to keep the values. |
|
""" |
|
model.train() |
|
inputs = self._prepare_inputs(inputs) |
|
|
|
with self.compute_loss_context_manager(): |
|
loss = self.compute_loss(model, inputs) |
|
|
|
if self.args.n_gpu > 1: |
|
loss = loss.mean() |
|
|
|
if not self.freeze_original_weights: |
|
if loss is not None: |
|
self.accelerator.backward(loss, retain_graph=True) |
|
|
|
if self.use_sparse_regularization: |
|
regularization_loss = self.compute_regularization(model) |
|
if self.args.n_gpu > 1: |
|
regularization_loss = regularization_loss.mean() |
|
if regularization_loss is not None: |
|
self.accelerator.backward(regularization_loss, retain_graph=True) |
|
loss += regularization_loss |
|
|
|
if self.use_spm_loss: |
|
spm_loss = self.compute_spm_loss(model) |
|
if self.args.n_gpu > 1: |
|
spm_loss = spm_loss.mean() |
|
if spm_loss is not None: |
|
self.accelerator.backward(spm_loss, retain_graph=False) |
|
loss += spm_loss |
|
|
|
return loss.detach() / self.args.gradient_accumulation_steps |
|
|
|
def compute_regularization(self, model): |
|
""" |
|
Compute a sparse regularization loss for SiLU |
|
""" |
|
loss = 0 |
|
if len(self.sparse_layers) == 0: |
|
self.initialize_sparse_silu_layers(model) |
|
num_layers = len(self.sparse_layers) |
|
|
|
for module in self.sparse_layers: |
|
if module.activation_norm is not None: |
|
loss += module.activation_norm |
|
|
|
loss /= num_layers |
|
loss *= self.regularization_coefficient |
|
|
|
if self.state.global_step % 20 == 0 and loss != 0: |
|
print("Negative relularizer loss: ", loss.item()) |
|
return loss |
|
|
|
def compute_spm_loss(self, model): |
|
loss = 0 |
|
if len(self.sparse_decoder_layers) == 0: |
|
self.initialize_sparse_decoder_layers(model) |
|
for module in self.sparse_decoder_layers: |
|
if module.distill_loss != None: |
|
loss += module.distill_loss |
|
if self.state.global_step % 20 == 0 and loss != 0: |
|
print("Sparse Predictor Distillation loss: ", loss.item()) |
|
return loss |
|
|