Clarification on Output Neuron Pruning Method in "Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time"

#2
by adamgallas - opened

Clarification on Output Neuron Pruning Method in "Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time"

Hello,

I am attempting to replicate the findings of "Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time" and have some questions regarding the methodology for predicting and applying sparsity in the MLP layers of LLMs, specifically for models like llama2 7b.

Both the "Deja Vu" paper and subsequent works, such as "ProSparse: Introducing and Enhancing Intrinsic Activation Sparsity within Large Language Models," propose using two small, low-rank MLP layers to predict the output sparsity level of large MLP layers. These approaches suggest replacing typical activation functions (SiLU or GeLU) with ReLU and applying the Deja Vu method for sparsity prediction.

However, it is unclear how the determination is made regarding which output neurons should be pruned based on the output of the predictor layer. Is it more appropriate to identify the top-k indices of the predictor output as significant, or should a threshold-based method be applied, indicating that any predictor output below zero means the corresponding output of the real MLP + ReLU should be zero?

Code Snippet from Deja Vu's Training:

...
x, y = batch
y = y.float().to(device)
logits = model(x.to(device))
probs = logits.sigmoid()
preds = probs >= 0.5

dif = y.int() - preds.int()
miss = dif > 0.0  # classifier didn't activate target neuron

weight = (y.sum() / y.numel()) + 0.005
loss_weight = y * (1 - weight) + weight
eval["Loss Weight"] += [weight.item()]
eval["Loss"] += [
    torch.nn.functional.binary_cross_entropy(probs, y, weight=loss_weight).item()
]
...

This code suggests a threshold-based approach to neuron pruning.

Deja Vu's Python Modeling:

...
def prepare_fc_weights(self, hidden_states: torch.Tensor):
    with torch.no_grad():
        self.predictor = self.predictor.float()

        _logit = self.predictor(hidden_states.reshape(-1, self.embed_dim).float())
        _, _top_indices = _logit.topk(self.topk, dim=1)
        _top_k_indices = _top_indices[:, :self.topk]
        self._mask = torch.zeros_like(_logit)
        self._mask = self._mask.scatter(1, _top_k_indices, 1).bool().half()
...
    hidden_states = self.fc1(hidden_states)
    if self.predictor != None:
        hidden_states = hidden_states * self._mask
...

In contrast, this snippet utilizes a top-k function for identifying active neurons.

PowerInfer's Implementation Snippet:

...
float *ffdata = (float *)dst->src[2]->data;
int *gid = (int *)dst->src[3]->data;
float *predictor_data = (float *)dst->src[2]->data;
const size_t predictor_row_size = dst->src[2]->ne[0]*ggml_type_size(GGML_TYPE_F32)/ggml_blck_size(GGML_TYPE_F32);
...
    ffdata = (float *)((char *)predictor_data + (i11 + i12*ne11 + i13*ne12*ne11)*predictor_row_size);
    float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));

    if (gid[ir0] == 1 || ffdata[ir0] < threshold) {
        dst_col[ir0] = 0;
        continue;
    }
    vec_dot(ne00, &dst_col[ir0], src0_row + ir0 * nb01, src1_col);
...

This implementation appears to adopt a threshold-based method, yet it's unclear how it aligns with the methods described in Deja Vu or ProSparse.

So how does the predictor work?

Given these observations and my own experimentation—where the top-k method proved effective but without clear guidance on selecting "k" due to the absence of details in ProSparse—I seek clarification on two fronts:

1 What is the recommended method for determining which neurons should be pruned: top-k or threshold-based?

2 How is the "k" value for the top-k method determined in practice, especially considering the variable sparsity levels across different models and tasks?

Any insights or clarifications on these points would be greatly appreciated, as they could significantly enhance the practical application and exploration of these promising sparsity techniques.

Thank you.

SparseLLMs org

