Edit model card

yujiepan/Meta-Llama-3-8B-Instruct-awq-w4g64-v2

This model applies AutoAWQ on meta-llama/Meta-Llama-3-8B-Instruct.

  • 4-bit asymmetric weight only quantization
  • group_size=64
  • skip last layer FFN
  • calibration set: pileval

Accuracy

model precision wikitext ppl (↓)
meta-llama/Meta-Llama-3-8B-Instruct FP16 10.842
yujiepan/Meta-Llama-3-8B-Instruct-awq-w4g64 w4g64 10.943
yujiepan/Meta-Llama-3-8B-Instruct-awq-w4g64-v2 w4g64, skip last layer's FFN 10.928

Note:

  • Evaluated on lm-evaluation-harness "wikitext" task
  • Wikitext PPL does not guarantee actual accuracy, but helps to check the distortion after quantization.

Usage

model = AutoModelForCausalLM.from_pretrained('<MODEL_ID>', torch_dtype=torch.float16)

Codes

from unittest.mock import patch

import torch

from awq import AutoAWQForCausalLM
from awq.models.llama import LlamaAWQForCausalLM
from transformers import AutoTokenizer

module2fullname = {}


def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
    if modules_to_not_convert is None:
        return linear_layers
    filtered_layers = {}
    for name, linear_layer in linear_layers.items():
        full_name = module2fullname[linear_layer]
        if not any(key in full_name for key in modules_to_not_convert):
            filtered_layers[name] = linear_layer
        else:
            print('Skipping', full_name)
    return filtered_layers


class PatchedLlamaAWQForCausalLM(LlamaAWQForCausalLM):
    @staticmethod
    def get_layers_for_scaling(module, input_feat, module_kwargs):
        print(input_feat.keys())
        layers = []
        # attention input
        if 'self_attn.q_proj' in input_feat:
            layers.append(
                dict(
                    prev_op=module.input_layernorm,
                    layers=[
                        module.self_attn.q_proj,
                        module.self_attn.k_proj,
                        module.self_attn.v_proj,
                    ],
                    inp=input_feat["self_attn.q_proj"],
                    module2inspect=module.self_attn,
                    kwargs=module_kwargs,
                )
            )
        # attention out
        # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
        if 'self_attn.o_proj' in input_feat:
            if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
                layers.append(
                    dict(
                        prev_op=module.self_attn.v_proj,
                        layers=[module.self_attn.o_proj],
                        inp=input_feat["self_attn.o_proj"],
                    )
                )

        if 'mlp.gate_proj' in input_feat:
            # linear 1
            layers.append(
                dict(
                    prev_op=module.post_attention_layernorm,
                    layers=[module.mlp.gate_proj, module.mlp.up_proj],
                    inp=input_feat["mlp.gate_proj"],
                    module2inspect=module.mlp,
                )
            )

        if 'mlp.down_proj' in input_feat:
            # linear 2
            layers.append(
                dict(
                    prev_op=module.mlp.up_proj,
                    layers=[module.mlp.down_proj],
                    inp=input_feat["mlp.down_proj"],
                )
            )
        return layers


quant_config = {
    "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM",
    "modules_to_not_convert": [
        'layers.31.mlp',
    ],
}
with patch('awq.quantize.quantizer.exclude_layers_to_not_quantize', exclude_layers_to_not_quantize):
    model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
    # model_path = 'yujiepan/meta-llama-3-tiny-random'
    model = PatchedLlamaAWQForCausalLM.from_pretrained(model_path, model_type='llama', device_map='cuda')
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    module2fullname = {module: name for name, module in model.named_modules()}
    model.quantize(tokenizer, quant_config=quant_config)
Downloads last month
61
Safetensors
Model size
2.2B params
Tensor type
I32
·
FP16
·