chhao AlexGeek commited on
Commit
1bfdcc4
Β·
1 Parent(s): 8d84170

Create README.md (#1)

Browse files

- Create README.md (9ef08e8527ad12803487816436ffb4adba5e6e40)


Co-authored-by: Gongxun Li <AlexGeek@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +353 -0
README.md ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: mit
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - weak-driven-learning
7
+ - post-training
8
+ - mathematical-reasoning
9
+ - code-generation
10
+ - qwen3
11
+ arxiv: 2602.08222
12
+ ---
13
+
14
+ # Weak-Driven Learning
15
+
16
+ ## Highlights
17
+
18
+ **Weak-Driven Learning** introduces a novel post-training paradigm that challenges the conventional assumption that learning with weaker models necessarily degrades performance. Key features include:
19
+
20
+ - **Novel Learning Paradigm**: Leverages weak agents (historical model checkpoints) as informative error signals to drive continuous improvement beyond standard supervision saturation.
21
+ - **No Additional Inference Cost**: The enhanced model maintains the same architecture as the base model, requiring no extra computational overhead during inference.
22
+ - **Consistent Performance Gains**: Demonstrates improvements on challenging benchmarks including mathematical reasoning and code generation tasks, compared to standard SFT baselines.
23
+ - **Practical Training Framework**: Implements joint optimization of weak and strong models through logit mixing, preventing gradient vanishing and maintaining effective learning pressure.
24
+
25
+ <p align="center">
26
+ <img src="https://raw.githubusercontent.com/chenzehao82/Weak-Driven-Learning/main/pics/weak-drivenlearning.png" alt="Weak-Driven Learning Framework" width="800"/>
27
+ </p>
28
+
29
+ ## Model Overview
30
+
31
+ This repository contains models trained using the **Weak-Driven Learning** framework, which operationalizes the principle that **weak agents can make strong agents stronger** (WMSS). Unlike knowledge distillation that requires access to stronger teachers, weak-driven learning leverages easily obtainable weak reference models such as historical checkpoints.
32
+
33
+ ### Key Contributions
34
+
35
+ - **Learning Paradigm**: Introduces a post-training approach that highlights the overlooked role of weak agents as driving signals for improving strong agents.
36
+ - **Training Framework**: Proposes joint optimization through logit mixing that compels the strong model to refine its decision boundary and sustain meaningful gradients in saturated regimes.
37
+ - **Theoretical Foundation**: Provides gradient-level analysis demonstrating how incorporating weak-model logits reshapes the optimization landscape and prevents gradient vanishing.
38
+ - **Empirical Validation**: Shows consistent improvements on mathematical reasoning and code generation benchmarks.
39
+
40
+ ### Training Methodology
41
+
42
+ The framework consists of three phases:
43
+
44
+ 1. **Phase 1: Initialization**
45
+ - Prepare the base model and compute initial entropy on training data
46
+ - The base model serves as the "weak agent" in subsequent training
47
+
48
+ 2. **Phase 2: Curriculum Learning with Entropy-Weighted Sampling**
49
+ - Train the first-stage model using entropy-based weighted sampling (BrownBoost-style)
50
+ - Focus on challenging samples where entropy differences are significant
51
+ - This model becomes the "strong agent" for joint training
52
+
53
+ 3. **Phase 3: Joint Training**
54
+ - Jointly train weak and strong models through logit mixing
55
+ - The mechanism prevents gradient vanishing on non-target tokens
56
+ - Extract the enhanced sub-model with improved capabilities
57
+ - **No additional inference cost**: Extracted model has the same architecture as base model
58
+
59
+ <p align="center">
60
+ <img src="https://raw.githubusercontent.com/chenzehao82/Weak-Driven-Learning/main/pics/framework.png" alt="Training Framework" width="1000"/>
61
+ </p>
62
+
63
+ ### Model Specifications
64
+
65
+ This model is trained using the Weak-Driven Learning framework with the following specifications:
66
+
67
+ - **Base Model**: [Qwen3-4B-Base](https://huggingface.co/Qwen/Qwen3-4B-Base)
68
+ - **Type**: Causal Language Model
69
+ - **Number of Parameters**: 4.0B total (3.6B non-embedding)
70
+ - **Architecture**: Qwen3 (Transformer-based)
71
+ - **Number of Layers**: 36
72
+ - **Attention Heads**: 32 for Q, 8 for KV (Grouped Query Attention)
73
+ - **Context Length**: 32,768 tokens (training with max sequence length 8,192)
74
+ - **Training Data**: AM-1.4M dataset (AM-DeepSeek-R1-Distilled, filtered and processed)
75
+ - **Training Hardware**: 8Γ— NVIDIA H800 GPUs
76
+ - **Training Framework**: TRL + Hugging Face Transformers + DeepSpeed
77
+
78
+ **Training Hyperparameters**:
79
+ - Learning rate: 1Γ—10⁻⁡
80
+ - Maximum sequence length: 8,192
81
+ - Weak-Driven Learning parameters: Ξ±=0.1, Ξ²=0.8, Ξ³=0.1
82
+ - Logit mixing coefficient: Ξ»=0.5
83
+
84
+ **Key Dependencies**:
85
+ - `transformers>=4.57.1`
86
+ - `trl>=0.25.1`
87
+ - `torch>=2.8.0`
88
+ - `vllm>=0.11.0` (for inference)
89
+
90
+ ## Model Variants
91
+
92
+ We provide models trained with Weak-Driven Learning on different base models:
93
+
94
+ | Model | Base Model | Parameters | Context Length | Recommended Use |
95
+ |-------|-----------|------------|----------------|-----------------|
96
+ | **Weak-Driven-Learning-4B** | Qwen3-4B-Base | 4.0B | 32K | Mathematical reasoning, code generation, resource-constrained environments |
97
+ | **Weak-Driven-Learning-8B** | Qwen3-8B-Base | 8.0B | 32K | Complex reasoning tasks, advanced code generation |
98
+
99
+ All models are trained using the same three-phase Weak-Driven Learning framework with identical hyperparameters.
100
+
101
+ ## Hardware Requirements
102
+
103
+ ### Inference
104
+
105
+ | Model Size | Minimum VRAM | Recommended VRAM | Precision |
106
+ |------------|--------------|------------------|-----------|
107
+ | 4B | 8GB | 16GB | FP16/BF16 |
108
+ | 8B | 16GB | 24GB | FP16/BF16 |
109
+
110
+ For longer context lengths (>8K tokens), additional memory may be required.
111
+
112
+ ### Training
113
+
114
+ - **Recommended**: 8Γ— NVIDIA H800 (80GB) or A100 (80GB) GPUs
115
+ - **Minimum**: 4Γ— NVIDIA A100 (40GB) GPUs with gradient accumulation
116
+ - DeepSpeed ZeRO-3 optimization recommended for memory efficiency
117
+
118
+ ## Quickstart
119
+
120
+ ### Installation
121
+
122
+ ```bash
123
+ # Clone the repository
124
+ git clone https://github.com/chenzehao82/Weak-Driven-Learning.git
125
+ cd Weak-Driven-Learning
126
+
127
+ # Install dependencies
128
+ pip install -r requirements.txt
129
+ ```
130
+
131
+ ### Inference Example
132
+
133
+ ```python
134
+ from transformers import AutoModelForCausalLM, AutoTokenizer
135
+
136
+ model_name = "chhao/Weak-Driven-Learning"
137
+
138
+ # Load the tokenizer and model
139
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
140
+ model = AutoModelForCausalLM.from_pretrained(
141
+ model_name,
142
+ torch_dtype="auto",
143
+ device_map="auto"
144
+ )
145
+
146
+ # Prepare the model input
147
+ prompt = "Solve the following math problem: If x + 2 = 5, what is x?"
148
+ messages = [
149
+ {"role": "user", "content": prompt}
150
+ ]
151
+ text = tokenizer.apply_chat_template(
152
+ messages,
153
+ tokenize=False,
154
+ add_generation_prompt=True,
155
+ )
156
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
157
+
158
+ # Generate response
159
+ generated_ids = model.generate(
160
+ **model_inputs,
161
+ max_new_tokens=2048,
162
+ temperature=1.0,
163
+ top_p=0.95,
164
+ top_k=40
165
+ )
166
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
167
+
168
+ response = tokenizer.decode(output_ids, skip_special_tokens=True)
169
+ print("Response:", response)
170
+ ```
171
+
172
+ ## Deployment
173
+
174
+ ### Using vLLM
175
+
176
+ [vLLM](https://github.com/vllm-project/vllm) provides high-throughput and memory-efficient inference for LLMs.
177
+
178
+ ```bash
179
+ # Install vLLM
180
+ pip install 'vllm>=0.11.0'
181
+
182
+ # Launch OpenAI-compatible API server
183
+ vllm serve chhao/Weak-Driven-Learning --port 8000 --tensor-parallel-size 2
184
+ ```
185
+
186
+ ### Using the API
187
+
188
+ ```python
189
+ from openai import OpenAI
190
+
191
+ client = OpenAI(
192
+ base_url='http://localhost:8000/v1',
193
+ api_key="EMPTY"
194
+ )
195
+
196
+ messages = [{'role': 'user', 'content': 'Solve: 2x + 3 = 11'}]
197
+
198
+ completion = client.chat.completions.create(
199
+ messages=messages,
200
+ model="chhao/Weak-Driven-Learning",
201
+ max_tokens=2048,
202
+ temperature=1.0,
203
+ top_p=0.95
204
+ )
205
+
206
+ print(completion.choices[0].message.content)
207
+ ```
208
+
209
+ ## Training Your Own Model
210
+
211
+ To train your own model using the Weak-Driven Learning framework:
212
+
213
+ ### 1. Prepare Training Data
214
+
215
+ ```bash
216
+ cd dataprocess
217
+ python am_deepseek_r1_distilled.py
218
+ ```
219
+
220
+ This generates:
221
+ - `am_deepseek_r1_filtered_ad.jsonl` β€” main training data
222
+ - `am_deepseek_r1_filtered_ad_test_1000.jsonl` β€” test subset
223
+
224
+ ### 2. Configure Training Parameters
225
+
226
+ Edit `scripts/run_ensemble.sh`:
227
+ - `GPU_USE`: GPU device IDs
228
+ - `base_model`: Base model path (e.g., `Qwen/Qwen3-4B-Base` or `Qwen/Qwen3-8B-Base`)
229
+ - `outdir`: Output directory for checkpoints
230
+ - Training hyperparameters (learning rate: 1Γ—10⁻⁡, max sequence length: 8,192, etc.)
231
+ - Weak-Driven Learning parameters (Ξ±, Ξ², Ξ³, Ξ»)
232
+
233
+ ### 3. Run the Complete Pipeline
234
+
235
+ ```bash
236
+ cd Weak-Driven-Learning
237
+ bash scripts/run_ensemble.sh
238
+ ```
239
+
240
+ The script automatically executes the three-phase training pipeline:
241
+ 1. Initialize base model and compute initial entropy
242
+ 2. Train first-stage model with curriculum learning
243
+ 3. Jointly train weak and strong models, then extract the enhanced sub-model
244
+
245
+ ## Evaluation Results
246
+
247
+ Our method consistently improves performance on challenging benchmarks compared to standard SFT baselines. These gains arise purely from improved optimization dynamics during training and incur **no additional inference cost**.
248
+
249
+ <p align="center">
250
+ <img src="https://raw.githubusercontent.com/chenzehao82/Weak-Driven-Learning/main/pics/results.png" alt="Evaluation Results" width="600"/>
251
+ </p>
252
+
253
+ ## Best Practices
254
+
255
+ ### Inference Parameters
256
+
257
+ For optimal performance, we recommend the following sampling parameters:
258
+ - `temperature=1.0`
259
+ - `top_p=0.95`
260
+ - `top_k=40`
261
+
262
+ ### Task-Specific Recommendations
263
+
264
+ **Mathematical Reasoning**:
265
+ - Use the model's chat template for structured input
266
+ - Allow sufficient `max_new_tokens` (2048-4096) for detailed reasoning chains
267
+ - The model benefits from step-by-step problem decomposition
268
+
269
+ **Code Generation**:
270
+ - Provide clear problem specifications and constraints
271
+ - Use appropriate context length for complex codebases
272
+ - The model can handle multi-file code generation tasks
273
+
274
+ ## Limitations
275
+
276
+ While Weak-Driven Learning demonstrates consistent improvements, users should be aware of:
277
+
278
+ - **Training Data Dependency**: Performance is influenced by the quality and diversity of the AM-1.4M training dataset
279
+ - **Domain Specificity**: The model is optimized for mathematical reasoning and code generation; performance on other tasks may vary
280
+ - **Computational Requirements**: Training requires significant GPU resources (8Γ— H800 GPUs recommended)
281
+ - **Base Model Constraints**: Inherits limitations from the base Qwen3 model architecture
282
+
283
+ ## Project Structure
284
+
285
+ ```
286
+ Weak-Driven-Learning/
287
+ β”œβ”€β”€ scripts/ # Training pipeline scripts
288
+ β”‚ └── run_ensemble.sh # Complete three-phase training pipeline
289
+ β”œβ”€β”€ ensemble/ # Core training and evaluation
290
+ β”‚ β”œβ”€β”€ ensemble_train.py # Joint training implementation
291
+ β”‚ β”œβ”€β”€ run_entropy.py # Entropy computation
292
+ β”‚ β”œβ”€β”€ extract_submodel.py # Extract enhanced sub-model
293
+ β”‚ └── eval_vllm_thinking_math.py # Evaluation script
294
+ β”œβ”€β”€ utils/ # Model fusion, entropy, and data processing
295
+ β”‚ β”œβ”€β”€ fuse_models.py # Logit mixing and model fusion (WMSS)
296
+ β”‚ β”œβ”€β”€ compute_entropy.py # Entropy computation algorithms
297
+ β”‚ └── weight_datasets.py # Entropy-based weighted sampling
298
+ β”œβ”€β”€ EnsembleQwen3/ # Qwen3 ensemble model definitions
299
+ β”‚ β”œβ”€β”€ configuration_qwen3.py # Model configuration
300
+ β”‚ └── modeling_qwen3.py # Model architecture with logit mixing
301
+ └── dataprocess/ # Data processing scripts
302
+ ```
303
+
304
+ ## Citation
305
+
306
+ If you find our work helpful, please cite our paper:
307
+
308
+ ```bibtex
309
+ @misc{chen2026weakdrivenlearningweakagents,
310
+ title={Weak-Driven Learning: How Weak Agents make Strong Agents Stronger},
311
+ author={Zehao Chen and Gongxun Li and Tianxiang Ai and Yifei Li and Zixuan Huang and Wang Zhou and Fuzhen Zhuang and Xianglong Liu and Jianxin Li and Deqing Wang and Yikun Ban},
312
+ year={2026},
313
+ eprint={2602.08222},
314
+ archivePrefix={arXiv},
315
+ primaryClass={cs.AI},
316
+ url={https://arxiv.org/abs/2602.08222}
317
+ }
318
+ ```
319
+
320
+ ## Links
321
+
322
+ - **Paper**: [arXiv:2602.08222](https://arxiv.org/abs/2602.08222)
323
+ - **Hugging Face Paper Page**: [Weak-Driven Learning](https://huggingface.co/papers/2602.08222)
324
+ - **GitHub Repository**: [Weak-Driven-Learning](https://github.com/chenzehao82/Weak-Driven-Learning)
325
+ - **Model Weights**: [chhao/Weak-Driven-Learning](https://huggingface.co/chhao/Weak-Driven-Learning)
326
+
327
+ ## Frequently Asked Questions
328
+
329
+ **Q: What makes Weak-Driven Learning different from knowledge distillation?**
330
+
331
+ A: Unlike knowledge distillation that requires a stronger teacher model, Weak-Driven Learning uses weaker models (like historical checkpoints) as reference points. By explicitly identifying and distancing from weak model failure modes, the strong model continues to improve beyond standard supervision saturation.
332
+
333
+ **Q: Does the model have additional inference overhead?**
334
+
335
+ A: No. After training, we extract the enhanced sub-model which has the same architecture as the base model. There is zero additional inference cost compared to standard fine-tuned models.
336
+
337
+ **Q: Can I use this framework with other base models?**
338
+
339
+ A: Yes! The Weak-Driven Learning framework is model-agnostic. While we provide implementations for Qwen3, the methodology can be adapted to other transformer-based architectures. See the [GitHub repository](https://github.com/chenzehao82/Weak-Driven-Learning) for implementation details.
340
+
341
+ **Q: What is the AM-1.4M dataset?**
342
+
343
+ A: AM-1.4M is a high-quality dataset derived from AM-DeepSeek-R1-Distilled, containing 1.4 million samples focused on mathematical reasoning and problem-solving. The dataset is filtered and processed to ensure quality and diversity.
344
+
345
+ ## Acknowledgments
346
+
347
+ - Model architecture based on [Qwen models](https://github.com/QwenLM/Qwen3)
348
+ - Training framework built on [TRL](https://github.com/huggingface/trl) and [Hugging Face Transformers](https://github.com/huggingface/transformers)
349
+ - Training data derived from [AM-DeepSeek-R1-Distilled dataset](https://huggingface.co/datasets/AM-DeepSeek-R1-Distilled)
350
+
351
+ ## License
352
+
353
+ This project is licensed under the MIT License - see the [LICENSE](https://github.com/chenzehao82/Weak-Driven-Learning/blob/main/LICENSE) file for details.