Thank you for your attention to our work! The conclusion comes first: if there is no additional explanation, the determination of inactive neurons is based on the zero threshold. (So does ProSparse.)

In the above issue, you provide code snippets from three situations: (1) training of predictors in Deja Vu; (2) collection of training data for predictors in Deja Vu; (3) implementation of predictors in PowerInfer. Both (1) and (3) use the zero-threshold based determination. For code snippet (2), I think the top-k implementation is used for obtaining Figure 6 in Deja Vu paper, which involves the accuracies under different sparsities. Actually, for downstream inference acceleration, we are not likely to involve the top-k operation, mainly due to the bottleneck efficiency of the top-k operator. By contrast, the threshold-based operator is far more efficient.

If you are interested in how we obtain the predictors for ProSparse, we have already released the codes for predictor training at this repo, which clearly adopts the zero-threshold-based paradigm.

Thank you for your prompt response. I truly appreciate it. I concur that the threshold-based method appears to be more efficient than the Topk method, and it seems logical that prosparse and relullama also employ the threshold method. However, while attempting to write a Python modeling code to validate the predictor behavior, I encountered some difficulties. Surprisingly, the threshold method doesn't seem to work as expected, whereas the Topk method does. Hence, I have raised this issue.

I am curious if you would be willing to assist me in identifying the potential bugs in my Python verification code. Your assistance would greatly benefit my forthcoming work. Once again, thank you for your patient and considerate reply.

The code of loading the model

path = "/mnt/gallas/prosparse-llama-2-7b/"
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True)

model.half()
model.to("cuda")

The code of loading the predictor and inserting the predictor into prosparse's MLP (threshold based method)

import torch

