File size: 24,226 Bytes
f677bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 |
# GRPO
Paper Links
[DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/abs/2402.03300)
[DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://arxiv.org/abs/2501.12948)
environments
```bash
pip install math_verify # reward function
pip install -U trl
```
**Dev Log**
- **2025-05-11** — Implemented support for the **Generative Reward Model** and enabled customized reward model processing logic through the reward plugin. For more details, refer to the [Customized Reward Models](#customized-reward-models) section.
- **2025-04-30** — The startup command for the external vLLM server has been changed to swift rollout.
**FAQ**
1. It is normal for the loss to approach zero during training. Refer to this [issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851) for more details.
2. How to calculate the training steps? Refer to this [issue](https://github.com/modelscope/ms-swift/issues/3912) for more details.
3. Why is the clip_ratio always 1? Refer to this [issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851) for more details.
## Cluster Support

The GRPO training framework supports the integration of high-performance inference engines (such as vLLM) to accelerate the sampling process, offering the following two deployment modes:
### 1. Internal Integration Mode
- Launch the inference service directly within the Trainer.
- Provides two resource allocation strategies:
- **Colocate Mode**: Training and inference share GPU resources.
- **Async Mode**: Training and inference use separate GPU resources.
### GRPO Training Resource Allocation Scheme
| Configuration Scenario | NPROC_PER_NODE | num_infer_workers | Resource Allocation Description |
|-------------------------|----------------|-------------------|---------------------------------------|
| **Colocate** | = Total GPUs | = Total GPUs | Training and inference share all GPU resources. |
| **Async** | = Training GPUs| = Inference GPUs | Must satisfy: Training GPUs + Inference GPUs = Total GPUs. |
**Note:**
1. In Colocate mode, it is recommended to set `sleep_level=1` to release the GPU memory occupied by vLLM during model training.
2. Total GPUs refers to the total number of visible GPU devices.
### 2. External Service Mode
Connect to an external vLLM inference server.
When using this mode, configure the external vLLM server with the following parameters:
```bash
--vllm_server_host <Server IP> \
--vllm_server_port <Server Port> \
--vllm_server_timeout <Timeout> \
```
Deploy the vLLM server using the `swift rollout` command. Currently, only the vLLM backend is supported.
```bash
CUDA_VISIBLE_DEVICES=2 \
swift rollout \
--model Qwen/Qwen2.5-VL-7B-Instruct \
--tensor_parallel_size 2 \
```
The complete script can be found [here](../../../examples/train/grpo/multi_node/Qwen2_5_32B_full.sh) .
## Reward Functions
### Custom Reward Functions
A reward function takes the text `completions` generated by a model and other columns from the dataset as parameters(kwargs), and scores the model's generated text. Below is an example that demonstrates how to implement a simple length-based reward function. This function will give a reward signal of 1.0 if the length of the generated text exceeds 1024; otherwise, the reward signal will be 0.0.
```python
from swift.plugin import ORM, orms
class DummyLengthRewardFunction(ORM):
def __call__(self, completions, **kwargs):
return [1.0 if len(completion) > 1024 else 0.0 for completion in completions]
orms['dummy']= DummyLengthRewardFunction
```
You can add this reward function in `swift/examples/train/grpo/plugin/plugin.py` and register it using the parameter `--external_plugins examples/train/grpo/plugin/plugin.py`, then specify it using the reward_funcs parameter.
For an example of how to execute the script, refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/plugin/run_external_rm.sh).
### Built-in Reward Functions
Swift provides five rule-based reward functions built into the system(The code can be found in swift/plugin/orm.py.)
| Reward Function | Paper |
|----------------|----------------------------------------------------------------------------|
| accuracy | [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via RL](https://arxiv.org/abs/2501.12948) |
| format | Same as above |
| cosine | [Demystifying Long Chain-of-Thought Reasoning in LLMs](https://arxiv.org/abs/2502.03373) |
| repetition | Same as above |
| soft_overlong | [Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)](https://arxiv.org/abs/2503.14476) |
#### 1. **accuracy**
This function compares the model's generated result with the solution column in the dataset to calculate an accuracy score. If the generated result matches the standard answer, the score is 1.0; otherwise, it is 0.0.
Note: This reward function uses the math_verify library to parse the generated results and the answers in the solution, and it may only be applicable to specific mathematical datasets.
#### 2. **format**
The paper uses the following system prompt to enforce a fixed format for model responses:
```
A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think>
```
This function checks whether the model generates text in the format `<think>think content</think><answer>answer content</answer>`. If the generated text adheres to the format requirements, the score is 1.0; otherwise, it is 0.0.
#### 3. **cosine**
The paper found that training with only the accuracy reward function could lead to overly long generated sequences, affecting training performance. The cosine reward function optimizes the training process by controlling the length of the generated sequences:
- For text that generates the correct answer, the reward value decreases as the length increases, encouraging concise responses.
- For text that generates incorrect answers, the reward value increases as the length increases, encouraging deeper reasoning.
A cosine function is used to smoothly adjust the reward value, ensuring that the changes are within a reasonable range. The parameters for the cosine function include the length of the generated text, the maximum length limit, and the minimum and maximum reward values.
Parameters:
- cosine_min_len_value_wrong (default: -0.5): Reward value corresponding to the minimum length when the answer is incorrect.
- cosine_max_len_value_wrong (default: 0.0): Reward value corresponding to the maximum length when the answer is incorrect.
- cosine_min_len_value_correct (default: 1.0): Reward value corresponding to the minimum length when the answer is correct.
- cosine_max_len_value_correct (default: 0.5): Reward value corresponding to the maximum length when the answer is correct.
- cosine_max_len (default value equal to the model's maximum generation capacity): Maximum length limit for generated text.
#### 4. **repetition**
This function penalizes repetition in generated text by detecting repeated n-gram patterns and assigning penalties based on the level of repetition.
The function splits the generated text into words and extracts n-grams of a specified size (default is 3-gram). It calculates the repetition ratio based on the proportion of unique n-grams to the total number of n-grams. If the proportion of repeated n-grams is high, a significant negative reward (penalty) is applied. The penalty value is computed based on the repetition ratio and a maximum penalty value (default: -1.0).
Parameters:
- repetition_n_grams (default: 3): Size of the n-gram used to detect repetition.
- repetition_max_penalty (default: -1.0): Maximum penalty value, which controls the intensity of the penalty.
#### 5. **soft overlong punishment**
Define the length penalty interval. Within this interval, a linear penalty of [-1, 0] is applied.
Parameters:
- soft_max_length: L_max in the paper, the maximum generation length of the model, default is equal to max_completion_length.
- soft_cache_length: L_cache in the paper, controls the length penalty interval, which is defined as [soft_max_length - soft_cache_length, soft_max_length].
Original text from the paper:
> a length-aware penalty mechanism designed to shape the reward for truncated samples. Specifically, when the response length exceeds the predefined maximum value, we define a punishment interval. Within this interval, the longer the response, the greater the punishment it receives. This penalty is added to the original rule-based correctness reward, thereby signaling to the model to avoid excessively long responses.
#### 6. **Reward Models**
In addition to rule-based reward functions, this framework also supports using reward models as reward functions. When using a reward model, you need to specify the `reward_model` parameter, similar to the `model` parameter, which is used to specify the path or name of the reward model. Note that either `reward_model` or `reward_funcs` needs to be specified.
## Arguments and Execution Script
Arguments
- per_device_train_batch_size: The training batch size per device. In GRPO, this refers to the batch size of completions during training.
- per_device_eval_batch_size: The evaluation batch size per device. In GRPO, this refers to the batch size of completions during evaluation.
- num_generations: The number of samples for each prompt, referred to as the G value in the paper, needs to be divisible by per_device_batch_size * - gradient_accumulation_steps * nproc_per_node, default is 8.
- max_completion_length: The maximum length for sampling generation, default is 512.
- ds3_gather_for_generation: This parameter applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, improving generation speed. However, disabling this option allows training models that exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible with vLLM generation. The default is True.
- reward_funcs: Reward functions to score the results generated by the model. Includes built-in accuracy, format , cosine and repetition rule-based functions, detailed in the swift/plugin/orm.py file.
- reward_weights: Weights for each reward function. The number should be equal to the sum of the number of reward functions and reward models. If `None`, all rewards are weighted equally with weight `1.0`.
- Note: If `--reward_model` is included in GRPO training, it is added to the end of the reward functions.
- reward_model: Same as the model, using a reward model as a reward function. At least one of reward_funcs and reward_model needs to be specified.
- reward_model_plugin: The logic for the reward model, which defaults to ORM logic. For more information, please refer to [Customized Reward Models](#customized-reward-models).
- dataset_shuffle: Whether to shuffle the dataset randomly. Default is True.
- loss_type: The type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo'], default is 'grpo'. For details, see this [pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)
- log_completions: Whether to log the model-generated content during training, to be used in conjunction with `--report_to wandb`, default is False.
- Note: If `--report_to wandb` is not set, a `completions.jsonl` will be created in the checkpoint to store the generated content.
- use_vllm: Whether to use vLLM as the back-end for sampling generation; default is False, using it is recommended to speed up training.
- vllm_device: Device for deploying vLLM, default is auto, meaning the first unused GPU. Use cuda:x to specify a particular card.
- vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.
- vllm_max_model_len: vLLM passthrough parameter, default is None.
- vllm_max_num_seqs: vLLM passthrough parameter, default is 256.
- vllm_enforce_eager: vLLM passthrough parameter, default is False.
- vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None.
- vllm_enable_prefix_caching: vLLM passthrough parameter, default is True.
- vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
- vllm_server_port: The service port of the vLLM server. Default is 8000.
- vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.
- num_iterations: number of iterations per batch. Default is 1.
- epsilon: epsilon value for clipping. Default is 0.2.
- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon.
- async_generate: Use async rollout to improve train speed,default `false`.
- sleep_level: vllm specific,when both actor and rollout in the same GPU,you can make vllm sleep when model is training.
- move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches.
- offload_optimizer: Whether to offload optimizer parameters during inference with vLLM/LMDeploy. The default is `False`.
- offload_model: Whether to offload the model itself during inference with vLLM/LMDeploy. The default is `False`.
- Note: If this parameter is set to True and the grad_norm remains zero during training, please install vllm==0.7.3.
- gc_collect_after_offload: Whether to perform garbage collection (both Python GC and GPU GC) after offloading. The default is `False`.
- multi_turn_func: The multi turn GRPO plugin name. Add your multi-turn implementation in plugin/multi_turn.py
- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False.
- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times.
- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False.
The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions).
You can use vLLM and LMDeploy as sampling backends to accelerate training.
Multi-GPU vLLM
```bash
# async mode
# The requirement is that num_infer_workers (deployment) + NPROC_PER_NODE (training) = device_count.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=7 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B \
--reward_funcs accuracy format cosine repetition\
--use_vllm true \
--vllm_device auto \
--vllm_gpu_memory_utilization 0.7 \
--vllm_max_model_len 8192 \
--num_infer_workers 1 \
--train_type full \
--torch_dtype bfloat16 \
--dataset 'AI-MO/NuminaMath-TIR#5000' \
--max_completion_length 2048 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 2 \
--eval_steps 200 \
--save_steps 200 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 4096 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 7 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2 \
--log_completions true
# colocate mode
# The requirement is that num_infer_workers (deployment) = NPROC_PER_NODE (training) = device_count.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-1.5B \
--reward_funcs accuracy format \
--use_vllm true \
--vllm_device auto \
--vllm_gpu_memory_utilization 0.7 \
--vllm_max_model_len 8192 \
--num_infer_workers 8 \
--train_type full \
--torch_dtype bfloat16 \
--dataset 'AI-MO/NuminaMath-TIR#5000' \
--max_completion_length 2048 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-6 \
--gradient_accumulation_steps 2 \
--eval_steps 200 \
--save_steps 200 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 4096 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 8 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero2 \
--log_completions true \
--sleep_level 1 \
--offload_model true \
--offload_optimizer true \
--gc_collect_after_offload true \
--log_completions true \
```
Single-GPU
```bash
# PT backend
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B \
--reward_funcs accuracy format cosine repetition\
--train_type lora \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--torch_dtype bfloat16 \
--dataset 'AI-MO/NuminaMath-TIR#1000' \
--max_completion_length 1024 \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
# vLLM backend
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B \
--vllm_gpu_memory_utilization 0.5 \
--use_vllm true \
--sleep_level 1 \
--offload_model true \
--offload_optimizer true \
--gc_collect_after_offload true \
--reward_funcs accuracy format \
--train_type lora \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--torch_dtype bfloat16 \
--dataset 'AI-MO/NuminaMath-TIR#1000' \
--max_completion_length 1024 \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 0.9 \
--system 'examples/train/grpo/prompt.txt' \
--log_completions true
```
For multi-node training, refer to [here](../../../examples/train/grpo/multi_node/) .
Note : In the internal integration mode, the GPU configurations and training parameters must be identical across different nodes.
## Customized Reward Models
By default, a reward model refers to classification models that include a value head (commonly known as Output Reward Model (ORM)). These models score the outputs of other models, producing a scalar value that represents the quality of the response.
Currently, we can leverage the **reward_model_plugin** to flexibly customize the processing logic of these reward models. This enables the implementation of advanced techniques such as Generative Reward Models, which include:
- Customizing the Model's System Prompt: Defining specific instructions and context to guide the evaluation process.
- Handling Model Interaction History: Managing the conversational context to provide meaningful and contextually aware evaluations.
- Defining Custom Evaluation Criteria: Setting unique standards and metrics for assessing the model's responses beyond default accuracy and relevance measures.
Through the **reward_model_plugin**, developers can tailor the reward evaluation process to meet the specific requirements of their applications. This flexibility allows for more nuanced and effective reward-based training strategies.
We provide a simple generative reward model example (GenRMPlugin) in [rm_plugin.py](../../../swift/plugin/rm_plugin.py)
You can also customized your reward model plugin in [plugin.py](../../../examples/train/grpo/plugin/plugin.py), and register with `external_plugins` argument
Here is an example training script to train GRPO with two reward models: one ORM and one Gen-RM (using qwen2.5-3B-Instruct in this case):
```
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-7B \
--dataset AI-MO/NuminaMath-TIR#5000 \
--external_plugins examples/train/grpo/plugin/plugin.py \
--reward_funcs format \
--reward_model Qwen/Qwen2.5-3B-Instruct Shanghai_AI_Laboratory/internlm2-7b-reward \
--reward_model_plugin genrm my_rmplugin \
--reward_weights 0.1 1 1 \
--num_infer_workers 8 \
--vllm_gpu_memory_utilization 0.5 \
--sleep_level 1 \
--offload_model true \
--offload_optimizer true \
--gc_collect_after_offload true \
--log_completions true \
--deepspeed zero3
```
Notes:
1. In the GRPOTrainer, reward_model instances are appended sequentially to reward_funcs. Therefore, the order of reward_weights corresponds to [reward_funcs, reward_model].
2. The default value for reward_model_plugin is default, which uses the ORM processing logic.
## DAPO
Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) introduces several tricks based on GRPO, which are:
- Clip Higher
- Dynamic Sampling
- Overlong Filtering
- Token level Loss
- Soft Overlong Punishment
Among these, Token level Loss is implemented by default and does not require additional settings. For the other tricks, we can achieve the desired setup based on GRPOTrainer by configuring the following parameters.
| Parameter | Type | Value |
|----------------------|-----------|-------------|
| `--epsilon_high` | `float` | `0.28` |
| `--dynamic_sample` | `bool` | `true` |
| `--overlong_filter` | `bool` | `true` |
| `--reward_funcs` | `str` | `soft_overlong`|
| `--max_resample_times` | `int` | `3` |
Reference training script (for 8-card colocate mode):
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
WANDB_API_KEY=xxx \
swift rlhf \
--rlhf_type grpo \
--model Qwen/Qwen2.5-1.5B \
--reward_funcs accuracy soft_overlong \
--max_completion_length 4096 \
--soft_cache_length 819 \
--epsilon 0.2 \
--epsilon_high 0.28 \
--dynamic_sample true \
--overlong_filter true \
--max_resample_times 3 \
--use_vllm true \
--vllm_gpu_memory_utilization 0.6 \
--num_infer_workers 8 \
--train_type full \
--torch_dtype bfloat16 \
--dataset AI-MO/NuminaMath-TIR#5000 \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--learning_rate 1e-6 \
--eval_steps 1000 \
--save_steps 1000 \
--save_total_limit 2 \
--logging_steps 5 \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 8 \
--temperature 1.0 \
--top_p 1.0 \
--deepspeed zero2 \
--log_completions true \
--num_iterations 1 \
--report_to tensorboard wandb \
--beta 0.0 \
--sleep_level 1 \
--offload_model true \
--offload_optimizer true \
--gc_collect_after_offload true \
--log_completions true
```
|