amd
/

Safetensors
llama
alignment-handbook
Generated from Trainer
Mingyuyang-1 commited on
Commit
943839d
·
verified ·
1 Parent(s): 6223b38

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +131 -56
README.md CHANGED
@@ -3,65 +3,140 @@ base_model: meta-llama/Llama-3.2-1B-Instruct
3
  datasets:
4
  - JunxiongWang/sftdatasetv3
5
  model-index:
6
- - name: HybridInLlama_mla50_mamba50_1B8B_uniform_stage2
7
  results: []
8
  tags:
9
  - alignment-handbook
10
  - generated_from_trainer
11
-
12
  ---
13
 
14
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
- should probably proofread and complete it, then remove this comment. -->
16
-
17
- [<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="200" height="32"/>](https://wandb.ai/mingyuy-university-of-michigan/huggingface/runs/4080enjr)
18
- # HybridInLlama_mla50_mamba50_1B8B_uniform_stage2
19
-
20
- This model is a fine-tuned version of [/home/amd-user/mingyyan/checkpoints/llama3.2_1b-instruct](https://huggingface.co//home/amd-user/mingyyan/checkpoints/llama3.2_1b-instruct) on the /home/amd-user/mingyyan/data/sftdatasetv3 dataset.
21
- It achieves the following results on the evaluation set:
22
- - Loss: 480.8432
23
-
24
- ## Model description
25
-
26
- More information needed
27
-
28
- ## Intended uses & limitations
29
-
30
- More information needed
31
-
32
- ## Training and evaluation data
33
-
34
- More information needed
35
-
36
- ## Training procedure
37
-
38
- ### Training hyperparameters
39
-
40
- The following hyperparameters were used during training:
41
- - learning_rate: 8e-05
42
- - train_batch_size: 12
43
- - eval_batch_size: 2
44
- - seed: 42
45
- - distributed_type: multi-GPU
46
- - num_devices: 8
47
- - gradient_accumulation_steps: 2
48
- - total_train_batch_size: 192
49
- - total_eval_batch_size: 16
50
- - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
51
- - lr_scheduler_type: cosine
52
- - lr_scheduler_warmup_ratio: 0.01
53
- - num_epochs: 1
54
-
55
- ### Training results
56
-
57
- | Training Loss | Epoch | Step | Validation Loss |
58
- |:-------------:|:-----:|:-----:|:---------------:|
59
- | 391.9548 | 1.0 | 13865 | 480.8432 |
60
-
61
-
62
- ### Framework versions
63
-
64
- - Transformers 4.43.1
65
- - Pytorch 2.7.0a0+git6374332
66
- - Datasets 2.21.0
67
- - Tokenizers 0.19.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  datasets:
4
  - JunxiongWang/sftdatasetv3
5
  model-index:
6
+ - name: Zebra-Llama-1B-8MLA-8Mamba-SFT
7
  results: []
8
  tags:
9
  - alignment-handbook
10
  - generated_from_trainer
11
+ license: apache-2.0
12
  ---
13
 
14
+ # Zebra-Llama: Towards Extremely Efficient Hybrid Models
15
+ Zebra-Llama is a family of hybrid large language models (LLMs) proposed by AMD that composes Multi-head Latent Attention (MLA) and Mamba2 for KV cache compression and computational efficiency.
16
+ Thus combination achieves Transformer-level accuracy with near-State Space Model (SSM) efficiency. While standard Transformers are limited by the quadratic complexity of self-attention and the large memory footprint of their key-value (KV) cache, Zebra-Llama offers a practical and scalable solution.
17
+
18
+ This model, `Zebra-Llama-1B-4MLA-12M2-SFT`, is created by efficiently adapting the pre-trained `Llama-3.2-1B-Instruct` model conducted post-training on AMD Instinct&trade; MI300X GPUs. This training approach bypasses the need for costly pre-training from scratch.
19
+
20
+ <div align="center">
21
+ <img src="comparison.png" width="570" height="380" style="object-fit: contain;"/>
22
+ <em><b>Figure 1:</b> Comparing 8B-scale models on average LM Harness score vs. KV cache size. Zebra-Llama (green) matches or exceeds baselines with smaller KV cache and fewer training tokens. Circle and square sizes indicate training tokens (billions for post-training, trillions for pre-training).</em>
23
+ </div>
24
+
25
+ ## Key Takeaways
26
+ - Announcing Zebra-Llama, a family of highly efficient 1B, 3B, and 8B hybrid models created by post-training adaptation of existing state-of-the-art Transformers.
27
+ - Extreme KV Cache Compression: Zebra-Llama dramatically reduces the KV cache size to 2%-4% of the original Llama model while preserving 100% of its average zero-shot performance on LM Harness tasks.
28
+ - Efficient Hybrid Architecture: Zebra-Llama strategically combines Multi-head Latent Attention (MLA) layers, which compress the KV cache, and Mamba2 (SSM) layers, which eliminate the KV cache entirely, to balance memory usage and performance.
29
+ - Novel Post-Training Pipeline: Zebra-Llama employs an efficient post-training pipeline featuring refined weight initialization, Intermediate Layer Distillation (ILD) for knowledge transfer, and a sensitivity-aware strategy (SMART) for optimal hybrid composition.
30
+
31
+
32
+ ## Model Composition Pipeline
33
+
34
+ The Zebra-Llama models are not trained from scratch. Instead, they are composed from powerful pre-trained Transformers through a lightweight and efficient pipeline. The creation of this model followed these stages:
35
+
36
+ | Stage | Action | Description |
37
+ |-------------------|---------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
38
+ | 1. Base Model | Llama-3.2-1B-Instruct | The starting point is a high-quality, pre-trained Transformer model. |
39
+ | 2. Initialization | Structured Weight Mapping | Pure Mamba2 and MLA models are initialized from the base model's weights using structured mapping techniques (SVD for MLA, reinterpretation for Mamba2). |
40
+ | 3. Refinement | Intermediate Layer Distillation (ILD) | The internal representations of the Mamba2 and MLA models are aligned with the base model's layers on a small dataset to ensure a strong starting point. |
41
+ | 4. Composition | SMART Layer Selection | A hybrid architecture is composed using the SMART (Sensitivity Measure-Aware Replacement of Transformer layers) strategy to optimally place each layer type. |
42
+ | 5. SFT | End-to-End Knowledge Distillation | The composed hybrid model is fine-tuned via knowledge distillation, using an 8B model as a teacher to transfer rich, pre-trained knowledge. |
43
+ | 6. Alignment | Direct Preference Optimization (DPO) | In the final stage, DPO is used to align the model's preferences, with the distilled student model itself serving as the reference model for stability. |
44
+
45
+ ## Getting Started
46
+
47
+ ### Installation
48
+
49
+ ```
50
+ git clone https://github.com/AMD-AIG-AIMA/AMD-Hybrid-Models.git
51
+ ```
52
+ Then follow the installation instruction in `AMD-AIG-AIMA/AMD-Hybrid-Models` repo.
53
+
54
+ ### Example Usage
55
+ Once the installation completed, we can try the following code for a quick test
56
+ ```python
57
+ import torch
58
+ from transformers import AutoModelForCausalLM, AutoTokenizer
59
+ from hybrid_model.hybrid_model_wrapper import HybridTransformerHybridModelWrapper
60
+
61
+ checkpoint = "amd/Zebra-Llama-1B-8MLA-8Mamba-SFT"
62
+
63
+ model = HybridModelWrapper.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).cuda()
64
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
65
+ model.eval()
66
+
67
+ # Format the prompt using the chat template
68
+ prompt = [{"role": "user", "content": "What are the benefits of hybrid language models?"}]
69
+ input_ids = tokenizer.apply_chat_template(
70
+ prompt,
71
+ add_generation_prompt=True,
72
+ return_tensors='pt'
73
+ ).cuda()
74
+
75
+ # Generate a response
76
+ tokens = model.generate(
77
+ input_ids,
78
+ max_new_tokens=256,
79
+ temperature=0.7,
80
+ do_sample=True,
81
+ eos_token_id=tokenizer.eos_token_id
82
+ )
83
+
84
+ print(tokenizer.decode(tokens[0], skip_special_tokens=False))
85
+ ```
86
+
87
+ ### Model details
88
+
89
+ | Model | KV Size | Param | Index of MLA layers | r<sub>kv</sub>| r<sub>q</sub> | d<sub>rope</sub> | d<sub>nope</sub> |
90
+ |-------|--------:|------:|-------------------:|------:|------:|---------:|---------:|
91
+ |Llama-3.2-1B-Instruct | 100% | 1.24B | - | -| -| -| -|
92
+ |Zebra-Llama-1B-8MLA-8M2 | 7.81% | 1.27B | [0,2,4,6,8,10,12,14] | 128 | 1344 | 32 | 32 |
93
+ |Zebra-Llama-1B-4MLA-12M2 | 3.91% | 1.28B | [0,5,10,14] | 128 | 1344 | 32 | 32 |
94
+
95
+ ### Benchmark results
96
+ Zebra-Llama was evaluated on the Language Model Harness benchmark for zero-shot tasks and compared against its base model and other post-training methods. The results demonstrate that Zebra-Llama provides a superior balance of performance and efficiency.
97
+ | Tasks | Metric | Llama-3.2-1B-Instruct | Zebra-Llama-1B-4MLA-12M2-SFT | Zebra-Llama-1B-4MLA-12M2-DPO | Zebra-Llama-1B-8MLA-8M2-SFT | Zebra-Llama-1B-8MLA-8M2-DPO |
98
+ |-------------------|----------|------------------: |----: |----:| ----:|----:|
99
+ | arc_challenge | acc | 0.3575 (±0.0140) | 0.3507 (±0.0139) | 0.3976 (±0.0143)| 0.3456 (±0.0139) | 0.3951 (±0.0143)|
100
+ | | acc_norm | 0.3797 (±0.0142) | 0.3908 (±0.0143) | 0.4232 (±0.0144)| 0.3797 (±0.0142)| 0.4249 (±0.0144)|
101
+ | arc_easy | acc | 0.6843 (±0.0095) | 0.7054 (±0.0094) | 0.7226 (±0.0092)| 0.7092 (±0.0093)| 0.7239 (±0.0092)|
102
+ | | acc_norm | 0.6351 (±0.0099) | 0.6536 (±0.0098) | 0.6696 (±0.0097)| 0.6641 (±0.0097)| 0.6726 (±0.0096)|
103
+ | hellaswag | acc | 0.4506 (±0.005) | 0.4272 (±0.0049) | 0.4399 (±0.005) | 0.4366 (±0.0049)| 0.4527 (±0.0050)|
104
+ | | acc_norm | 0.6077 (±0.0049) | 0.5691 (±0.0049) | 0.5893 (±0.0049) | 0.5816 (±0.0049)| 0.6061 (±0.0049)|
105
+ | mmlu | acc | 0.4609 (±0.0918) | 0.3739 (±0.0736) | 0.3791 (±0.0742)| 0.3940 (±0.0779)| 0.3909 (±0.0756)|
106
+ | - humanities | acc | 0.4397 (±0.0763) | 0.3456 (±0.0583) | 0.3443 (±0.0634)| 0.3694 (±0.0709)| 0.3700 (±0.0684)|
107
+ | - other | acc | 0.5204 (±0.0868) | 0.4184 (±0.0746) | 0.4081 (±0.0707)| 0.4300 (±0.0747)| 0.4258 (±0.0737)|
108
+ | - social_sciences | acc | 0.5109 (±0.0843) | 0.4098 (±0.0758) | 0.4303 (±0.0709)| 0.4348 (±0.0749)| 0.4283 (±0.0727)|
109
+ | - stem | acc | 0.3850 (±0.09) | 0.3375 (±0.0730) | 0.3527 (±0.077)| 0.3555 (±0.0776)| 0.3511 (±0.0746)|
110
+ | openbookqa | acc | 0.244 (±0.0192) | 0.2800 (±0.0201) | 0.302 (±0.0206)| 0.2480 (±0.0193)| 0.3000 (±0.0205)|
111
+ | | acc_norm | 0.35 (±0.0214) | 0.3700 (±0.0216) | 0.406 (±0.022)| 0.3800 (±0.0217)| 0.4180 (±0.0221)|
112
+ | piqa | acc | 0.7405 (±0.0102) | 0.7214 (±0.0105) | 0.7252 (±0.0104)| 0.7252 (±0.0104)| 0.7280 (±0.0104)|
113
+ | | acc_norm | 0.7437 (±0.0102) | 0.7225 (±0.0104) | 0.7296 (±0.0104)| 0.7269 (±0.0104)| 0.7296 (±0.0104)|
114
+ | pubmedqa | acc | 0.602 (±0.0219) | 0.5760 (±0.0221) | 0.566 (±0.0222)| 0.5940 (±0.0220)| 0.5860 (±0.0220)|
115
+ | race | acc | 0.3809 (±0.015)| 0.3445 (±0.0147) | 0.377 (±0.015)| 0.3694 (±0.0149)| 0.3866 (±0.0151)|
116
+ | winogrande | acc | 0.5967 (±0.0138) | 0.5785 (±0.0139) | 0.5888 (±0.0138)| 0.6125 (±0.0137)| 0.6133 (±0.0137)|
117
+
118
+ ## Conclusion
119
+ Zebra-Llama demonstrates a practical and scalable framework for composing highly efficient hybrid models from existing pre-trained Transformers. By intelligently combining MLA and Mamba2 layers, this approach drastically reduces memory requirements and improves inference throughput while preserving the strong capabilities of the original model. This work highlights the viability of post-training hybridization as a cost-effective and environmentally sustainable alternative to full retraining, paving the way for the deployment of powerful LLMs in resource-constrained environments.
120
+
121
+ ## Bias, Risks, and Limitations
122
+ - This model is a research artifact and has not been evaluated for safety in production use cases.
123
+ - The model's performance is dependent on the quality of its pre-trained base model and the teacher model used during distillation. Its capabilities and biases are inherited from these sources.
124
+ - The model may generate content that is factually inaccurate, biased, or otherwise objectionable. Users should be aware of these risks and implement appropriate safeguards for their applications.
125
+ - One limitation of this work is the reliance on a strong teacher model for knowledge transfer, which may not always be available. Distillation from a teacher also adds to the resource requirements during the post-training phase.
126
+
127
+ ## Citation
128
+ If you find this model useful, please consider citing the original paper:
129
+ ```
130
+ @article{yang2025zebra,
131
+ title={Zebra-Llama: Towards Extremely Efficient Hybrid Models},
132
+ author={Yang, Mingyu and Rezagholizadeh, Mehdi and Li, Guihong and Appia, Vikram and Barsoum, Emad},
133
+ journal={arXiv preprint arXiv:2505.17272},
134
+ year={2025}
135
+ }
136
+ @article{li2025x,
137
+ title={X-ecomla: Upcycling pre-trained attention into mla for efficient and extreme kv compression},
138
+ author={Li, Guihong and Rezagholizadeh, Mehdi and Yang, Mingyu and Appia, Vikram and Barsoum, Emad},
139
+ journal={arXiv preprint arXiv:2503.11132},
140
+ year={2025}
141
+ }
142
+ ```