yujiepan's picture
Create README.md
b8747d1 verified
---
library_name: transformers
tags: []
---
# yujiepan/Meta-Llama-3-8B-awq-w4g64-v2
This model applies AutoAWQ on [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B).
- 4-bit asymmetric weight only quantization
- group_size=64
- skip last layer FFN
- calibration set: pileval
## Accuracy
| model | precision | wikitext ppl (↓) |
|-|-|-|
| meta-llama/Meta-Llama-3-8B | FP16 | 9.179 |
| yujiepan/Meta-Llama-3-8B-awq-w4g64 | w4g64 | 9.219 |
| yujiepan/Meta-Llama-3-8B-awq-w4g64-v2 | w4g64, skip last layer's FFN | 9.278 |
Note:
- Evaluated on lm-evaluation-harness "wikitext" task
- Wikitext PPL does not guarantee actual accuracy, but helps to check the distortion after quantization.
## Usage
```python
model = AutoModelForCausalLM.from_pretrained('<MODEL_ID>', torch_dtype=torch.float16)
```
## Codes
```python
from unittest.mock import patch
import torch
from awq import AutoAWQForCausalLM
from awq.models.llama import LlamaAWQForCausalLM
from transformers import AutoTokenizer
module2fullname = {}
def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
if modules_to_not_convert is None:
return linear_layers
filtered_layers = {}
for name, linear_layer in linear_layers.items():
full_name = module2fullname[linear_layer]
if not any(key in full_name for key in modules_to_not_convert):
filtered_layers[name] = linear_layer
else:
print('Skipping', full_name)
return filtered_layers
class PatchedLlamaAWQForCausalLM(LlamaAWQForCausalLM):
@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
print(input_feat.keys())
layers = []
# attention input
if 'self_attn.q_proj' in input_feat:
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if 'self_attn.o_proj' in input_feat:
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
if 'mlp.gate_proj' in input_feat:
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
if 'mlp.down_proj' in input_feat:
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
quant_config = {
"zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM",
"modules_to_not_convert": [
'layers.31.mlp',
],
}
with patch('awq.quantize.quantizer.exclude_layers_to_not_quantize', exclude_layers_to_not_quantize):
model_path = "meta-llama/Meta-Llama-3-8B"
# model_path = 'yujiepan/meta-llama-3-tiny-random'
model = PatchedLlamaAWQForCausalLM.from_pretrained(model_path, model_type='llama', device_map='cuda')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
module2fullname = {module: name for name, module in model.named_modules()}
model.quantize(tokenizer, quant_config=quant_config)
```