class PredictorThreshold(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(4096, 1024, bias=False)
        self.fc2 = nn.Linear(1024, 11008, bias=False)
    
    def forward(self, x):
        t = self.fc2(nn.functional.relu(self.fc1(x)))
        return t >= 0

predictor_list = []
predictor_path = "/mnt/gallas/prosparse-llama-2-7b-predictor"

for i in range(32):
    pred = PredictorThreshold()
    state_dict = torch.load(os.path.join(predictor_path, f"model_{i}.pt"))
    pred.load_state_dict(state_dict)
    pred.to("cuda")
    pred = pred.half()
    predictor_list.append(pred)

def mlp_with_predict(predictor, self, x):
    down_proj = self.down_proj(self.act_fn(predictor(x).to(x.dtype) * self.gate_proj(x)) * self.up_proj(x))
    return down_proj

for i in range(32):
    model.model.layers[0].mlp.forward = types.MethodType(
        functools.partial(mlp_with_predict, predictor_list[i]),
        model.model.layers[0].mlp
    )

The code of loading the predictor and inserting the predictor into prosparse's MLP (TopK based method)

import torch

class PredictorTopk(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(4096, 1024, bias=False)
        self.fc2 = nn.Linear(1024, 11008, bias=False)
    
    def forward(self, x):
        t = self.fc2(nn.functional.relu(self.fc1(x)))
        _, _top_indices = t.topk(1800, dim=1)
        _top_k_indices = _top_indices[:, : 1800]
        _mask = torch.zeros_like(t)
        _mask = _mask.scatter(1, _top_k_indices, 1).bool()
        return _mask

predictor_list = []
predictor_path = "/mnt/gallas/prosparse-llama-2-7b-predictor"

for i in range(32):
    pred = PredictorTopk()
    state_dict = torch.load(os.path.join(predictor_path, f"model_{i}.pt"))
    pred.load_state_dict(state_dict)
    pred.to("cuda")
    pred = pred.half()
    predictor_list.append(pred)

def mlp_with_predict(predictor, self, x):
    down_proj = self.down_proj(self.act_fn(predictor(x).to(x.dtype) * self.gate_proj(x)) * self.up_proj(x))
    return down_proj

for i in range(32):
    model.model.layers[0].mlp.forward = types.MethodType(
        functools.partial(mlp_with_predict, predictor_list[i]),
        model.model.layers[0].mlp
    )

The code for perplexity evaluation

import torch
import torch.nn as nn
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer

def evaluate_perplexity(model, tokenizer):
    def _perplexity(nlls, n_samples, seqlen):
        return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))

    # load and prepare dataset
    data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    data = tokenizer("\n\n".join(data["text"]), return_tensors="pt")
    data = data.input_ids.to(model.device)

    seqlen = 2048
    model = model.eval()
    n_samples = data.numel() // seqlen

    nlls = []

    with tqdm(range(n_samples), desc="Perplexity -") as progress_bar:
        for i in progress_bar:
            start_index = i * seqlen
            end_index = (i + 1) * seqlen
            batch = data[:, start_index:end_index].to(model.device)

            # bug-fix, Add <s> token to the beginning of each context if not present
            for context_idx in range(batch.size(0)):
                if batch[context_idx, 0] != 1:  # Check if <s> token is present
                    batch[context_idx, 1:] = batch[context_idx, :-1].clone()
                    batch[context_idx, 0] = 1  # Add <s> token at the beginning

            with torch.no_grad():
                logits = model(batch).logits
            shift_logits = logits[:, :-1, :].contiguous().float()
            shift_labels = data[:, start_index:end_index][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * seqlen
            nlls.append(neg_log_likelihood)

            curr_ppl = _perplexity(nlls, i + 1, seqlen)
            progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")

    ppl = _perplexity(nlls, n_samples, seqlen)

    return ppl.item()

I conducted a perplexity evaluation for the original ProSparse model, as well as for the Top-K-based and threshold-based variants. The results yielded perplexity scores of 8.2, 11.2, and 33.7, respectively. It appears that the threshold-based method is not performing as expected. Consequently, I'm reaching out to seek assistance.

I'm unsure whether the issue lies within my Python verification code for the predictor behavior, or if there's a problem with the perplexity evaluation code. I utilized the perplexity evaluation code from lm-eval, while incorporate the bug-fix mentioned in your HuggingFace repository.

Thank you for taking the time to address my concern. I hope to hear from you soon.

SparseLLMs org
edited Mar 18

I get a fatal bug in your codes. Actually, the predictor output is not a "mask" applied to the intermediate outputs. Instead, as shown in the following code snippet included in our predictor training codes, the predictor outputs are the intermediate outputs themselves.

image.png

Therefore, maybe you can try to modify your mlp_with_predict as follows:

def mlp_with_predict(predictor, self, x):
    down_proj = self.down_proj(predictor(x).to(x.dtype))
    return down_proj

Besides, it is just ok to return t for PredictorThreshold as well as other activation predictors.

Thank you for your promptly reply :). Thank you for your patience!

I found my stupid bug when replacing the MLP layer in my Python modelling code. I got the index wrong. I fixed the bug and I found that the threshold-based method actually works. Thanks for the help!

However, I don't think the predictor's output are the intermediate results themselves. Although in the training code, the intermediate results are colleteced to train the predictor, in the train_mlp.py, the loss is derived from the diff > 0, but not difference between the predictor's output and the up_proj's output. And intuitively, it is hard to use two 4096x1024 and 1024x11008 linear layer to approximate two 4096x11008 gated linear layer.

            x, y = batch
            y = y.float().to(device)
            y = generate_label(y)
            logits = model(x.to(device))
            probs = logits.sigmoid()
            preds = probs >= 0.5

            dif = y.int() - preds.int()
            miss = dif > 0.0  # classifier didn't activated target neuron

But still, thank you for your help! I really appreiciate it :) Wish you a happy day.

SparseLLMs org

I see, it appears to be my mistake in understanding the actual predictor outputs. Still, I'm glad you have solved the problem!

Raincleared changed discussion status to closed

Sign up or log in to comment