Text Generation
Transformers
PyTorch
Safetensors
English
stripedhyena
custom_code

Unable to quantize layers one at a time

#8
by abhinavkulkarni - opened

Hi,

I am trying to apply AWQ quantization to this new architecture one layer at a time and running into a problem.

The way it works is as follows:

  1. Pass sample input through the model and catch the input to the first layer
  2. Pass the input through each layer successively while determining optimal quantization parameters
  3. Output of one layer is input to the next one

I have omitted the quantization logic, but the main scaffold is as follows.

import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextStreamer
from accelerate import init_empty_weights, infer_auto_device_map
from datasets import load_dataset
import torch.nn as nn
import gc

model_id = "togethercomputer/StripedHyena-Nous-7B"
# model_id = "meta-llama/Llama-2-7b-hf"

# Config
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)

# Load model on CPU
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
    model_id, config=config, trust_remote_code=True, **kwargs
)

model.eval()

# Tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
except:
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)

tokenizer.pad_token = tokenizer.eos_token

# Load sample dataset
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")

texts = [dataset[i]['text'] for i in range(10)]
samples = [tokenizer.encode(text, max_length=512, truncation=True, padding='max_length') for text in texts]
samples = torch.LongTensor(samples) # Shape = (10, 512)

# Catch the input to the first layer
inps = []
layer_kwargs = {}


class Catcher(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inp, **kwargs):
        inps.append(inp)
        layer_kwargs.update(kwargs)
        raise ValueError  # early exit to break later inference

# patch layer 0 to catch input and kwargs
layers = model.backbone.blocks # For StripedHyena
# layers = model.model.layers # For Llama-2
layers[0] = Catcher(layers[0])
try:
    model(samples.to(next(model.parameters()).device))
except ValueError:  # work with early exit
    pass

layers[0] = layers[0].module  # restore
inps = inps[0]

layers[0] = layers[0].cpu()

# Now pass the input successively through each layer, collecting the output
# which becomes input for the next layer
for i in range(len(layers)):
    print(i)
    layer = layers[i]
    layer = layer.cuda()
    inps = inps.to(next(layer.parameters()).device)  # in case multi-gpu
    layer_kwargs = {k:(v.to(inps.device) if isinstance(v, torch.Tensor) else v) for k,v in layer_kwargs.items()}
    # get output as next layer's input
    inps = layer(inps, **layer_kwargs)[0]
    # Clear GPU memory
    torch.cuda.empty_cache()
    layer = layer.cpu()
    gc.collect()
    torch.cuda.empty_cache()

I get the following error when the input is passed through AttentionBlock layer:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 75
     73 layer_kwargs = {k:(v.to(inps.device) if isinstance(v, torch.Tensor) else v) for k,v in layer_kwargs.items()}
     74 # get output as next layer's input
---> 75 inps = layer(inps, **layer_kwargs)[0]
     76 # Clear GPU memory
     77 torch.cuda.empty_cache()

File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.cache/huggingface/modules/transformers_modules/togethercomputer/StripedHyena-Nous-7B/42777970d603597dadb768705896533eb9556a07/model.py:71, in AttentionBlock.forward(self, u, inference_params, padding_mask, *args, **kwargs)
     64     u = u * padding_mask[..., None]
     66 # for attr in ['lengths_per_sample', 'max_seqlen', 'key_value_memory_dict']:
     67 #     if not hasattr(inference_params, attr):
     68 #         setattr(inference_params, attr, None)
     69 # inference_params.key_value_memory_dict = inference_params.key_value_memory_dict or {}
     70 u = (
---> 71     self.inner_mha_cls(
     72         self.pre_norm(u),
     73         inference_params=inference_params,
     74     )
     75     + u
     76 )
     77 if type(padding_mask) == torch.Tensor:  # guard against bias
     78     u = u * padding_mask[..., None]

File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/envs/llm-awq/lib/python3.10/site-packages/flash_attn/modules/mha.py:563, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, **kwargs)
    551     assert not self.dwconv
    553 kwargs = (
    554     {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
    555     if self.use_flash_attn
    556     else {"key_padding_mask": key_padding_mask, **kwargs}
    557 )
    558 seqlen_offset = (
    559     0
    560     if inference_params is None
    561     else (
    562         inference_params.lengths_per_sample
--> 563         if inference_params.lengths_per_sample is not None
    564         else inference_params.seqlen_offset
    565     )
    566 )
    567 rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
    568 batch, seqlen = x.shape[:2]

AttributeError: 'RecurrentInferenceParams' object has no attribute 'lengths_per_sample'

Please note the same code works for meta-llama/Llama-2-7b-hf.

All the quantization methods - GPTQ, AWQ, etc. - work layer by layer. Can you please help?

Thanks!

Together org

RecurrentInferenceParams handles cache management for Hyena layers only. Since these layers have a constant cache (no kv-cache), RecurrentInferenceParams does not have a .lengths_per_sample attribute.

Can you try setting cache use to False before loading the model:

config.use_cache = False
model = AutoModelForCausalLM.from_pretrained(
    model_id, config=config, trust_remote_code=True, **kwargs
)

Sign up or log in to comment