yujiepan commited on
Commit
b8747d1
1 Parent(s): b8157e1

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +124 -0
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # yujiepan/Meta-Llama-3-8B-awq-w4g64-v2
7
+ This model applies AutoAWQ on [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B).
8
+
9
+ - 4-bit asymmetric weight only quantization
10
+ - group_size=64
11
+ - skip last layer FFN
12
+ - calibration set: pileval
13
+
14
+ ## Accuracy
15
+ | model | precision | wikitext ppl (↓) |
16
+ |-|-|-|
17
+ | meta-llama/Meta-Llama-3-8B | FP16 | 9.179 |
18
+ | yujiepan/Meta-Llama-3-8B-awq-w4g64 | w4g64 | 9.219 |
19
+ | yujiepan/Meta-Llama-3-8B-awq-w4g64-v2 | w4g64, skip last layer's FFN | 9.278 |
20
+
21
+ Note:
22
+ - Evaluated on lm-evaluation-harness "wikitext" task
23
+ - Wikitext PPL does not guarantee actual accuracy, but helps to check the distortion after quantization.
24
+
25
+ ## Usage
26
+ ```python
27
+ model = AutoModelForCausalLM.from_pretrained('<MODEL_ID>', torch_dtype=torch.float16)
28
+ ```
29
+
30
+ ## Codes
31
+ ```python
32
+ from unittest.mock import patch
33
+
34
+ import torch
35
+
36
+ from awq import AutoAWQForCausalLM
37
+ from awq.models.llama import LlamaAWQForCausalLM
38
+ from transformers import AutoTokenizer
39
+
40
+ module2fullname = {}
41
+
42
+
43
+ def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
44
+ if modules_to_not_convert is None:
45
+ return linear_layers
46
+ filtered_layers = {}
47
+ for name, linear_layer in linear_layers.items():
48
+ full_name = module2fullname[linear_layer]
49
+ if not any(key in full_name for key in modules_to_not_convert):
50
+ filtered_layers[name] = linear_layer
51
+ else:
52
+ print('Skipping', full_name)
53
+ return filtered_layers
54
+
55
+
56
+ class PatchedLlamaAWQForCausalLM(LlamaAWQForCausalLM):
57
+ @staticmethod
58
+ def get_layers_for_scaling(module, input_feat, module_kwargs):
59
+ print(input_feat.keys())
60
+ layers = []
61
+ # attention input
62
+ if 'self_attn.q_proj' in input_feat:
63
+ layers.append(
64
+ dict(
65
+ prev_op=module.input_layernorm,
66
+ layers=[
67
+ module.self_attn.q_proj,
68
+ module.self_attn.k_proj,
69
+ module.self_attn.v_proj,
70
+ ],
71
+ inp=input_feat["self_attn.q_proj"],
72
+ module2inspect=module.self_attn,
73
+ kwargs=module_kwargs,
74
+ )
75
+ )
76
+ # attention out
77
+ # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
78
+ if 'self_attn.o_proj' in input_feat:
79
+ if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
80
+ layers.append(
81
+ dict(
82
+ prev_op=module.self_attn.v_proj,
83
+ layers=[module.self_attn.o_proj],
84
+ inp=input_feat["self_attn.o_proj"],
85
+ )
86
+ )
87
+
88
+ if 'mlp.gate_proj' in input_feat:
89
+ # linear 1
90
+ layers.append(
91
+ dict(
92
+ prev_op=module.post_attention_layernorm,
93
+ layers=[module.mlp.gate_proj, module.mlp.up_proj],
94
+ inp=input_feat["mlp.gate_proj"],
95
+ module2inspect=module.mlp,
96
+ )
97
+ )
98
+
99
+ if 'mlp.down_proj' in input_feat:
100
+ # linear 2
101
+ layers.append(
102
+ dict(
103
+ prev_op=module.mlp.up_proj,
104
+ layers=[module.mlp.down_proj],
105
+ inp=input_feat["mlp.down_proj"],
106
+ )
107
+ )
108
+ return layers
109
+
110
+
111
+ quant_config = {
112
+ "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM",
113
+ "modules_to_not_convert": [
114
+ 'layers.31.mlp',
115
+ ],
116
+ }
117
+ with patch('awq.quantize.quantizer.exclude_layers_to_not_quantize', exclude_layers_to_not_quantize):
118
+ model_path = "meta-llama/Meta-Llama-3-8B"
119
+ # model_path = 'yujiepan/meta-llama-3-tiny-random'
120
+ model = PatchedLlamaAWQForCausalLM.from_pretrained(model_path, model_type='llama', device_map='cuda')
121
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
122
+ module2fullname = {module: name for name, module in model.named_modules()}
123
+ model.quantize(tokenizer, quant_config=quant_config)
124
+ ```