Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- ICL/DAPO/verl-recipe/collabllm/config/agent.yaml +2 -0
- ICL/DAPO/verl-recipe/collabllm/config/collabllm_interaction_config.yaml +10 -0
- ICL/DAPO/verl-recipe/collabllm/metrics/token_amount.py +26 -0
- ICL/DAPO/verl-recipe/dapo/config/dapo_trainer.yaml +28 -0
- ICL/DAPO/verl-recipe/deepeyes/configs/deepeyes_multiturn_grpo.yaml +32 -0
- ICL/DAPO/verl-recipe/entropy/config/entropy_trainer.yaml +39 -0
- ICL/DAPO/verl-recipe/entropy/reward_score/__init__.py +38 -0
- ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/__init__.py +1062 -0
- ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/grader.py +384 -0
- ICL/DAPO/verl-recipe/fapo/config/rm_config.yaml +49 -0
- ICL/DAPO/verl-recipe/fapo/run_fapo_genrm_train.sh +138 -0
- ICL/DAPO/verl-recipe/fapo/runtime_env.yaml +5 -0
- ICL/DAPO/verl-recipe/fault_recover/agent_loop/__init__.py +20 -0
- ICL/DAPO/verl-recipe/fault_recover/agent_loop/fault_recover_single_turn_agent_loop.py +111 -0
- ICL/DAPO/verl-recipe/fault_recover/config/fault_recover_ppo_megatron_trainer.yaml +265 -0
- ICL/DAPO/verl-recipe/fault_recover/vllm_rollout/__init__.py +13 -0
- ICL/DAPO/verl-recipe/fault_recover/vllm_rollout/vllm_async_server.py +104 -0
- ICL/DAPO/verl-recipe/flowrl/config/flowrl_trainer.yaml +33 -0
- ICL/DAPO/verl-recipe/flowrl/figures/file.svg +135 -0
- ICL/DAPO/verl-recipe/flowrl/figures/flowrl.pdf +0 -0
- ICL/DAPO/verl-recipe/flowrl/prepare/prepare_data.sh +43 -0
- ICL/DAPO/verl-recipe/flowrl/prepare/prepare_model.sh +10 -0
- ICL/DAPO/verl-recipe/gkd/megatron/megatron_utils.py +200 -0
- ICL/DAPO/verl-recipe/gvpo/config/gvpo_trainer.yaml +15 -0
- ICL/DAPO/verl-recipe/langgraph_agent/example/README.md +138 -0
- ICL/DAPO/verl-recipe/langgraph_agent/example/agent.yaml +2 -0
- ICL/DAPO/verl-recipe/langgraph_agent/example/create_dataset.py +290 -0
- ICL/DAPO/verl-recipe/langgraph_agent/example/math_expression.py +38 -0
- ICL/DAPO/verl-recipe/langgraph_agent/example/run_gpt_oss_20b_bf16.sh +143 -0
- ICL/DAPO/verl-recipe/langgraph_agent/example/run_qwen2.5_3b.sh +145 -0
- ICL/DAPO/verl-recipe/open_math_reasoning/run_eval.sh +7 -0
- ICL/DAPO/verl-recipe/prime/config/prime_trainer.yaml +77 -0
- ICL/DAPO/verl-recipe/qat/config/nvfp4_w4a16.json +34 -0
- ICL/DAPO/verl-recipe/qat/config/nvfp4_w4a4.json +45 -0
- ICL/DAPO/verl-recipe/r1/config/evaluation.yaml +14 -0
- ICL/DAPO/verl-recipe/r1/tasks/math_reward.py +35 -0
- ICL/DAPO/verl-recipe/r1_ascend/figures/response_len.png +0 -0
- ICL/DAPO/verl-recipe/r1_ascend/figures/rewards.png +0 -0
- ICL/DAPO/verl-recipe/r1_ascend/figures/val_score.png +0 -0
- ICL/DAPO/verl-recipe/rep_exp/config/_generated_ppo_megatron_trainer.yaml +594 -0
- ICL/DAPO/verl-recipe/rep_exp/config/_generated_ppo_trainer.yaml +563 -0
- ICL/DAPO/verl-recipe/rep_exp/config/actor/actor.yaml +215 -0
- ICL/DAPO/verl-recipe/rep_exp/config/actor/dp_actor.yaml +43 -0
- ICL/DAPO/verl-recipe/rep_exp/config/actor/megatron_actor.yaml +20 -0
- ICL/DAPO/verl-recipe/rep_exp/config/algorithm/rollout_correction.yaml +30 -0
- ICL/DAPO/verl-recipe/rep_exp/config/critic/critic.yaml +176 -0
- ICL/DAPO/verl-recipe/rep_exp/config/critic/dp_critic.yaml +66 -0
- ICL/DAPO/verl-recipe/rep_exp/config/critic/megatron_critic.yaml +43 -0
- ICL/DAPO/verl-recipe/rep_exp/config/data/legacy_data.yaml +131 -0
- ICL/DAPO/verl-recipe/rep_exp/config/engine/fsdp.yaml +56 -0
ICL/DAPO/verl-recipe/collabllm/config/agent.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- name: collabllm_agent
|
| 2 |
+
_target_: recipe.collabllm.collabllm_agent_loop.CollabLLMAgentLoop
|
ICL/DAPO/verl-recipe/collabllm/config/collabllm_interaction_config.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
interaction:
|
| 2 |
+
- name: "collabllm"
|
| 3 |
+
class_name: "recipe.collabllm.collabllm_interation.CollabLLMInteraction"
|
| 4 |
+
config: {
|
| 5 |
+
"user_model": "gpt-4o-mini",
|
| 6 |
+
"num_retries": 3,
|
| 7 |
+
"max_tokens": 512,
|
| 8 |
+
"temperature": 1.0,
|
| 9 |
+
"enable_log": True
|
| 10 |
+
}
|
ICL/DAPO/verl-recipe/collabllm/metrics/token_amount.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 CollabLLM team and/or its affiliates
|
| 2 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def compute_score(data_source, messages, ground_truth, extra_info, **kwargs):
|
| 18 |
+
prompt = extra_info["prompt"]
|
| 19 |
+
|
| 20 |
+
# Calculate the token penalty based on the length of the prompt
|
| 21 |
+
future_conv = messages[len(prompt) :]
|
| 22 |
+
|
| 23 |
+
# simple length estimation
|
| 24 |
+
total_tokens = sum(len(m.content.split()) for m in future_conv)
|
| 25 |
+
|
| 26 |
+
return total_tokens
|
ICL/DAPO/verl-recipe/dapo/config/dapo_trainer.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- ppo_trainer
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
data:
|
| 10 |
+
gen_batch_size: ${data.train_batch_size}
|
| 11 |
+
|
| 12 |
+
reward_model:
|
| 13 |
+
reward_manager: dapo
|
| 14 |
+
overlong_buffer:
|
| 15 |
+
enable: False # We try to avoid forgetting to set enable
|
| 16 |
+
len: 0
|
| 17 |
+
penalty_factor: 0.0
|
| 18 |
+
log: False
|
| 19 |
+
|
| 20 |
+
algorithm:
|
| 21 |
+
filter_groups:
|
| 22 |
+
_target_: verl.trainer.config.FilterGroupsConfig
|
| 23 |
+
enable: False # We try to avoid forgetting to set enable
|
| 24 |
+
metric: null # acc / score / seq_reward / seq_final_reward / ...
|
| 25 |
+
max_num_gen_batches: 0 # Non-positive values mean no upper limit
|
| 26 |
+
|
| 27 |
+
trainer:
|
| 28 |
+
project_name: verl-dapo
|
ICL/DAPO/verl-recipe/deepeyes/configs/deepeyes_multiturn_grpo.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- ppo_trainer
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
data:
|
| 10 |
+
max_prompt_length: 2048
|
| 11 |
+
max_response_length: 2048
|
| 12 |
+
train_batch_size: 256
|
| 13 |
+
return_raw_chat: True
|
| 14 |
+
return_multi_modal_inputs: False
|
| 15 |
+
custom_cls:
|
| 16 |
+
path: "recipe/deepeyes/deepeyes.py"
|
| 17 |
+
name: CustomRLHFDataset
|
| 18 |
+
|
| 19 |
+
actor_rollout_ref:
|
| 20 |
+
hybrid_engine: True
|
| 21 |
+
model:
|
| 22 |
+
custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{%- if messages[0]['content'] is string %}{{- messages[0]['content'] }}{%- else %}{{- messages[0]['content'][0]['text'] }}{%- endif %}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n<tool_call>\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n<tool_response>\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n<tool_call>\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n</tool_call>' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n<tool_response>\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
| 23 |
+
rollout:
|
| 24 |
+
name: sglang
|
| 25 |
+
multi_turn:
|
| 26 |
+
enable: True
|
| 27 |
+
max_assistant_turns: 5
|
| 28 |
+
tool_config_path: "recipe/deepeyes/config/image_zoom_in_tool_config.yaml"
|
| 29 |
+
|
| 30 |
+
custom_reward_function:
|
| 31 |
+
path: "recipe/deepeyes/deepeyes.py"
|
| 32 |
+
name: compute_score
|
ICL/DAPO/verl-recipe/entropy/config/entropy_trainer.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- ppo_trainer
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
data:
|
| 10 |
+
gen_batch_size: ${data.train_batch_size}
|
| 11 |
+
|
| 12 |
+
reward_model:
|
| 13 |
+
reward_kwargs:
|
| 14 |
+
overlong_buffer_cfg: ${reward_model.overlong_buffer}
|
| 15 |
+
reward_manager: dapo
|
| 16 |
+
overlong_buffer:
|
| 17 |
+
enable: False
|
| 18 |
+
len: 0
|
| 19 |
+
penalty_factor: 0.0
|
| 20 |
+
log: False
|
| 21 |
+
|
| 22 |
+
algorithm:
|
| 23 |
+
filter_groups:
|
| 24 |
+
enable: False # We try to avoid forgetting to set enable
|
| 25 |
+
metric: null # acc / score / seq_reward / seq_final_reward / ...
|
| 26 |
+
max_num_gen_batches: 0 # Non-positive values mean no upper limit
|
| 27 |
+
|
| 28 |
+
trainer:
|
| 29 |
+
project_name: verl-entropy
|
| 30 |
+
|
| 31 |
+
actor_rollout_ref:
|
| 32 |
+
actor:
|
| 33 |
+
policy_loss:
|
| 34 |
+
loss_mode: "vanilla" # /clip-cov / kl-cov from https://arxiv.org/abs/2505.
|
| 35 |
+
clip_cov_ratio: 0.0002 # for clip-cov loss
|
| 36 |
+
clip_cov_lb: 1.0 # for clip-cov loss
|
| 37 |
+
clip_cov_ub: 5.0 # for clip-cov loss
|
| 38 |
+
kl_cov_ratio: 0.0002 # for kl-cov loss
|
| 39 |
+
ppo_kl_coef: 0.1 # for kl-cov loss
|
ICL/DAPO/verl-recipe/entropy/reward_score/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# from . import gsm8k, math, prime_math, prime_code
|
| 15 |
+
|
| 16 |
+
import traceback
|
| 17 |
+
|
| 18 |
+
from . import entropy_math
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _default_compute_score(
|
| 22 |
+
data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None
|
| 23 |
+
):
|
| 24 |
+
try:
|
| 25 |
+
res = entropy_math.compute_score(solution_str, str(ground_truth))
|
| 26 |
+
# print(f"data_source: {data_source}")
|
| 27 |
+
# raise NotImplementedError(f"Reward function is not implemented for {data_source=}")
|
| 28 |
+
|
| 29 |
+
if isinstance(res, dict):
|
| 30 |
+
return res
|
| 31 |
+
elif isinstance(res, int | float | bool):
|
| 32 |
+
return float(res)
|
| 33 |
+
else:
|
| 34 |
+
return float(res[0])
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"[ERROR] Error in process_completion for task : {str(e)}")
|
| 37 |
+
traceback.print_exc() # 打印完整堆栈
|
| 38 |
+
raise # 重新抛出异常以便上层捕获
|
ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/__init__.py
ADDED
|
@@ -0,0 +1,1062 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 PRIME team and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except Exception in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Provides a math answer grading function with high recall.
|
| 16 |
+
Based on HF math_verify, verl, open reasoner zero, etc.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import re
|
| 21 |
+
import signal
|
| 22 |
+
from itertools import islice, zip_longest
|
| 23 |
+
from math import isclose
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
import sympy
|
| 27 |
+
from latex2sympy2_extended import latex2sympy
|
| 28 |
+
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
| 29 |
+
from pylatexenc import latex2text
|
| 30 |
+
from sympy import N, simplify
|
| 31 |
+
from sympy.parsing import sympy_parser
|
| 32 |
+
from sympy.parsing.latex import parse_latex
|
| 33 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
This code is adapted from: Dr. GRPO (https://github.com/sail-sg/understand-r1-zero/blob/main/understand_r1_zero/math_grader.py).
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def timeout_ours(timeout_seconds: int = 8):
|
| 41 |
+
if os.name == "posix":
|
| 42 |
+
import signal
|
| 43 |
+
|
| 44 |
+
def decorator(func):
|
| 45 |
+
def handler(signum, frame):
|
| 46 |
+
raise TimeoutError("Operation timed out!")
|
| 47 |
+
|
| 48 |
+
def wrapper(*args, **kwargs):
|
| 49 |
+
old_handler = signal.getsignal(signal.SIGALRM)
|
| 50 |
+
signal.signal(signal.SIGALRM, handler)
|
| 51 |
+
signal.alarm(timeout_seconds)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
return func(*args, **kwargs)
|
| 55 |
+
finally:
|
| 56 |
+
signal.alarm(0)
|
| 57 |
+
signal.signal(signal.SIGALRM, old_handler)
|
| 58 |
+
|
| 59 |
+
return wrapper
|
| 60 |
+
|
| 61 |
+
return decorator
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError(f"Unsupported OS: {os.name}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Dan Hendrycks' code
|
| 67 |
+
def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]:
|
| 68 |
+
if answer is None:
|
| 69 |
+
return None
|
| 70 |
+
answer = answer.strip()
|
| 71 |
+
try:
|
| 72 |
+
# Remove enclosing `\text{}`.
|
| 73 |
+
m = re.search(r"^\\text\{(?P<text>.+?)\}$", answer)
|
| 74 |
+
if m is not None:
|
| 75 |
+
answer = m.group("text").strip()
|
| 76 |
+
return _strip_string(answer)
|
| 77 |
+
except Exception:
|
| 78 |
+
return answer
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# units mainly from MathQA
|
| 82 |
+
unit_texts = [
|
| 83 |
+
"east",
|
| 84 |
+
"degree",
|
| 85 |
+
"mph",
|
| 86 |
+
"kmph",
|
| 87 |
+
"ft",
|
| 88 |
+
"m square",
|
| 89 |
+
" m east",
|
| 90 |
+
"sq m",
|
| 91 |
+
"deg",
|
| 92 |
+
"mile",
|
| 93 |
+
"q .",
|
| 94 |
+
"monkey",
|
| 95 |
+
"prime",
|
| 96 |
+
"ratio",
|
| 97 |
+
"profit of rs",
|
| 98 |
+
"rd",
|
| 99 |
+
"o",
|
| 100 |
+
"gm",
|
| 101 |
+
"p . m",
|
| 102 |
+
"lb",
|
| 103 |
+
"tile",
|
| 104 |
+
"per",
|
| 105 |
+
"dm",
|
| 106 |
+
"lt",
|
| 107 |
+
"gain",
|
| 108 |
+
"ab",
|
| 109 |
+
"way",
|
| 110 |
+
"west",
|
| 111 |
+
"a .",
|
| 112 |
+
"b .",
|
| 113 |
+
"c .",
|
| 114 |
+
"d .",
|
| 115 |
+
"e .",
|
| 116 |
+
"f .",
|
| 117 |
+
"g .",
|
| 118 |
+
"h .",
|
| 119 |
+
"t",
|
| 120 |
+
"a",
|
| 121 |
+
"h",
|
| 122 |
+
"no change",
|
| 123 |
+
"men",
|
| 124 |
+
"soldier",
|
| 125 |
+
"pie",
|
| 126 |
+
"bc",
|
| 127 |
+
"excess",
|
| 128 |
+
"st",
|
| 129 |
+
"inches",
|
| 130 |
+
"noon",
|
| 131 |
+
"percent",
|
| 132 |
+
"by",
|
| 133 |
+
"gal",
|
| 134 |
+
"kmh",
|
| 135 |
+
"c",
|
| 136 |
+
"acre",
|
| 137 |
+
"rise",
|
| 138 |
+
"a . m",
|
| 139 |
+
"th",
|
| 140 |
+
"π r 2",
|
| 141 |
+
"sq",
|
| 142 |
+
"mark",
|
| 143 |
+
"l",
|
| 144 |
+
"toy",
|
| 145 |
+
"coin",
|
| 146 |
+
"sq . m",
|
| 147 |
+
"gallon",
|
| 148 |
+
"° f",
|
| 149 |
+
"profit",
|
| 150 |
+
"minw",
|
| 151 |
+
"yr",
|
| 152 |
+
"women",
|
| 153 |
+
"feet",
|
| 154 |
+
"am",
|
| 155 |
+
"pm",
|
| 156 |
+
"hr",
|
| 157 |
+
"cu cm",
|
| 158 |
+
"square",
|
| 159 |
+
"v â € ™",
|
| 160 |
+
"are",
|
| 161 |
+
"rupee",
|
| 162 |
+
"rounds",
|
| 163 |
+
"cubic",
|
| 164 |
+
"cc",
|
| 165 |
+
"mtr",
|
| 166 |
+
"s",
|
| 167 |
+
"ohm",
|
| 168 |
+
"number",
|
| 169 |
+
"kmph",
|
| 170 |
+
"day",
|
| 171 |
+
"hour",
|
| 172 |
+
"minute",
|
| 173 |
+
"min",
|
| 174 |
+
"second",
|
| 175 |
+
"man",
|
| 176 |
+
"woman",
|
| 177 |
+
"sec",
|
| 178 |
+
"cube",
|
| 179 |
+
"mt",
|
| 180 |
+
"sq inch",
|
| 181 |
+
"mp",
|
| 182 |
+
"∏ cm ³",
|
| 183 |
+
"hectare",
|
| 184 |
+
"more",
|
| 185 |
+
"sec",
|
| 186 |
+
"unit",
|
| 187 |
+
"cu . m",
|
| 188 |
+
"cm 2",
|
| 189 |
+
"rs .",
|
| 190 |
+
"rs",
|
| 191 |
+
"kg",
|
| 192 |
+
"g",
|
| 193 |
+
"month",
|
| 194 |
+
"km",
|
| 195 |
+
"m",
|
| 196 |
+
"cm",
|
| 197 |
+
"mm",
|
| 198 |
+
"apple",
|
| 199 |
+
"liter",
|
| 200 |
+
"loss",
|
| 201 |
+
"yard",
|
| 202 |
+
"pure",
|
| 203 |
+
"year",
|
| 204 |
+
"increase",
|
| 205 |
+
"decrease",
|
| 206 |
+
"d",
|
| 207 |
+
"less",
|
| 208 |
+
"Surface",
|
| 209 |
+
"litre",
|
| 210 |
+
"pi sq m",
|
| 211 |
+
"s .",
|
| 212 |
+
"metre",
|
| 213 |
+
"meter",
|
| 214 |
+
"inch",
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
unit_texts.extend([t + "s" for t in unit_texts])
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _strip_string(string):
|
| 221 |
+
def _fix_fracs(string):
|
| 222 |
+
substrs = string.split("\\frac")
|
| 223 |
+
new_str = substrs[0]
|
| 224 |
+
if len(substrs) > 1:
|
| 225 |
+
substrs = substrs[1:]
|
| 226 |
+
for substr in substrs:
|
| 227 |
+
new_str += "\\frac"
|
| 228 |
+
if substr[0] == "{":
|
| 229 |
+
new_str += substr
|
| 230 |
+
else:
|
| 231 |
+
try:
|
| 232 |
+
assert len(substr) >= 2
|
| 233 |
+
except Exception:
|
| 234 |
+
return string
|
| 235 |
+
a = substr[0]
|
| 236 |
+
b = substr[1]
|
| 237 |
+
if b != "{":
|
| 238 |
+
if len(substr) > 2:
|
| 239 |
+
post_substr = substr[2:]
|
| 240 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
| 241 |
+
else:
|
| 242 |
+
new_str += "{" + a + "}{" + b + "}"
|
| 243 |
+
else:
|
| 244 |
+
if len(substr) > 2:
|
| 245 |
+
post_substr = substr[2:]
|
| 246 |
+
new_str += "{" + a + "}" + b + post_substr
|
| 247 |
+
else:
|
| 248 |
+
new_str += "{" + a + "}" + b
|
| 249 |
+
string = new_str
|
| 250 |
+
return string
|
| 251 |
+
|
| 252 |
+
def _fix_a_slash_b(string):
|
| 253 |
+
if len(string.split("/")) != 2:
|
| 254 |
+
return string
|
| 255 |
+
a = string.split("/")[0]
|
| 256 |
+
b = string.split("/")[1]
|
| 257 |
+
try:
|
| 258 |
+
a = int(a)
|
| 259 |
+
b = int(b)
|
| 260 |
+
assert string == "{}/{}".format(a, b)
|
| 261 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
| 262 |
+
return new_string
|
| 263 |
+
except Exception:
|
| 264 |
+
return string
|
| 265 |
+
|
| 266 |
+
def _remove_right_units(string):
|
| 267 |
+
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
| 268 |
+
if "\\text{ " in string:
|
| 269 |
+
splits = string.split("\\text{ ")
|
| 270 |
+
assert len(splits) == 2
|
| 271 |
+
return splits[0]
|
| 272 |
+
else:
|
| 273 |
+
return string
|
| 274 |
+
|
| 275 |
+
def _fix_sqrt(string):
|
| 276 |
+
if "\\sqrt" not in string:
|
| 277 |
+
return string
|
| 278 |
+
splits = string.split("\\sqrt")
|
| 279 |
+
new_string = splits[0]
|
| 280 |
+
for split in splits[1:]:
|
| 281 |
+
if split[0] != "{":
|
| 282 |
+
a = split[0]
|
| 283 |
+
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
| 284 |
+
else:
|
| 285 |
+
new_substr = "\\sqrt" + split
|
| 286 |
+
new_string += new_substr
|
| 287 |
+
return new_string
|
| 288 |
+
|
| 289 |
+
# linebreaks
|
| 290 |
+
string = string.replace("\n", "")
|
| 291 |
+
# print(string)
|
| 292 |
+
|
| 293 |
+
# remove inverse spaces
|
| 294 |
+
string = string.replace("\\!", "")
|
| 295 |
+
# print(string)
|
| 296 |
+
|
| 297 |
+
# replace \\ with \
|
| 298 |
+
string = string.replace("\\\\", "\\")
|
| 299 |
+
# print(string)
|
| 300 |
+
|
| 301 |
+
# matrix
|
| 302 |
+
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
|
| 303 |
+
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
|
| 304 |
+
string = string.replace("bmatrix", "pmatrix")
|
| 305 |
+
|
| 306 |
+
# replace tfrac and dfrac with frac
|
| 307 |
+
string = string.replace("tfrac", "frac")
|
| 308 |
+
string = string.replace("dfrac", "frac")
|
| 309 |
+
string = string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge")
|
| 310 |
+
# print(string)
|
| 311 |
+
|
| 312 |
+
# remove \left and \right
|
| 313 |
+
string = string.replace("\\left", "")
|
| 314 |
+
string = string.replace("\\right", "")
|
| 315 |
+
# print(string)
|
| 316 |
+
|
| 317 |
+
# Remove unit: miles, dollars if after is not none
|
| 318 |
+
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
| 319 |
+
if _string != "" and _string != string:
|
| 320 |
+
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
| 321 |
+
string = _string
|
| 322 |
+
|
| 323 |
+
# Remove unit: texts
|
| 324 |
+
for _ in range(2):
|
| 325 |
+
for unit_text in unit_texts:
|
| 326 |
+
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
|
| 327 |
+
# the suffix should be either the end of the string or a non-alphanumeric character
|
| 328 |
+
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
|
| 329 |
+
if _string != "":
|
| 330 |
+
string = _string
|
| 331 |
+
|
| 332 |
+
# Remove circ (degrees)
|
| 333 |
+
string = string.replace("^{\\circ}", "")
|
| 334 |
+
string = string.replace("^\\circ", "")
|
| 335 |
+
|
| 336 |
+
# remove dollar signs
|
| 337 |
+
string = string.replace("\\$", "")
|
| 338 |
+
|
| 339 |
+
# remove units (on the right)
|
| 340 |
+
string = _remove_right_units(string)
|
| 341 |
+
|
| 342 |
+
# remove percentage
|
| 343 |
+
string = string.replace("\\\\%", "")
|
| 344 |
+
string = string.replace("\\%", "")
|
| 345 |
+
|
| 346 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
| 347 |
+
string = string.replace(" .", " 0.")
|
| 348 |
+
string = string.replace("{.", "{0.")
|
| 349 |
+
# if empty, return empty string
|
| 350 |
+
if len(string) == 0:
|
| 351 |
+
return string
|
| 352 |
+
if string[0] == ".":
|
| 353 |
+
string = "0" + string
|
| 354 |
+
|
| 355 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
| 356 |
+
if len(string.split("=")) == 2:
|
| 357 |
+
if len(string.split("=")[0]) <= 2:
|
| 358 |
+
string = string.split("=")[1]
|
| 359 |
+
|
| 360 |
+
# fix sqrt3 --> sqrt{3}
|
| 361 |
+
string = _fix_sqrt(string)
|
| 362 |
+
|
| 363 |
+
# remove spaces
|
| 364 |
+
string = string.replace(" ", "")
|
| 365 |
+
|
| 366 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1).
|
| 367 |
+
# Also does a/b --> \\frac{a}{b}
|
| 368 |
+
string = _fix_fracs(string)
|
| 369 |
+
|
| 370 |
+
# manually change 0.5 --> \frac{1}{2}
|
| 371 |
+
if string == "0.5":
|
| 372 |
+
string = "\\frac{1}{2}"
|
| 373 |
+
|
| 374 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
| 375 |
+
string = _fix_a_slash_b(string)
|
| 376 |
+
|
| 377 |
+
return string
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
SUBSTITUTIONS = [
|
| 381 |
+
("an ", ""),
|
| 382 |
+
("a ", ""),
|
| 383 |
+
(".$", "$"),
|
| 384 |
+
("\\$", ""),
|
| 385 |
+
(r"\ ", ""),
|
| 386 |
+
(" ", ""),
|
| 387 |
+
("mbox", "text"),
|
| 388 |
+
(",\\text{and}", ","),
|
| 389 |
+
("\\text{and}", ","),
|
| 390 |
+
("\\text{m}", "\\text{}"),
|
| 391 |
+
]
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
REMOVED_EXPRESSIONS = [
|
| 395 |
+
"square",
|
| 396 |
+
"ways",
|
| 397 |
+
"integers",
|
| 398 |
+
"dollars",
|
| 399 |
+
"mph",
|
| 400 |
+
"inches",
|
| 401 |
+
"ft",
|
| 402 |
+
"hours",
|
| 403 |
+
"km",
|
| 404 |
+
"units",
|
| 405 |
+
"\\ldots",
|
| 406 |
+
"sue",
|
| 407 |
+
"points",
|
| 408 |
+
"feet",
|
| 409 |
+
"minutes",
|
| 410 |
+
"digits",
|
| 411 |
+
"cents",
|
| 412 |
+
"degrees",
|
| 413 |
+
"cm",
|
| 414 |
+
"gm",
|
| 415 |
+
"pounds",
|
| 416 |
+
"meters",
|
| 417 |
+
"meals",
|
| 418 |
+
"edges",
|
| 419 |
+
"students",
|
| 420 |
+
"childrentickets",
|
| 421 |
+
"multiples",
|
| 422 |
+
"\\text{s}",
|
| 423 |
+
"\\text{.}",
|
| 424 |
+
"\\text{\ns}",
|
| 425 |
+
"\\text{}^2",
|
| 426 |
+
"\\text{}^3",
|
| 427 |
+
"\\text{\n}",
|
| 428 |
+
"\\text{}",
|
| 429 |
+
r"\mathrm{th}",
|
| 430 |
+
r"^\circ",
|
| 431 |
+
r"^{\circ}",
|
| 432 |
+
r"\;",
|
| 433 |
+
r",\!",
|
| 434 |
+
"{,}",
|
| 435 |
+
'"',
|
| 436 |
+
"\\dots",
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def normalize_final_answer(final_answer: str) -> str:
|
| 441 |
+
"""
|
| 442 |
+
Normalize a final answer to a quantitative reasoning question.
|
| 443 |
+
This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18.
|
| 444 |
+
"""
|
| 445 |
+
# final_answer = final_answer.split("=")[-1]
|
| 446 |
+
|
| 447 |
+
for before, after in SUBSTITUTIONS:
|
| 448 |
+
final_answer = final_answer.replace(before, after)
|
| 449 |
+
for expr in REMOVED_EXPRESSIONS:
|
| 450 |
+
final_answer = final_answer.replace(expr, "")
|
| 451 |
+
|
| 452 |
+
# Extract answer that is in LaTeX math, is bold,
|
| 453 |
+
# is surrounded by a box, etc.
|
| 454 |
+
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
| 455 |
+
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
| 456 |
+
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
| 457 |
+
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
| 458 |
+
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
| 459 |
+
|
| 460 |
+
# Normalize shorthand TeX:
|
| 461 |
+
# \fracab -> \frac{a}{b}
|
| 462 |
+
# \frac{abc}{bef} -> \frac{abc}{bef}
|
| 463 |
+
# \fracabc -> \frac{a}{b}c
|
| 464 |
+
# \sqrta -> \sqrt{a}
|
| 465 |
+
# \sqrtab -> sqrt{a}b
|
| 466 |
+
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
| 467 |
+
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
| 468 |
+
final_answer = final_answer.replace("$", "")
|
| 469 |
+
|
| 470 |
+
# Normalize 100,000 -> 100000
|
| 471 |
+
if final_answer.replace(",", "").isdigit():
|
| 472 |
+
final_answer = final_answer.replace(",", "")
|
| 473 |
+
|
| 474 |
+
return final_answer
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def repeatness(s: str):
|
| 478 |
+
def ranks(seq):
|
| 479 |
+
index = {v: i for i, v in enumerate(sorted(set(seq)))}
|
| 480 |
+
return [index[v] for v in seq]
|
| 481 |
+
|
| 482 |
+
def suffixArray(s):
|
| 483 |
+
line = ranks(s)
|
| 484 |
+
n, k, ans, sa = len(s), 1, line, [0] * len(s)
|
| 485 |
+
while k < n - 1:
|
| 486 |
+
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
|
| 487 |
+
ans, k = line, k << 1
|
| 488 |
+
for i, k in enumerate(ans):
|
| 489 |
+
sa[k] = i
|
| 490 |
+
return ans, sa
|
| 491 |
+
|
| 492 |
+
def lcp(arr, suffixArr, inv_suff):
|
| 493 |
+
n, ans, k = len(arr), [0] * len(arr), 0
|
| 494 |
+
|
| 495 |
+
for i in range(n):
|
| 496 |
+
if inv_suff[i] == n - 1:
|
| 497 |
+
k = 0
|
| 498 |
+
continue
|
| 499 |
+
|
| 500 |
+
j = suffixArr[inv_suff[i] + 1]
|
| 501 |
+
while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
|
| 502 |
+
k += 1
|
| 503 |
+
|
| 504 |
+
ans[inv_suff[i]] = k
|
| 505 |
+
if k > 0:
|
| 506 |
+
k -= 1
|
| 507 |
+
|
| 508 |
+
return ans
|
| 509 |
+
|
| 510 |
+
arr = [ord(i) for i in s]
|
| 511 |
+
n = len(arr)
|
| 512 |
+
if n <= 1:
|
| 513 |
+
return 0
|
| 514 |
+
c, sa = suffixArray(arr)
|
| 515 |
+
cnt = sum(lcp(arr, sa, c))
|
| 516 |
+
|
| 517 |
+
return (cnt * 2 / (n * (n + 1))) > 0.2
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class timeout:
|
| 521 |
+
def __init__(self, seconds=1, error_message="Timeout"):
|
| 522 |
+
self.seconds = seconds
|
| 523 |
+
self.error_message = error_message
|
| 524 |
+
|
| 525 |
+
def handle_timeout(self, signum, frame):
|
| 526 |
+
raise TimeoutError(self.error_message)
|
| 527 |
+
|
| 528 |
+
def __enter__(self):
|
| 529 |
+
signal.signal(signal.SIGALRM, self.handle_timeout)
|
| 530 |
+
signal.alarm(self.seconds)
|
| 531 |
+
|
| 532 |
+
def __exit__(self, type, value, traceback):
|
| 533 |
+
signal.alarm(0)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def latex_eval(latex):
|
| 537 |
+
sym = parse_latex(latex)
|
| 538 |
+
val = sym.evalf()
|
| 539 |
+
return sym, val
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def numeric_equal(prediction: float, reference: float):
|
| 543 |
+
# Note that relative tolerance has significant impact
|
| 544 |
+
# on the result of the synthesized GSM-Hard dataset
|
| 545 |
+
# if reference.is_integer():
|
| 546 |
+
# return isclose(reference, round(prediction), abs_tol=1e-4)
|
| 547 |
+
# else:
|
| 548 |
+
# prediction = round(prediction, len(str(reference).split(".")[-1]))
|
| 549 |
+
return isclose(reference, prediction, rel_tol=1e-4)
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
@timeout_ours(timeout_seconds=5)
|
| 553 |
+
def symbolic_equal(a, b):
|
| 554 |
+
def _parse(s):
|
| 555 |
+
for f in [parse_latex, parse_expr, latex2sympy]:
|
| 556 |
+
try:
|
| 557 |
+
return f(s.replace("\\\\", "\\"))
|
| 558 |
+
except Exception:
|
| 559 |
+
try:
|
| 560 |
+
return f(s)
|
| 561 |
+
except Exception:
|
| 562 |
+
pass
|
| 563 |
+
return s
|
| 564 |
+
|
| 565 |
+
a = _parse(a)
|
| 566 |
+
b = _parse(b)
|
| 567 |
+
|
| 568 |
+
# direct equal
|
| 569 |
+
try:
|
| 570 |
+
if str(a) == str(b) or a == b:
|
| 571 |
+
return True
|
| 572 |
+
except Exception:
|
| 573 |
+
pass
|
| 574 |
+
|
| 575 |
+
# simplify equal
|
| 576 |
+
try:
|
| 577 |
+
if a.equals(b) or simplify(a - b) == 0:
|
| 578 |
+
return True
|
| 579 |
+
except Exception:
|
| 580 |
+
pass
|
| 581 |
+
|
| 582 |
+
# equation equal
|
| 583 |
+
try:
|
| 584 |
+
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
| 585 |
+
return True
|
| 586 |
+
except Exception:
|
| 587 |
+
pass
|
| 588 |
+
|
| 589 |
+
try:
|
| 590 |
+
if numeric_equal(float(N(a)), float(N(b))):
|
| 591 |
+
return True
|
| 592 |
+
except Exception:
|
| 593 |
+
pass
|
| 594 |
+
|
| 595 |
+
# matrix
|
| 596 |
+
try:
|
| 597 |
+
# if a and b are matrix
|
| 598 |
+
if a.shape == b.shape:
|
| 599 |
+
_a = a.applyfunc(lambda x: round(x, 3))
|
| 600 |
+
_b = b.applyfunc(lambda x: round(x, 3))
|
| 601 |
+
if _a.equals(_b):
|
| 602 |
+
return True
|
| 603 |
+
except Exception:
|
| 604 |
+
pass
|
| 605 |
+
|
| 606 |
+
return False
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def _is_latex_equal(str1, str2):
|
| 610 |
+
try:
|
| 611 |
+
sym1, val1 = latex_eval(str1)
|
| 612 |
+
sym2, val2 = latex_eval(str2)
|
| 613 |
+
if sym1 == sym2 or val1 == val2:
|
| 614 |
+
return True
|
| 615 |
+
else:
|
| 616 |
+
raise ValueError
|
| 617 |
+
except Exception:
|
| 618 |
+
try:
|
| 619 |
+
norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2)
|
| 620 |
+
sym1, val1 = latex_eval(norm1)
|
| 621 |
+
sym2, val2 = latex_eval(norm2)
|
| 622 |
+
if sym1 == sym2 or val1 == val2:
|
| 623 |
+
return True
|
| 624 |
+
except Exception:
|
| 625 |
+
return norm1 == norm2
|
| 626 |
+
return False
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def is_latex_equal(given_answer: str, ground_truth: str) -> bool:
|
| 630 |
+
try:
|
| 631 |
+
with timeout(1):
|
| 632 |
+
try:
|
| 633 |
+
if (len(given_answer) > 128 and repeatness(given_answer)) or (
|
| 634 |
+
len(ground_truth) > 128 and repeatness(ground_truth)
|
| 635 |
+
):
|
| 636 |
+
return False
|
| 637 |
+
# First conduct normalized string matching.
|
| 638 |
+
ground_truth_normalized = _normalize(ground_truth)
|
| 639 |
+
given_normalized = _normalize(given_answer)
|
| 640 |
+
if ground_truth_normalized is None:
|
| 641 |
+
return False
|
| 642 |
+
if ground_truth_normalized == given_normalized:
|
| 643 |
+
return True
|
| 644 |
+
|
| 645 |
+
# Next call math verify.
|
| 646 |
+
given_answer.replace("\n", "")
|
| 647 |
+
ground_truth.replace("\n", "")
|
| 648 |
+
if "$" not in given_answer:
|
| 649 |
+
given_answer = f"${given_answer}$"
|
| 650 |
+
if "$" not in ground_truth:
|
| 651 |
+
ground_truth = f"${ground_truth}$"
|
| 652 |
+
return verify(
|
| 653 |
+
parse(
|
| 654 |
+
ground_truth,
|
| 655 |
+
extraction_config=(
|
| 656 |
+
LatexExtractionConfig(boxed_match_priority=0),
|
| 657 |
+
ExprExtractionConfig(),
|
| 658 |
+
),
|
| 659 |
+
fallback_mode="no_fallback",
|
| 660 |
+
extraction_mode=["first_match"],
|
| 661 |
+
parsing_timeout=1,
|
| 662 |
+
),
|
| 663 |
+
parse(
|
| 664 |
+
given_answer,
|
| 665 |
+
extraction_config=(
|
| 666 |
+
LatexExtractionConfig(boxed_match_priority=0),
|
| 667 |
+
ExprExtractionConfig(),
|
| 668 |
+
),
|
| 669 |
+
fallback_mode="no_fallback",
|
| 670 |
+
extraction_mode=["first_match"],
|
| 671 |
+
parsing_timeout=1,
|
| 672 |
+
),
|
| 673 |
+
timeout_seconds=1,
|
| 674 |
+
)
|
| 675 |
+
# or symbolic_equal(ground_truth, given_answer)
|
| 676 |
+
except Exception:
|
| 677 |
+
return False
|
| 678 |
+
except TimeoutError:
|
| 679 |
+
return False
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def is_value_equal(given_answer: str, ground_truth: str) -> bool:
|
| 683 |
+
assert ground_truth is not None
|
| 684 |
+
ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)
|
| 685 |
+
given_answer_normalized_mathd = mathd_normalize_answer(given_answer)
|
| 686 |
+
|
| 687 |
+
str_equal = ground_truth_normalized_mathd == given_answer_normalized_mathd
|
| 688 |
+
try:
|
| 689 |
+
number_equal = float(ground_truth_normalized_mathd) == float(given_answer_normalized_mathd)
|
| 690 |
+
return str_equal or number_equal
|
| 691 |
+
except Exception:
|
| 692 |
+
return str_equal
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
# sympy might hang -- we don't care about trying to be lenient in these cases
|
| 696 |
+
BAD_SUBSTRINGS = ["^{", "^("]
|
| 697 |
+
BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"]
|
| 698 |
+
TUPLE_CHARS = "()[]"
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def _sympy_parse(expr: str):
|
| 702 |
+
"""Parses an expression with sympy."""
|
| 703 |
+
py_expr = expr.replace("^", "**")
|
| 704 |
+
return sympy_parser.parse_expr(
|
| 705 |
+
py_expr,
|
| 706 |
+
transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)),
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
def _parse_latex(expr: str) -> str:
|
| 711 |
+
"""Attempts to parse latex to an expression sympy can read."""
|
| 712 |
+
expr = expr.replace("\\tfrac", "\\frac")
|
| 713 |
+
expr = expr.replace("\\dfrac", "\\frac")
|
| 714 |
+
expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
|
| 715 |
+
expr = latex2text.LatexNodes2Text().latex_to_text(expr)
|
| 716 |
+
|
| 717 |
+
# Replace the specific characters that this parser uses.
|
| 718 |
+
expr = expr.replace("√", "sqrt")
|
| 719 |
+
expr = expr.replace("π", "pi")
|
| 720 |
+
expr = expr.replace("∞", "inf")
|
| 721 |
+
expr = expr.replace("∪", "U")
|
| 722 |
+
expr = expr.replace("·", "*")
|
| 723 |
+
expr = expr.replace("×", "*")
|
| 724 |
+
|
| 725 |
+
return expr.strip()
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def _is_float(num: str) -> bool:
|
| 729 |
+
try:
|
| 730 |
+
float(num)
|
| 731 |
+
return True
|
| 732 |
+
except ValueError:
|
| 733 |
+
return False
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
def _is_int(x: float) -> bool:
|
| 737 |
+
try:
|
| 738 |
+
return abs(x - int(round(x))) <= 1e-7
|
| 739 |
+
except Exception:
|
| 740 |
+
return False
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def _is_frac(expr: str) -> bool:
|
| 744 |
+
return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def _str_is_int(x: str) -> bool:
|
| 748 |
+
try:
|
| 749 |
+
x = _strip_properly_formatted_commas(x)
|
| 750 |
+
x = float(x)
|
| 751 |
+
return abs(x - int(round(x))) <= 1e-7
|
| 752 |
+
except Exception:
|
| 753 |
+
return False
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def _str_to_int(x: str) -> bool:
|
| 757 |
+
x = x.replace(",", "")
|
| 758 |
+
x = float(x)
|
| 759 |
+
return int(x)
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def _inject_implicit_mixed_number(step: str):
|
| 763 |
+
"""
|
| 764 |
+
Automatically make a mixed number evalable
|
| 765 |
+
e.g. 7 3/4 => 7+3/4
|
| 766 |
+
"""
|
| 767 |
+
p1 = re.compile("([0-9]) +([0-9])")
|
| 768 |
+
step = p1.sub("\\1+\\2", step) ## implicit mults
|
| 769 |
+
return step
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def _strip_properly_formatted_commas(expr: str):
|
| 773 |
+
# We want to be careful because we don't want to strip tuple commas
|
| 774 |
+
p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)")
|
| 775 |
+
while True:
|
| 776 |
+
next_expr = p1.sub("\\1\\3\\4", expr)
|
| 777 |
+
if next_expr == expr:
|
| 778 |
+
break
|
| 779 |
+
expr = next_expr
|
| 780 |
+
return next_expr
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def _normalize(expr: str) -> str:
|
| 784 |
+
"""Normalize answer expressions."""
|
| 785 |
+
if expr is None:
|
| 786 |
+
return None
|
| 787 |
+
|
| 788 |
+
# Remove enclosing `\text{}`.
|
| 789 |
+
m = re.search(r"^\\text\{(?P<text>.+?)\}$", expr)
|
| 790 |
+
if m is not None:
|
| 791 |
+
expr = m.group("text")
|
| 792 |
+
|
| 793 |
+
expr = expr.replace("\\%", "%")
|
| 794 |
+
expr = expr.replace("\\$", "$")
|
| 795 |
+
expr = expr.replace("$", "")
|
| 796 |
+
expr = expr.replace("%", "")
|
| 797 |
+
expr = expr.replace(" or ", " , ")
|
| 798 |
+
expr = expr.replace(" and ", " , ")
|
| 799 |
+
|
| 800 |
+
expr = expr.replace("million", "*10^6")
|
| 801 |
+
expr = expr.replace("billion", "*10^9")
|
| 802 |
+
expr = expr.replace("trillion", "*10^12")
|
| 803 |
+
|
| 804 |
+
for unit in [
|
| 805 |
+
"degree",
|
| 806 |
+
"cm",
|
| 807 |
+
"centimeter",
|
| 808 |
+
"meter",
|
| 809 |
+
"mile",
|
| 810 |
+
"second",
|
| 811 |
+
"minute",
|
| 812 |
+
"hour",
|
| 813 |
+
"day",
|
| 814 |
+
"week",
|
| 815 |
+
"month",
|
| 816 |
+
"year",
|
| 817 |
+
"foot",
|
| 818 |
+
"feet",
|
| 819 |
+
"inch",
|
| 820 |
+
"yard",
|
| 821 |
+
]:
|
| 822 |
+
expr = re.sub(f"{unit}(es)?(s)? *(\\^[0-9]+)?", "", expr)
|
| 823 |
+
expr = re.sub(r"\^ *\\circ", "", expr)
|
| 824 |
+
|
| 825 |
+
if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
|
| 826 |
+
expr = expr[1:-1]
|
| 827 |
+
|
| 828 |
+
expr = re.sub(",\\\\! *", "", expr)
|
| 829 |
+
if _is_float(expr) and _is_int(float(expr)):
|
| 830 |
+
expr = str(int(round(float(expr))))
|
| 831 |
+
if "\\" in expr:
|
| 832 |
+
try:
|
| 833 |
+
expr = _parse_latex(expr)
|
| 834 |
+
except Exception:
|
| 835 |
+
pass
|
| 836 |
+
|
| 837 |
+
# edge case with mixed numbers and negative signs
|
| 838 |
+
expr = re.sub("- *", "-", expr)
|
| 839 |
+
|
| 840 |
+
expr = _inject_implicit_mixed_number(expr)
|
| 841 |
+
expr = expr.replace(" ", "")
|
| 842 |
+
|
| 843 |
+
# if we somehow still have latex braces here, just drop them
|
| 844 |
+
expr = expr.replace("{", "")
|
| 845 |
+
expr = expr.replace("}", "")
|
| 846 |
+
|
| 847 |
+
# don't be case sensitive for text answers
|
| 848 |
+
expr = expr.lower()
|
| 849 |
+
|
| 850 |
+
if _str_is_int(expr):
|
| 851 |
+
expr = str(_str_to_int(expr))
|
| 852 |
+
|
| 853 |
+
return expr
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
def count_unknown_letters_in_expr(expr: str):
|
| 857 |
+
expr = expr.replace("sqrt", "")
|
| 858 |
+
expr = expr.replace("frac", "")
|
| 859 |
+
letters_in_expr = set([x for x in expr if x.isalpha()])
|
| 860 |
+
return len(letters_in_expr)
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def should_allow_eval(expr: str):
|
| 864 |
+
# we don't want to try parsing unknown text or functions of more than two variables
|
| 865 |
+
if count_unknown_letters_in_expr(expr) > 2:
|
| 866 |
+
return False
|
| 867 |
+
|
| 868 |
+
for bad_string in BAD_SUBSTRINGS:
|
| 869 |
+
if bad_string in expr:
|
| 870 |
+
return False
|
| 871 |
+
|
| 872 |
+
for bad_regex in BAD_REGEXES:
|
| 873 |
+
if re.search(bad_regex, expr) is not None:
|
| 874 |
+
return False
|
| 875 |
+
|
| 876 |
+
return True
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
@timeout_ours(timeout_seconds=5)
|
| 880 |
+
def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
|
| 881 |
+
are_equal = False
|
| 882 |
+
try:
|
| 883 |
+
expr = f"({ground_truth_normalized})-({given_normalized})"
|
| 884 |
+
if should_allow_eval(expr):
|
| 885 |
+
sympy_diff = _sympy_parse(expr)
|
| 886 |
+
simplified = sympy.simplify(sympy_diff)
|
| 887 |
+
if simplified == 0:
|
| 888 |
+
are_equal = True
|
| 889 |
+
except Exception:
|
| 890 |
+
pass
|
| 891 |
+
return are_equal
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def split_tuple(expr: str):
|
| 895 |
+
"""
|
| 896 |
+
Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
|
| 897 |
+
"""
|
| 898 |
+
expr = _strip_properly_formatted_commas(expr)
|
| 899 |
+
if len(expr) == 0:
|
| 900 |
+
return []
|
| 901 |
+
if (
|
| 902 |
+
len(expr) > 2
|
| 903 |
+
and expr[0] in TUPLE_CHARS
|
| 904 |
+
and expr[-1] in TUPLE_CHARS
|
| 905 |
+
and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
|
| 906 |
+
):
|
| 907 |
+
elems = [elem.strip() for elem in expr[1:-1].split(",")]
|
| 908 |
+
else:
|
| 909 |
+
elems = [expr]
|
| 910 |
+
return elems
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
def last_boxed_only_string(string):
|
| 914 |
+
idx = string.rfind("\\boxed")
|
| 915 |
+
if idx < 0:
|
| 916 |
+
idx = string.rfind("\\fbox")
|
| 917 |
+
if idx < 0:
|
| 918 |
+
return None
|
| 919 |
+
|
| 920 |
+
i = idx
|
| 921 |
+
right_brace_idx = None
|
| 922 |
+
num_left_braces_open = 0
|
| 923 |
+
while i < len(string):
|
| 924 |
+
if string[i] == "{":
|
| 925 |
+
num_left_braces_open += 1
|
| 926 |
+
if string[i] == "}":
|
| 927 |
+
num_left_braces_open -= 1
|
| 928 |
+
if num_left_braces_open == 0:
|
| 929 |
+
right_brace_idx = i
|
| 930 |
+
break
|
| 931 |
+
i += 1
|
| 932 |
+
if right_brace_idx is None:
|
| 933 |
+
retval = None
|
| 934 |
+
else:
|
| 935 |
+
retval = string[idx : right_brace_idx + 1]
|
| 936 |
+
|
| 937 |
+
return retval
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def remove_boxed(s):
|
| 941 |
+
left = "\\boxed{"
|
| 942 |
+
try:
|
| 943 |
+
assert s[: len(left)] == left
|
| 944 |
+
assert s[-1] == "}"
|
| 945 |
+
return s[len(left) : -1]
|
| 946 |
+
except Exception:
|
| 947 |
+
return None
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def extract_boxed_answer(solution: str) -> str:
|
| 951 |
+
"""Extract the answer from inside a LaTeX \\boxed{} command"""
|
| 952 |
+
solution = last_boxed_only_string(solution)
|
| 953 |
+
solution = remove_boxed(solution)
|
| 954 |
+
return solution
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
|
| 958 |
+
ground_truth_normalized = _normalize(ground_truth)
|
| 959 |
+
given_normalized = _normalize(given_answer)
|
| 960 |
+
|
| 961 |
+
if ground_truth_normalized is None:
|
| 962 |
+
return False
|
| 963 |
+
|
| 964 |
+
if ground_truth_normalized == given_normalized:
|
| 965 |
+
return True
|
| 966 |
+
|
| 967 |
+
if len(given_normalized) == 0:
|
| 968 |
+
return False
|
| 969 |
+
|
| 970 |
+
ground_truth_elems = split_tuple(ground_truth_normalized)
|
| 971 |
+
given_elems = split_tuple(given_normalized)
|
| 972 |
+
|
| 973 |
+
if len(ground_truth_elems) > 1 and (
|
| 974 |
+
ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]
|
| 975 |
+
):
|
| 976 |
+
is_correct = False
|
| 977 |
+
elif len(ground_truth_elems) != len(given_elems):
|
| 978 |
+
is_correct = False
|
| 979 |
+
else:
|
| 980 |
+
for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):
|
| 981 |
+
if _is_frac(ground_truth_elem) and _is_frac(given_elem):
|
| 982 |
+
# if fractions aren't reduced, then shouldn't be marked as correct
|
| 983 |
+
# so, we don't want to allow sympy.simplify in this case
|
| 984 |
+
is_correct = ground_truth_elem == given_elem
|
| 985 |
+
elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
|
| 986 |
+
# if the ground truth answer is an integer, we require the given answer to be a strict match
|
| 987 |
+
# (no sympy.simplify)
|
| 988 |
+
is_correct = False
|
| 989 |
+
else:
|
| 990 |
+
is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
|
| 991 |
+
if not is_correct:
|
| 992 |
+
break
|
| 993 |
+
|
| 994 |
+
return is_correct
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool:
|
| 998 |
+
ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)
|
| 999 |
+
given_answer_normalized_mathd = mathd_normalize_answer(given_answer)
|
| 1000 |
+
|
| 1001 |
+
# be at least as lenient as mathd
|
| 1002 |
+
if ground_truth_normalized_mathd == given_answer_normalized_mathd:
|
| 1003 |
+
return True
|
| 1004 |
+
return False
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
def extract_answer(passage: str) -> str:
|
| 1008 |
+
if "\\boxed" in passage:
|
| 1009 |
+
return extract_boxed_answer(passage)
|
| 1010 |
+
return None
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def grade(model_answer: str, gt_answer: str, fast: bool = True):
|
| 1014 |
+
if "\\boxed" in gt_answer:
|
| 1015 |
+
gt_answer = extract_answer(gt_answer)
|
| 1016 |
+
correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy(model_answer, gt_answer)
|
| 1017 |
+
if not fast:
|
| 1018 |
+
# This mode further uses math_verify to recall originally false positives.
|
| 1019 |
+
# Will be a bit slower, and sensitive to bad inputs.
|
| 1020 |
+
correct = correct or is_latex_equal(
|
| 1021 |
+
model_answer,
|
| 1022 |
+
gt_answer,
|
| 1023 |
+
)
|
| 1024 |
+
return correct
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
def compute_score(model_response, gt_answer, fast=False):
|
| 1028 |
+
model_answer = extract_answer(model_response)
|
| 1029 |
+
if model_answer is None:
|
| 1030 |
+
return {
|
| 1031 |
+
"score": 0.0,
|
| 1032 |
+
"format_score": 0.0,
|
| 1033 |
+
"acc": False,
|
| 1034 |
+
"extracted_gt": gt_answer,
|
| 1035 |
+
# "extracted_pred": None,
|
| 1036 |
+
}
|
| 1037 |
+
# return 0.0, 0.0 # Cannot even parse anything.
|
| 1038 |
+
is_correct = False
|
| 1039 |
+
if isinstance(gt_answer, float) or isinstance(gt_answer, int):
|
| 1040 |
+
gt_answer = str(gt_answer)
|
| 1041 |
+
if isinstance(gt_answer, str):
|
| 1042 |
+
is_correct = grade(model_answer, gt_answer, fast)
|
| 1043 |
+
elif isinstance(gt_answer, list):
|
| 1044 |
+
is_correct = False
|
| 1045 |
+
for gt in gt_answer:
|
| 1046 |
+
is_correct |= grade(model_answer, gt, fast)
|
| 1047 |
+
if is_correct:
|
| 1048 |
+
return {
|
| 1049 |
+
"score": 1.0,
|
| 1050 |
+
"format_score": 1.0,
|
| 1051 |
+
"acc": True,
|
| 1052 |
+
"extracted_gt": gt_answer,
|
| 1053 |
+
# "extracted_pred": None,
|
| 1054 |
+
}
|
| 1055 |
+
else:
|
| 1056 |
+
return {
|
| 1057 |
+
"score": 0.0,
|
| 1058 |
+
"format_score": 1.0,
|
| 1059 |
+
"acc": False,
|
| 1060 |
+
"extracted_gt": gt_answer,
|
| 1061 |
+
# "extracted_pred": None,
|
| 1062 |
+
}
|
ICL/DAPO/verl-recipe/entropy/reward_score/entropy_math/grader.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright (c) Microsoft Corporation.
|
| 16 |
+
#
|
| 17 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 18 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 19 |
+
# in the Software without restriction, including without limitation the rights
|
| 20 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 21 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 22 |
+
# furnished to do so, subject to the following conditions:
|
| 23 |
+
#
|
| 24 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 25 |
+
# copies or substantial portions of the Software.
|
| 26 |
+
#
|
| 27 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 28 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 29 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 30 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 31 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 32 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 33 |
+
# SOFTWARE
|
| 34 |
+
|
| 35 |
+
# Copyright (c) 2023 OpenAI
|
| 36 |
+
#
|
| 37 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 38 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 39 |
+
# in the Software without restriction, including without limitation the rights
|
| 40 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 41 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 42 |
+
# furnished to do so, subject to the following conditions:
|
| 43 |
+
|
| 44 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 45 |
+
# copies or substantial portions of the Software.
|
| 46 |
+
#
|
| 47 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 48 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 49 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 50 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 51 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 52 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 53 |
+
# SOFTWARE.
|
| 54 |
+
|
| 55 |
+
# Copyright (c) 2021 Dan Hendrycks
|
| 56 |
+
#
|
| 57 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 58 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 59 |
+
# in the Software without restriction, including without limitation the rights
|
| 60 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 61 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 62 |
+
# furnished to do so, subject to the following conditions:
|
| 63 |
+
#
|
| 64 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 65 |
+
# copies or substantial portions of the Software.
|
| 66 |
+
#
|
| 67 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 68 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 69 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 70 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 71 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 72 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 73 |
+
# SOFTWARE.
|
| 74 |
+
|
| 75 |
+
# Copyright 2024 PRIME team and/or its affiliates
|
| 76 |
+
#
|
| 77 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 78 |
+
# you may not use this file except in compliance with the License.
|
| 79 |
+
# You may obtain a copy of the License at
|
| 80 |
+
#
|
| 81 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 82 |
+
#
|
| 83 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 84 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 85 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 86 |
+
# See the License for the specific language governing permissions and
|
| 87 |
+
# limitations under the License.
|
| 88 |
+
"""
|
| 89 |
+
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
|
| 90 |
+
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
|
| 91 |
+
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
| 92 |
+
- https://github.com/openai/prm800k
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
import contextlib
|
| 96 |
+
import math
|
| 97 |
+
import re
|
| 98 |
+
from math import isclose
|
| 99 |
+
|
| 100 |
+
# sympy related
|
| 101 |
+
from sympy import N, simplify
|
| 102 |
+
from sympy.parsing.latex import parse_latex
|
| 103 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 104 |
+
|
| 105 |
+
# verl related
|
| 106 |
+
from verl.utils.py_functional import timeout_limit
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def is_digit(s):
|
| 110 |
+
try:
|
| 111 |
+
if "{,}" in str(s):
|
| 112 |
+
num = float(str(s).replace("{,}", ""))
|
| 113 |
+
return True, num
|
| 114 |
+
|
| 115 |
+
num = float(str(s).replace(",", ""))
|
| 116 |
+
return True, num
|
| 117 |
+
except ValueError:
|
| 118 |
+
return False, None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def normalize(answer, pi) -> str:
|
| 122 |
+
# checking if answer is $<number> and removing $ in that case to compare
|
| 123 |
+
if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)):
|
| 124 |
+
return answer[1:]
|
| 125 |
+
|
| 126 |
+
# checking if answer is <number>% or <number>\\% and removing %
|
| 127 |
+
if isinstance(answer, str) and (
|
| 128 |
+
bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))
|
| 129 |
+
):
|
| 130 |
+
return answer.replace("\\%", "").replace("%", "")
|
| 131 |
+
|
| 132 |
+
# handle base
|
| 133 |
+
answer = handle_base(answer)
|
| 134 |
+
|
| 135 |
+
# handle pi
|
| 136 |
+
answer = handle_pi(answer, pi)
|
| 137 |
+
|
| 138 |
+
return answer
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def handle_base(x) -> str:
|
| 142 |
+
if isinstance(x, str) and "_" in x:
|
| 143 |
+
# Due to base
|
| 144 |
+
x = x.split("_")[0]
|
| 145 |
+
x = float(x)
|
| 146 |
+
return int(x)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def handle_pi(string, pi):
|
| 151 |
+
if isinstance(string, str) and "\\pi" in string:
|
| 152 |
+
# Find the first occurrence of "\pi"
|
| 153 |
+
idx = string.find("\\pi")
|
| 154 |
+
|
| 155 |
+
# Iterate over the string and find all occurrences of "\pi" with a valid previous character
|
| 156 |
+
while idx != -1:
|
| 157 |
+
if idx > 0 and string[idx - 1].isdigit():
|
| 158 |
+
# Replace "\pi" with "*math.pi" if the previous character is a digit
|
| 159 |
+
string = string[:idx] + f"*{pi}" + string[idx + 3 :]
|
| 160 |
+
else:
|
| 161 |
+
# Replace "\pi" with "1*math.pi" if the previous character is not a digit
|
| 162 |
+
string = string[:idx] + f"1*{pi}" + string[idx + 3 :]
|
| 163 |
+
|
| 164 |
+
# Find the next occurrence of "\pi"
|
| 165 |
+
idx = string.find("\\pi", idx + 1)
|
| 166 |
+
|
| 167 |
+
# Evaluate the expression using eval() function
|
| 168 |
+
with contextlib.suppress(Exception):
|
| 169 |
+
string = eval(string)
|
| 170 |
+
|
| 171 |
+
return string
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def math_equal(
|
| 175 |
+
prediction: bool | float | str,
|
| 176 |
+
reference: float | str,
|
| 177 |
+
include_percentage: bool = True,
|
| 178 |
+
tolerance: float = 1e-4,
|
| 179 |
+
timeout: float = 10.0,
|
| 180 |
+
pi: float = math.pi,
|
| 181 |
+
) -> bool:
|
| 182 |
+
"""
|
| 183 |
+
Exact match of math if and only if:
|
| 184 |
+
1. numerical equal: both can convert to float and are equal
|
| 185 |
+
2. symbolic equal: both can convert to sympy expression and are equal
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
prediction = normalize(prediction, pi)
|
| 189 |
+
reference = normalize(reference, pi)
|
| 190 |
+
|
| 191 |
+
if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases
|
| 192 |
+
prediction = prediction[:1000]
|
| 193 |
+
|
| 194 |
+
# 0. string comparison
|
| 195 |
+
if isinstance(prediction, str) and isinstance(reference, str):
|
| 196 |
+
if prediction.strip().lower() == reference.strip().lower():
|
| 197 |
+
return True
|
| 198 |
+
if prediction.replace(" ", "") == reference.replace(" ", ""):
|
| 199 |
+
return True
|
| 200 |
+
|
| 201 |
+
try: # 1. numerical equal
|
| 202 |
+
if is_digit(prediction)[0] and is_digit(reference)[0]:
|
| 203 |
+
prediction = is_digit(prediction)[1]
|
| 204 |
+
reference = is_digit(reference)[1]
|
| 205 |
+
# number questions
|
| 206 |
+
gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference]
|
| 207 |
+
for item in gt_result:
|
| 208 |
+
try:
|
| 209 |
+
if isclose(item, prediction, rel_tol=tolerance):
|
| 210 |
+
return True
|
| 211 |
+
except Exception:
|
| 212 |
+
continue
|
| 213 |
+
return False
|
| 214 |
+
except Exception:
|
| 215 |
+
pass
|
| 216 |
+
|
| 217 |
+
if not prediction and prediction not in [0, False]:
|
| 218 |
+
return False
|
| 219 |
+
|
| 220 |
+
# 2. symbolic equal
|
| 221 |
+
reference = str(reference).strip()
|
| 222 |
+
prediction = str(prediction).strip()
|
| 223 |
+
|
| 224 |
+
## deal with [], (), {}
|
| 225 |
+
prediction = format_intervals(prediction)
|
| 226 |
+
|
| 227 |
+
pred_str, ref_str = prediction, reference
|
| 228 |
+
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (
|
| 229 |
+
prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")
|
| 230 |
+
):
|
| 231 |
+
pred_str = pred_str.strip("[]()")
|
| 232 |
+
ref_str = ref_str.strip("[]()")
|
| 233 |
+
for s in ["{", "}", "(", ")"]:
|
| 234 |
+
ref_str = ref_str.replace(s, "")
|
| 235 |
+
pred_str = pred_str.replace(s, "")
|
| 236 |
+
if pred_str == ref_str:
|
| 237 |
+
return True
|
| 238 |
+
|
| 239 |
+
## [a, b] vs. [c, d], return a==c and b==d
|
| 240 |
+
if (
|
| 241 |
+
prediction
|
| 242 |
+
and reference
|
| 243 |
+
and prediction[0] in "(["
|
| 244 |
+
and prediction[-1] in ")]"
|
| 245 |
+
and prediction[0] == reference[0]
|
| 246 |
+
and prediction[-1] == reference[-1]
|
| 247 |
+
):
|
| 248 |
+
pred_parts = prediction[1:-1].split(",")
|
| 249 |
+
ref_parts = reference[1:-1].split(",")
|
| 250 |
+
if len(pred_parts) == len(ref_parts) and all(
|
| 251 |
+
[
|
| 252 |
+
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
|
| 253 |
+
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
|
| 254 |
+
]
|
| 255 |
+
):
|
| 256 |
+
return True
|
| 257 |
+
|
| 258 |
+
if "," in prediction and "," in reference:
|
| 259 |
+
pred_parts = [item.strip() for item in prediction.split(",")]
|
| 260 |
+
ref_parts = [item.strip() for item in reference.split(",")]
|
| 261 |
+
|
| 262 |
+
if len(pred_parts) == len(ref_parts):
|
| 263 |
+
return bool(
|
| 264 |
+
all(
|
| 265 |
+
[
|
| 266 |
+
math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance)
|
| 267 |
+
for i in range(len(pred_parts))
|
| 268 |
+
]
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# if we have point == tuple of values
|
| 273 |
+
if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")":
|
| 274 |
+
pred_parts = prediction[prediction.find("(") + 1 : -1].split(",")
|
| 275 |
+
ref_parts = reference[1:-1].split(",")
|
| 276 |
+
if len(pred_parts) == len(ref_parts) and all(
|
| 277 |
+
[
|
| 278 |
+
math_equal(pred_pt, ref_pt, include_percentage, tolerance)
|
| 279 |
+
for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True)
|
| 280 |
+
]
|
| 281 |
+
):
|
| 282 |
+
return True
|
| 283 |
+
|
| 284 |
+
# if reference is a matrix
|
| 285 |
+
if r"\begin{pmatrix}" in reference and prediction.startswith("Matrix"):
|
| 286 |
+
try:
|
| 287 |
+
pred_matrix = parse_expr(prediction)
|
| 288 |
+
ref_matrix_items = reference.split()[1:-1:2]
|
| 289 |
+
if len(pred_matrix) == len(ref_matrix_items) and all(
|
| 290 |
+
[
|
| 291 |
+
math_equal(pred, ref, include_percentage, tolerance)
|
| 292 |
+
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
|
| 293 |
+
]
|
| 294 |
+
):
|
| 295 |
+
return True
|
| 296 |
+
except Exception:
|
| 297 |
+
pass
|
| 298 |
+
elif r"\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"):
|
| 299 |
+
if isinstance(eval(prediction), list):
|
| 300 |
+
try:
|
| 301 |
+
pred_matrix = eval(prediction)
|
| 302 |
+
# ref_matrix_items = reference.split()[1:-1:2]
|
| 303 |
+
ref_matrix_items = (
|
| 304 |
+
reference.removeprefix(r"\\begin{pmatrix}")
|
| 305 |
+
.removeprefix(r"\begin{pmatrix}")
|
| 306 |
+
.removesuffix(r"\\end{pmatrix}")
|
| 307 |
+
.removesuffix(r"\end{pmatrix}")
|
| 308 |
+
)
|
| 309 |
+
ref_matrix_items = ref_matrix_items.split("\\")
|
| 310 |
+
ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items]
|
| 311 |
+
if len(pred_matrix) == len(ref_matrix_items) and all(
|
| 312 |
+
[
|
| 313 |
+
math_equal(pred, ref, include_percentage, tolerance)
|
| 314 |
+
for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True)
|
| 315 |
+
]
|
| 316 |
+
):
|
| 317 |
+
return True
|
| 318 |
+
except Exception:
|
| 319 |
+
pass
|
| 320 |
+
|
| 321 |
+
return symbolic_equal(prediction, reference, tolerance, timeout)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def symbolic_equal(a, b, tolerance, timeout=10.0):
|
| 325 |
+
def _parse(s):
|
| 326 |
+
for f in [parse_expr, parse_latex]:
|
| 327 |
+
try:
|
| 328 |
+
with timeout_limit(seconds=timeout):
|
| 329 |
+
return f(s)
|
| 330 |
+
except TimeoutError:
|
| 331 |
+
print(f"Parsing timed out for {s}")
|
| 332 |
+
continue
|
| 333 |
+
except Exception:
|
| 334 |
+
continue
|
| 335 |
+
return s
|
| 336 |
+
|
| 337 |
+
a = _parse(a)
|
| 338 |
+
b = _parse(b)
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
with timeout_limit(seconds=timeout):
|
| 342 |
+
if simplify(a - b) == 0:
|
| 343 |
+
return True
|
| 344 |
+
except TimeoutError:
|
| 345 |
+
print(f"Simplification timed out for {a} - {b}")
|
| 346 |
+
pass
|
| 347 |
+
except Exception:
|
| 348 |
+
pass
|
| 349 |
+
|
| 350 |
+
try:
|
| 351 |
+
with timeout_limit(seconds=timeout):
|
| 352 |
+
if isclose(N(a), N(b), rel_tol=tolerance):
|
| 353 |
+
return True
|
| 354 |
+
except TimeoutError:
|
| 355 |
+
print(f"Numerical evaluation timed out for {a}, {b}")
|
| 356 |
+
pass
|
| 357 |
+
except Exception:
|
| 358 |
+
pass
|
| 359 |
+
return False
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def format_intervals(prediction):
|
| 363 |
+
patterns = {
|
| 364 |
+
"Interval(": r"^Interval\((.*)\)$",
|
| 365 |
+
"Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$",
|
| 366 |
+
"Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$",
|
| 367 |
+
"Interval.open(": r"^Interval\.open\((.*)\)$",
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
for key, pattern in patterns.items():
|
| 371 |
+
match = re.match(pattern, prediction)
|
| 372 |
+
if match:
|
| 373 |
+
inner_content = match.group(1)
|
| 374 |
+
|
| 375 |
+
if key == "Interval(": # Intarval(a, b) == [a, b]
|
| 376 |
+
return f"[{inner_content}]"
|
| 377 |
+
elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b)
|
| 378 |
+
return f"[{inner_content})"
|
| 379 |
+
elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b]
|
| 380 |
+
return f"({inner_content}]"
|
| 381 |
+
elif key == "Interval.open(": # Intarval.open(a, b) == (a, b)
|
| 382 |
+
return f"({inner_content})"
|
| 383 |
+
|
| 384 |
+
return prediction
|
ICL/DAPO/verl-recipe/fapo/config/rm_config.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- ppo_trainer
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
reward_model:
|
| 10 |
+
_target_: verl.workers.config.RewardModelConfig
|
| 11 |
+
|
| 12 |
+
reward_manager: dapo
|
| 13 |
+
enable: False
|
| 14 |
+
|
| 15 |
+
# Whether to deploy the model to a separate resource pool.
|
| 16 |
+
enable_resource_pool: False
|
| 17 |
+
n_gpus_per_node: 0
|
| 18 |
+
nnodes: 0
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
type: discriminative
|
| 22 |
+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
| 23 |
+
external_lib: ${actor_rollout_ref.model.external_lib}
|
| 24 |
+
trust_remote_code: False
|
| 25 |
+
|
| 26 |
+
rollout:
|
| 27 |
+
_target_: verl.workers.config.RolloutConfig
|
| 28 |
+
name: ???
|
| 29 |
+
dtype: bfloat16
|
| 30 |
+
gpu_memory_utilization: 0.5
|
| 31 |
+
enforce_eager: true
|
| 32 |
+
cudagraph_capture_sizes: null
|
| 33 |
+
free_cache_engine: true
|
| 34 |
+
data_parallel_size: 1
|
| 35 |
+
expert_parallel_size: 1
|
| 36 |
+
tensor_model_parallel_size: 2
|
| 37 |
+
max_num_batched_tokens: 8192
|
| 38 |
+
max_model_len: null
|
| 39 |
+
max_num_seqs: 1024
|
| 40 |
+
load_format: auto
|
| 41 |
+
engine_kwargs: {}
|
| 42 |
+
limit_images: null
|
| 43 |
+
enable_chunked_prefill: true
|
| 44 |
+
enable_prefix_caching: true
|
| 45 |
+
disable_log_stats: true
|
| 46 |
+
skip_tokenizer_init: true
|
| 47 |
+
|
| 48 |
+
prompt_length: 512
|
| 49 |
+
response_length: 512
|
ICL/DAPO/verl-recipe/fapo/run_fapo_genrm_train.sh
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
project_name='FAPO-Reproduce'
|
| 5 |
+
exp_name='FAPO-GenRM-4B'
|
| 6 |
+
|
| 7 |
+
adv_estimator=grpo
|
| 8 |
+
|
| 9 |
+
use_kl_in_reward=False
|
| 10 |
+
kl_coef=0.0
|
| 11 |
+
use_kl_loss=False
|
| 12 |
+
kl_loss_coef=0.0
|
| 13 |
+
|
| 14 |
+
clip_ratio_low=0.2
|
| 15 |
+
clip_ratio_high=0.28
|
| 16 |
+
|
| 17 |
+
max_prompt_length=$((1024 * 5))
|
| 18 |
+
max_response_length=$((1024 * 8))
|
| 19 |
+
enable_overlong_buffer=True
|
| 20 |
+
overlong_buffer_len=$((1024 * 4))
|
| 21 |
+
overlong_penalty_factor=1.0
|
| 22 |
+
|
| 23 |
+
loss_agg_mode="token-mean"
|
| 24 |
+
|
| 25 |
+
train_prompt_bsz=512
|
| 26 |
+
n_resp_per_prompt=16
|
| 27 |
+
train_prompt_mini_bsz=32
|
| 28 |
+
|
| 29 |
+
# Ray
|
| 30 |
+
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
| 31 |
+
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
| 32 |
+
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
| 33 |
+
NNODES=${NNODES:-4}
|
| 34 |
+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
| 35 |
+
# Paths
|
| 36 |
+
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
| 37 |
+
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
| 38 |
+
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-4B-Instruct-2507"}
|
| 39 |
+
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
| 40 |
+
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/train.parquet"}
|
| 41 |
+
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/test.parquet"}
|
| 42 |
+
|
| 43 |
+
# Algorithm
|
| 44 |
+
temperature=1.2
|
| 45 |
+
top_p=1.0
|
| 46 |
+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
| 47 |
+
val_temperature=0.6
|
| 48 |
+
val_top_p=0.95
|
| 49 |
+
|
| 50 |
+
# Performance Related Parameter
|
| 51 |
+
sp_size=1
|
| 52 |
+
use_dynamic_bsz=True
|
| 53 |
+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
| 54 |
+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
| 55 |
+
offload=True
|
| 56 |
+
gen_tp=1
|
| 57 |
+
fsdp_size=8
|
| 58 |
+
|
| 59 |
+
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
|
| 60 |
+
--address "${RAY_ADDRESS}" \
|
| 61 |
+
--working-dir "${WORKING_DIR}" \
|
| 62 |
+
-- python3 -m verl.trainer.main_ppo \
|
| 63 |
+
data.train_files="${TRAIN_FILE}" \
|
| 64 |
+
data.val_files="${TEST_FILE}" \
|
| 65 |
+
data.prompt_key=prompt \
|
| 66 |
+
data.truncation='left' \
|
| 67 |
+
data.max_prompt_length=${max_prompt_length} \
|
| 68 |
+
data.max_response_length=${max_response_length} \
|
| 69 |
+
data.train_batch_size=${train_prompt_bsz} \
|
| 70 |
+
data.return_raw_chat=True \
|
| 71 |
+
data.filter_overlong_prompts=True \
|
| 72 |
+
data.truncation='error' \
|
| 73 |
+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
| 74 |
+
algorithm.adv_estimator=${adv_estimator} \
|
| 75 |
+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
| 76 |
+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
| 77 |
+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
| 78 |
+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
| 79 |
+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
| 80 |
+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
| 81 |
+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
| 82 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 83 |
+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
| 84 |
+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
| 85 |
+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
| 86 |
+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
| 87 |
+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
| 88 |
+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
| 89 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 90 |
+
actor_rollout_ref.rollout.mode=async \
|
| 91 |
+
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
| 92 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 93 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 94 |
+
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
| 95 |
+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
| 96 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
| 97 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
| 98 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
| 99 |
+
actor_rollout_ref.actor.entropy_coeff=0 \
|
| 100 |
+
actor_rollout_ref.actor.grad_clip=1.0 \
|
| 101 |
+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
| 102 |
+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
| 103 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
| 104 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
| 105 |
+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
| 106 |
+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
| 107 |
+
actor_rollout_ref.rollout.temperature=${temperature} \
|
| 108 |
+
actor_rollout_ref.rollout.top_p=${top_p} \
|
| 109 |
+
actor_rollout_ref.rollout.top_k=${top_k} \
|
| 110 |
+
actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
|
| 111 |
+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
| 112 |
+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
| 113 |
+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
| 114 |
+
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
| 115 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
| 116 |
+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
| 117 |
+
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
| 118 |
+
reward_model.reward_manager=dapo \
|
| 119 |
+
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
| 120 |
+
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
| 121 |
+
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
| 122 |
+
+reward_model.reward_kwargs.overlong_buffer_cfg.log=True \
|
| 123 |
+
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
| 124 |
+
custom_reward_function.path=recipe/fapo/reward_fn_genrm.py \
|
| 125 |
+
custom_reward_function.name=compute_score_fapo_genrm \
|
| 126 |
+
trainer.logger='["console","wandb"]' \
|
| 127 |
+
trainer.project_name="${project_name}" \
|
| 128 |
+
trainer.experiment_name="${exp_name}" \
|
| 129 |
+
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
| 130 |
+
trainer.nnodes="${NNODES}" \
|
| 131 |
+
trainer.val_before_train=True \
|
| 132 |
+
trainer.test_freq=10 \
|
| 133 |
+
trainer.save_freq=10 \
|
| 134 |
+
trainer.total_epochs=10 \
|
| 135 |
+
trainer.total_training_steps=500 \
|
| 136 |
+
trainer.default_local_dir="${CKPTS_DIR}" \
|
| 137 |
+
trainer.resume_mode=auto \
|
| 138 |
+
trainer.log_val_generations=10
|
ICL/DAPO/verl-recipe/fapo/runtime_env.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
working_dir: ./
|
| 2 |
+
excludes: ["/.git/"]
|
| 3 |
+
env_vars:
|
| 4 |
+
TORCH_NCCL_AVOID_RECORD_STREAMS: "1"
|
| 5 |
+
VLLM_USE_V1: "1"
|
ICL/DAPO/verl-recipe/fault_recover/agent_loop/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .fault_recover_agent_loop import FaultRecoverAgentLoopManager
|
| 16 |
+
from .fault_recover_single_turn_agent_loop import FaultRecoverSingleTurnAgentLoop
|
| 17 |
+
|
| 18 |
+
_ = [FaultRecoverSingleTurnAgentLoop, FaultRecoverAgentLoopManager]
|
| 19 |
+
|
| 20 |
+
__all__ = ["FaultRecoverSingleTurnAgentLoop", "FaultRecoverAgentLoopManager"]
|
ICL/DAPO/verl-recipe/fault_recover/agent_loop/fault_recover_single_turn_agent_loop.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from typing import Any
|
| 17 |
+
from uuid import uuid4
|
| 18 |
+
|
| 19 |
+
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register
|
| 20 |
+
from verl.tools.utils.tool_registry import initialize_tools_from_config
|
| 21 |
+
from verl.utils.profiler import simple_timer
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__file__)
|
| 24 |
+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@register("fault_recover_single_turn_agent")
|
| 28 |
+
class FaultRecoverSingleTurnAgentLoop(AgentLoopBase):
|
| 29 |
+
"""Naive agent loop that only do single turn chat completion."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
super().__init__(*args, **kwargs)
|
| 33 |
+
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
|
| 34 |
+
self.response_length = self.config.actor_rollout_ref.rollout.response_length
|
| 35 |
+
|
| 36 |
+
tool_config_path = self.config.data.tool_config_path
|
| 37 |
+
tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
|
| 38 |
+
self.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list]
|
| 39 |
+
|
| 40 |
+
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
| 41 |
+
messages = list(kwargs["raw_prompt"])
|
| 42 |
+
|
| 43 |
+
# 1. extract images and videos from messages
|
| 44 |
+
multi_modal_data = await self.process_vision_info(messages)
|
| 45 |
+
images = multi_modal_data.get("images")
|
| 46 |
+
videos = multi_modal_data.get("videos")
|
| 47 |
+
|
| 48 |
+
# 2. apply chat template and tokenize
|
| 49 |
+
prompt_ids = await self.apply_chat_template(
|
| 50 |
+
messages,
|
| 51 |
+
tools=self.tool_schemas,
|
| 52 |
+
images=images,
|
| 53 |
+
videos=videos,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# 3. generate sequences
|
| 57 |
+
metrics = {}
|
| 58 |
+
request_id = uuid4().hex
|
| 59 |
+
new_token_ids = kwargs.get("new_token_ids", [])
|
| 60 |
+
finished = kwargs.get("finished", False)
|
| 61 |
+
num_preempted = kwargs.get("num_preempted")
|
| 62 |
+
if finished:
|
| 63 |
+
with simple_timer("generate_sequences", metrics):
|
| 64 |
+
response_mask = [1] * len(new_token_ids)
|
| 65 |
+
if metrics.get("num_preempted") is None:
|
| 66 |
+
metrics["num_preempted"] = num_preempted if num_preempted is not None else -1
|
| 67 |
+
return AgentLoopOutput(
|
| 68 |
+
prompt_ids=prompt_ids,
|
| 69 |
+
response_ids=new_token_ids[: self.response_length],
|
| 70 |
+
response_mask=response_mask[: self.response_length],
|
| 71 |
+
response_logprobs=kwargs.get("log_probs"),
|
| 72 |
+
routed_experts=kwargs.get("routed_experts"),
|
| 73 |
+
multi_modal_data=multi_modal_data,
|
| 74 |
+
num_turns=2,
|
| 75 |
+
metrics=metrics,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
origin_prompt_length = len(prompt_ids)
|
| 79 |
+
prompt_ids += new_token_ids
|
| 80 |
+
|
| 81 |
+
with simple_timer("generate_sequences", metrics):
|
| 82 |
+
output = await self.server_manager.generate(
|
| 83 |
+
request_id=request_id,
|
| 84 |
+
prompt_ids=prompt_ids,
|
| 85 |
+
sampling_params=sampling_params,
|
| 86 |
+
image_data=images,
|
| 87 |
+
video_data=videos,
|
| 88 |
+
global_id=kwargs.get("global_id"),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if metrics.get("num_preempted") is None:
|
| 92 |
+
metrics["num_preempted"] = output.num_preempted if output.num_preempted is not None else -1
|
| 93 |
+
|
| 94 |
+
all_token_ids = new_token_ids + output.token_ids
|
| 95 |
+
response_mask = [1] * len(all_token_ids)
|
| 96 |
+
|
| 97 |
+
output = AgentLoopOutput(
|
| 98 |
+
prompt_ids=prompt_ids[:origin_prompt_length],
|
| 99 |
+
response_ids=all_token_ids[: self.response_length],
|
| 100 |
+
response_mask=response_mask[: self.response_length],
|
| 101 |
+
response_logprobs=output.log_probs[: self.response_length] if output.log_probs else None,
|
| 102 |
+
routed_experts=(
|
| 103 |
+
output.routed_experts[: len(prompt_ids) + self.response_length]
|
| 104 |
+
if output.routed_experts is not None
|
| 105 |
+
else None
|
| 106 |
+
),
|
| 107 |
+
multi_modal_data=multi_modal_data,
|
| 108 |
+
num_turns=2,
|
| 109 |
+
metrics=metrics,
|
| 110 |
+
)
|
| 111 |
+
return output
|
ICL/DAPO/verl-recipe/fault_recover/config/fault_recover_ppo_megatron_trainer.yaml
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
# specify the default per-component configs
|
| 6 |
+
defaults:
|
| 7 |
+
# <folder_name>@<field_name>.<field_name>: <yaml_file_name>
|
| 8 |
+
# actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml
|
| 9 |
+
- actor@actor_rollout_ref.actor: megatron_actor
|
| 10 |
+
# data: trainer/config/data/legacy_data.yaml
|
| 11 |
+
- data@data: legacy_data
|
| 12 |
+
# (Rule-based) Reward manager config.
|
| 13 |
+
- reward_manager@reward_manager
|
| 14 |
+
# load the reference default config, then apply the fields in the current yaml
|
| 15 |
+
# Reference model config.
|
| 16 |
+
# Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True.
|
| 17 |
+
- ref@actor_rollout_ref.ref: megatron_ref
|
| 18 |
+
# Rollout model config.
|
| 19 |
+
- rollout@actor_rollout_ref.rollout: rollout
|
| 20 |
+
# Model config.
|
| 21 |
+
- model@actor_rollout_ref.model: hf_model
|
| 22 |
+
# Critic model config.
|
| 23 |
+
- critic@critic: megatron_critic
|
| 24 |
+
# Reward model config.
|
| 25 |
+
- reward_model@reward_model: megatron_reward_loop
|
| 26 |
+
# Rollout correction config.
|
| 27 |
+
- algorithm@algorithm.rollout_correction: rollout_correction
|
| 28 |
+
- _self_
|
| 29 |
+
|
| 30 |
+
actor_rollout_ref:
|
| 31 |
+
hybrid_engine: True
|
| 32 |
+
|
| 33 |
+
nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
|
| 34 |
+
|
| 35 |
+
model:
|
| 36 |
+
override_config:
|
| 37 |
+
model_config: {}
|
| 38 |
+
moe_config:
|
| 39 |
+
freeze_moe_router: False
|
| 40 |
+
|
| 41 |
+
use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency)
|
| 42 |
+
|
| 43 |
+
trust_remote_code: False
|
| 44 |
+
|
| 45 |
+
# Whether to remove padding tokens in inputs during training
|
| 46 |
+
use_remove_padding: false
|
| 47 |
+
|
| 48 |
+
# LoRA (Low-Rank Adaptation) configuration for parameter-efficient fine-tuning
|
| 49 |
+
lora:
|
| 50 |
+
# LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora"
|
| 51 |
+
type: lora
|
| 52 |
+
|
| 53 |
+
# LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA
|
| 54 |
+
rank: 0 # typical values: 8, 16, 32, 64
|
| 55 |
+
|
| 56 |
+
# Weighting factor for the low-rank projection. Defaults to 32
|
| 57 |
+
alpha: 32
|
| 58 |
+
|
| 59 |
+
# Dropout rate for the low-rank projection. Defaults to 0.0
|
| 60 |
+
dropout: 0.0
|
| 61 |
+
|
| 62 |
+
# A list of module names to apply LoRA to.
|
| 63 |
+
# For fused LoRA, Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'].
|
| 64 |
+
# For canonical LoRA: ["linear_q", "linear_k", "linear_v", "linear_proj", "linear_fc1_up", "linear_fc1_gate", "linear_fc2"]
|
| 65 |
+
# - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections in self-attention
|
| 66 |
+
# - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention
|
| 67 |
+
# - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP
|
| 68 |
+
# - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP
|
| 69 |
+
# Target modules can also contain wildcards. For example, you can specify
|
| 70 |
+
# target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers
|
| 71 |
+
target_modules:
|
| 72 |
+
- linear_qkv
|
| 73 |
+
- linear_proj
|
| 74 |
+
- linear_fc1
|
| 75 |
+
- linear_fc2
|
| 76 |
+
|
| 77 |
+
# A list of module names not to apply LoRa to. It will match all nn.Linear & nn.Linear-adjacent modules whose name
|
| 78 |
+
# does not match any string in exclude_modules. If used, will require target_modules to be empty list or None
|
| 79 |
+
exclude_modules: []
|
| 80 |
+
|
| 81 |
+
# Position for applying dropout, can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre'
|
| 82 |
+
dropout_position: pre
|
| 83 |
+
|
| 84 |
+
# Initialization method for the low-rank matrix A. Defaults to "xavier".
|
| 85 |
+
lora_A_init_method: xavier
|
| 86 |
+
|
| 87 |
+
# Initialization method for the low-rank matrix B. Defaults to "zero".
|
| 88 |
+
lora_B_init_method: zero
|
| 89 |
+
|
| 90 |
+
# Enables the experimental All-to-All (A2A) communication strategy. Defaults to False
|
| 91 |
+
a2a_experimental: False
|
| 92 |
+
|
| 93 |
+
# Parameter data type for LoRA weights. Default to null, which will use model's dtype.
|
| 94 |
+
dtype: null
|
| 95 |
+
|
| 96 |
+
# Path to pre-trained LoRA adapter weights (null to train from scratch)
|
| 97 |
+
adapter_path: null
|
| 98 |
+
|
| 99 |
+
# VLMLoRA additionally allows the user to specify whether the language or vision models should be frozen.
|
| 100 |
+
# For example, a common finetuning workload for multimodal models is to apply adapters to language model and fully
|
| 101 |
+
# finetune the vision model.
|
| 102 |
+
freeze_vision_model: True
|
| 103 |
+
freeze_vision_projection: True
|
| 104 |
+
freeze_language_model: True
|
| 105 |
+
|
| 106 |
+
rollout:
|
| 107 |
+
quantization: null
|
| 108 |
+
|
| 109 |
+
layer_name_map:
|
| 110 |
+
qkv_layer_name: qkv
|
| 111 |
+
gate_proj_layer_name: gate_up
|
| 112 |
+
|
| 113 |
+
custom_reward_function:
|
| 114 |
+
path: null
|
| 115 |
+
name: compute_score
|
| 116 |
+
|
| 117 |
+
algorithm:
|
| 118 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 119 |
+
_target_: verl.trainer.config.AlgoConfig
|
| 120 |
+
gamma: 1.0
|
| 121 |
+
lam: 1.0
|
| 122 |
+
adv_estimator: gae
|
| 123 |
+
norm_adv_by_std_in_grpo: True
|
| 124 |
+
use_kl_in_reward: False
|
| 125 |
+
kl_penalty: kl # how to estimate kl divergence
|
| 126 |
+
kl_ctrl:
|
| 127 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 128 |
+
_target_: verl.trainer.config.KLControlConfig
|
| 129 |
+
type: fixed
|
| 130 |
+
kl_coef: 0.001
|
| 131 |
+
horizon: 10000
|
| 132 |
+
target_kl: 0.1
|
| 133 |
+
use_pf_ppo: False
|
| 134 |
+
pf_ppo:
|
| 135 |
+
reweight_method: pow # ["pow", "max_min", "max_random"]
|
| 136 |
+
weight_pow: 2.0
|
| 137 |
+
|
| 138 |
+
trainer:
|
| 139 |
+
balance_batch: True
|
| 140 |
+
total_epochs: 30
|
| 141 |
+
total_training_steps: null
|
| 142 |
+
project_name: verl_examples
|
| 143 |
+
experiment_name: gsm8k
|
| 144 |
+
logger: ["console", "wandb"]
|
| 145 |
+
log_val_generations: 0
|
| 146 |
+
nnodes: 1
|
| 147 |
+
n_gpus_per_node: 8
|
| 148 |
+
save_freq: -1
|
| 149 |
+
esi_redundant_time: 0
|
| 150 |
+
|
| 151 |
+
# auto: find the last ckpt to resume. If can't find, start from scratch
|
| 152 |
+
resume_mode: auto # or disable or resume_path if resume_from_path is set
|
| 153 |
+
resume_from_path: null
|
| 154 |
+
del_local_ckpt_after_load: False
|
| 155 |
+
val_before_train: True
|
| 156 |
+
test_freq: -1
|
| 157 |
+
critic_warmup: 0
|
| 158 |
+
default_hdfs_dir: null
|
| 159 |
+
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
| 160 |
+
max_actor_ckpt_to_keep: null
|
| 161 |
+
max_critic_ckpt_to_keep: null
|
| 162 |
+
# The timeout for ray worker group to wait for the register center to be ready
|
| 163 |
+
ray_wait_register_center_timeout: 300
|
| 164 |
+
device: cuda
|
| 165 |
+
# Directory for logging rollout data; no dump if null
|
| 166 |
+
rollout_data_dir: null
|
| 167 |
+
|
| 168 |
+
# whether to use legacy worker implementation
|
| 169 |
+
# mode: "auto", "enable", or "disable"
|
| 170 |
+
use_legacy_worker_impl: auto
|
| 171 |
+
|
| 172 |
+
global_profiler:
|
| 173 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 174 |
+
tool: null # choose between nsys, npu, torch, torch_memory
|
| 175 |
+
steps: null # profile steps
|
| 176 |
+
profile_continuous_steps: False
|
| 177 |
+
save_path: "outputs/profile" # profiler saving path
|
| 178 |
+
# Specific tool configs, can use +profiler.tool_config.[tool].xxx to config
|
| 179 |
+
global_tool_config:
|
| 180 |
+
# nsys config
|
| 181 |
+
nsys:
|
| 182 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 183 |
+
discrete: False
|
| 184 |
+
|
| 185 |
+
# controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.
|
| 186 |
+
## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html
|
| 187 |
+
## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html
|
| 188 |
+
controller_nsight_options:
|
| 189 |
+
# Select the API(s) to be traced.
|
| 190 |
+
trace: "cuda,nvtx,cublas,ucx"
|
| 191 |
+
|
| 192 |
+
# Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
|
| 193 |
+
cuda-memory-usage: "true"
|
| 194 |
+
|
| 195 |
+
# CUDA graphs will be traced as a whole
|
| 196 |
+
cuda-graph-trace: "graph"
|
| 197 |
+
|
| 198 |
+
# worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.
|
| 199 |
+
worker_nsight_options:
|
| 200 |
+
# Select the API(s) to be traced.
|
| 201 |
+
trace: "cuda,nvtx,cublas,ucx"
|
| 202 |
+
|
| 203 |
+
# Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
|
| 204 |
+
cuda-memory-usage: "true"
|
| 205 |
+
|
| 206 |
+
# CUDA graphs will be traced as a whole
|
| 207 |
+
cuda-graph-trace: "graph"
|
| 208 |
+
|
| 209 |
+
# Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.
|
| 210 |
+
capture-range: "cudaProfilerApi"
|
| 211 |
+
|
| 212 |
+
# Specify the desired behavior when a capture range ends.
|
| 213 |
+
# In verl we need the torch.cuda.profiler.start/stop pair to repeats n times.
|
| 214 |
+
# valid values are "repeat-shutdown:n" or null.
|
| 215 |
+
# For normal whole step profiling, n = len(profile_steps);
|
| 216 |
+
# but for discrete profiling, n = len(profile_steps) * Number(subtasks).
|
| 217 |
+
# Or you can just leave it null and the program will use n = len(profile_steps) * 6;
|
| 218 |
+
capture-range-end: null
|
| 219 |
+
|
| 220 |
+
# Send signal to the target application's process group. We let the program to exit by itself.
|
| 221 |
+
kill: none
|
| 222 |
+
|
| 223 |
+
# enable memory visualization for debugging memory usage
|
| 224 |
+
torch_memory:
|
| 225 |
+
# Maximum number of allocation entries to record
|
| 226 |
+
trace_alloc_max_entries: 100_000
|
| 227 |
+
# The depth of the call stack to capture for each allocation
|
| 228 |
+
stack_depth: 32
|
| 229 |
+
# 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both.
|
| 230 |
+
context: "all"
|
| 231 |
+
# 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both.
|
| 232 |
+
stacks: "all"
|
| 233 |
+
# devices, record_context etc.
|
| 234 |
+
kw_args: {}
|
| 235 |
+
|
| 236 |
+
# configs for TransferQueue
|
| 237 |
+
transfer_queue:
|
| 238 |
+
# Whether to enable transfer queue
|
| 239 |
+
enable: False
|
| 240 |
+
|
| 241 |
+
ray_kwargs:
|
| 242 |
+
ray_init:
|
| 243 |
+
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
|
| 244 |
+
timeline_json_file: null
|
| 245 |
+
|
| 246 |
+
fault_manager:
|
| 247 |
+
enable: False
|
| 248 |
+
# max retry times for other training phases except rollout (restart ray)
|
| 249 |
+
max_reschedule_times: 1
|
| 250 |
+
# max retry times for rollout phase (rebuild worker group)
|
| 251 |
+
max_rebuild_times: 1
|
| 252 |
+
# timeout of waiting cluster to be ready for reschedule
|
| 253 |
+
timeout_reschedule: 300
|
| 254 |
+
# timeout of waiting cluster to be ready for rebuild
|
| 255 |
+
timeout_rebuild: 300
|
| 256 |
+
# check chips usage interval during rollout, set -1 to disable timeout check
|
| 257 |
+
timeout_task_check_interval: 10
|
| 258 |
+
# timeout of chips usage being free, set -1 to disable chip check and
|
| 259 |
+
# 'timeout_task_check_interval' will be the whole time limit of rollout
|
| 260 |
+
# which means you should increase it
|
| 261 |
+
timeout_chip_free: 30
|
| 262 |
+
# file path for token saving
|
| 263 |
+
tokens_save_file: ./tokens_ckpt/tokens.pt
|
| 264 |
+
# interval of saving tokens to disk
|
| 265 |
+
tokens_save_interval: 10
|
ICL/DAPO/verl-recipe/fault_recover/vllm_rollout/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
ICL/DAPO/verl-recipe/fault_recover/vllm_rollout/vllm_async_server.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import inspect
|
| 16 |
+
import logging
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
import ray
|
| 20 |
+
import vllm
|
| 21 |
+
from packaging import version
|
| 22 |
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
| 23 |
+
from vllm.entrypoints.openai.api_server import (
|
| 24 |
+
build_app,
|
| 25 |
+
init_app_state,
|
| 26 |
+
)
|
| 27 |
+
from vllm.usage.usage_lib import UsageContext
|
| 28 |
+
|
| 29 |
+
from verl.workers.config import HFModelConfig, RolloutConfig
|
| 30 |
+
from verl.workers.rollout.utils import run_unvicorn
|
| 31 |
+
from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMHttpServer, vLLMReplica
|
| 32 |
+
|
| 33 |
+
_VLLM_VERSION = version.parse(vllm.__version__)
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__file__)
|
| 36 |
+
logger.setLevel(logging.INFO)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class FaultRecovervLLMHttpServer(vLLMHttpServer):
|
| 40 |
+
"""vLLM http server in single node, this is equivalent to launch server with command line:
|
| 41 |
+
```
|
| 42 |
+
vllm serve --tensor-parallel-size=8 ...
|
| 43 |
+
```
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
async def run_server(self, args: argparse.Namespace):
|
| 47 |
+
from recipe.fault_recover.async_llm import AsyncFaultRecoverLLM as AsyncLLM
|
| 48 |
+
|
| 49 |
+
engine_args = AsyncEngineArgs.from_cli_args(args)
|
| 50 |
+
usage_context = UsageContext.OPENAI_API_SERVER
|
| 51 |
+
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
| 52 |
+
vllm_config.parallel_config.data_parallel_master_port = self._dp_master_port
|
| 53 |
+
|
| 54 |
+
fn_args = set(dict(inspect.signature(AsyncLLM.from_vllm_config).parameters).keys())
|
| 55 |
+
kwargs = {}
|
| 56 |
+
if "enable_log_requests" in fn_args:
|
| 57 |
+
kwargs["enable_log_requests"] = engine_args.enable_log_requests
|
| 58 |
+
if "disable_log_stats" in fn_args:
|
| 59 |
+
kwargs["disable_log_stats"] = engine_args.disable_log_stats
|
| 60 |
+
|
| 61 |
+
engine_client = AsyncLLM.from_vllm_config(vllm_config=vllm_config, usage_context=usage_context, **kwargs)
|
| 62 |
+
|
| 63 |
+
# Don't keep the dummy data in memory
|
| 64 |
+
await engine_client.reset_mm_cache()
|
| 65 |
+
await engine_client.collective_rpc(
|
| 66 |
+
method="monkey_patch_model", kwargs={"vocab_size": len(self.model_config.tokenizer)}
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
build_app_sig = inspect.signature(build_app)
|
| 70 |
+
supported_tasks: tuple[Any, ...] = ()
|
| 71 |
+
if "supported_tasks" in build_app_sig.parameters:
|
| 72 |
+
supported_tasks = await engine_client.get_supported_tasks()
|
| 73 |
+
app = build_app(args, supported_tasks)
|
| 74 |
+
else:
|
| 75 |
+
app = build_app(args)
|
| 76 |
+
|
| 77 |
+
init_app_sig = inspect.signature(init_app_state)
|
| 78 |
+
if "vllm_config" in init_app_sig.parameters:
|
| 79 |
+
await init_app_state(engine_client, vllm_config, app.state, args)
|
| 80 |
+
elif "supported_tasks" in init_app_sig.parameters:
|
| 81 |
+
await init_app_state(engine_client, app.state, args, supported_tasks)
|
| 82 |
+
else:
|
| 83 |
+
await init_app_state(engine_client, app.state, args)
|
| 84 |
+
if self.replica_rank == 0 and self.node_rank == 0:
|
| 85 |
+
logger.info(f"Initializing a V1 LLM engine with config: {vllm_config}")
|
| 86 |
+
|
| 87 |
+
self.engine = engine_client
|
| 88 |
+
self._server_port, self._server_task = await run_unvicorn(app, args, self._server_address)
|
| 89 |
+
|
| 90 |
+
def clear_engine(self):
|
| 91 |
+
self.engine.shutdown()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class FaultRecovervLLMReplica(vLLMReplica):
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
replica_rank: int,
|
| 98 |
+
config: RolloutConfig,
|
| 99 |
+
model_config: HFModelConfig,
|
| 100 |
+
gpus_per_node: int = 8,
|
| 101 |
+
is_reward_model: bool = False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model)
|
| 104 |
+
self.server_class = ray.remote(FaultRecovervLLMHttpServer)
|
ICL/DAPO/verl-recipe/flowrl/config/flowrl_trainer.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- ppo_trainer
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
data:
|
| 10 |
+
gen_batch_size: ${data.train_batch_size}
|
| 11 |
+
|
| 12 |
+
reward_model:
|
| 13 |
+
reward_manager: dapo
|
| 14 |
+
overlong_buffer:
|
| 15 |
+
enable: False # We try to avoid forgetting to set enable
|
| 16 |
+
len: 0
|
| 17 |
+
penalty_factor: 0.0
|
| 18 |
+
log: False
|
| 19 |
+
|
| 20 |
+
algorithm:
|
| 21 |
+
# _target_: verl.trainer.config.AlgoConfig
|
| 22 |
+
|
| 23 |
+
# # FlowRL trajectory balance coefficient (β)
|
| 24 |
+
# tb_coef: 15.0
|
| 25 |
+
|
| 26 |
+
filter_groups:
|
| 27 |
+
_target_: verl.trainer.config.FilterGroupsConfig
|
| 28 |
+
enable: False # We try to avoid forgetting to set enable
|
| 29 |
+
metric: null # acc / score / seq_reward / seq_final_reward / ...
|
| 30 |
+
max_num_gen_batches: 0 # Non-positive values mean no upper limit
|
| 31 |
+
|
| 32 |
+
trainer:
|
| 33 |
+
project_name: verl-flowrl
|
ICL/DAPO/verl-recipe/flowrl/figures/file.svg
ADDED
|
|
ICL/DAPO/verl-recipe/flowrl/figures/flowrl.pdf
ADDED
|
Binary file (52.5 kB). View file
|
|
|
ICL/DAPO/verl-recipe/flowrl/prepare/prepare_data.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -uxo pipefail
|
| 3 |
+
|
| 4 |
+
export DOWNLOAD_DIR=${DOWNLOAD_DIR:-"downloads"}
|
| 5 |
+
export DATA_DIR=${DATA_DIR:-"${DOWNLOAD_DIR}/data"}
|
| 6 |
+
|
| 7 |
+
# Create final data directory
|
| 8 |
+
mkdir -p "${DATA_DIR}"
|
| 9 |
+
|
| 10 |
+
# Download DAPO-Math-17k dataset
|
| 11 |
+
DATASET_NAME_TRAIN="BytedTsinghua-SIA/DAPO-Math-17k"
|
| 12 |
+
echo "Downloading ${DATASET_NAME_TRAIN}..."
|
| 13 |
+
huggingface-cli download $DATASET_NAME_TRAIN \
|
| 14 |
+
--repo-type dataset \
|
| 15 |
+
--resume-download \
|
| 16 |
+
--local-dir ${DOWNLOAD_DIR}/${DATASET_NAME_TRAIN} \
|
| 17 |
+
--local-dir-use-symlinks False
|
| 18 |
+
|
| 19 |
+
# Move the parquet file to data directory
|
| 20 |
+
if [ -f "${DOWNLOAD_DIR}/${DATASET_NAME_TRAIN}/data/dapo-math-17k.parquet" ]; then
|
| 21 |
+
mv "${DOWNLOAD_DIR}/${DATASET_NAME_TRAIN}/data/dapo-math-17k.parquet" "${DATA_DIR}/dapo-math-17k.parquet"
|
| 22 |
+
echo "✓ Moved dapo-math-17k.parquet to ${DATA_DIR}/"
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
# Download AIME-2024 dataset
|
| 26 |
+
DATASET_NAME_TEST="BytedTsinghua-SIA/AIME-2024"
|
| 27 |
+
echo "Downloading ${DATASET_NAME_TEST}..."
|
| 28 |
+
huggingface-cli download $DATASET_NAME_TEST \
|
| 29 |
+
--repo-type dataset \
|
| 30 |
+
--resume-download \
|
| 31 |
+
--local-dir ${DOWNLOAD_DIR}/${DATASET_NAME_TEST} \
|
| 32 |
+
--local-dir-use-symlinks False
|
| 33 |
+
|
| 34 |
+
# Move the parquet file to data directory
|
| 35 |
+
if [ -f "${DOWNLOAD_DIR}/${DATASET_NAME_TEST}/data/aime-2024.parquet" ]; then
|
| 36 |
+
mv "${DOWNLOAD_DIR}/${DATASET_NAME_TEST}/data/aime-2024.parquet" "${DATA_DIR}/aime-2024.parquet"
|
| 37 |
+
echo "✓ Moved aime-2024.parquet to ${DATA_DIR}/"
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
echo ""
|
| 41 |
+
echo "Data preparation completed!"
|
| 42 |
+
echo "Training file: ${DATA_DIR}/dapo-math-17k.parquet"
|
| 43 |
+
echo "Test file: ${DATA_DIR}/aime-2024.parquet"
|
ICL/DAPO/verl-recipe/flowrl/prepare/prepare_model.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_NAME=Qwen/Qwen2.5-7B
|
| 4 |
+
|
| 5 |
+
huggingface-cli download $MODEL_NAME \
|
| 6 |
+
--repo-type model \
|
| 7 |
+
--resume-download \
|
| 8 |
+
--local-dir downloads/models/$MODEL_NAME \
|
| 9 |
+
--local-dir-use-symlinks False \
|
| 10 |
+
--exclude *.pth
|
ICL/DAPO/verl-recipe/gkd/megatron/megatron_utils.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
# Copyright 2023-2024 SGLang Team
|
| 4 |
+
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
| 5 |
+
# Copyright 2025 Individual Contributor: Brilliant Hanabi, furunding
|
| 6 |
+
#
|
| 7 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 8 |
+
# you may not use this file except in compliance with the License.
|
| 9 |
+
# You may obtain a copy of the License at
|
| 10 |
+
#
|
| 11 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 12 |
+
#
|
| 13 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 14 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 15 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 16 |
+
# See the License for the specific language governing permissions and
|
| 17 |
+
# limitations under the License.
|
| 18 |
+
import torch
|
| 19 |
+
from megatron.core import parallel_state as mpu
|
| 20 |
+
|
| 21 |
+
import verl.utils.megatron.tensor_parallel as tp_utils
|
| 22 |
+
from verl.utils.device import get_device_id
|
| 23 |
+
from verl.utils.megatron_utils import default_tp_concat_fn, unwrap_model
|
| 24 |
+
from verl.utils.model import normalize_model_name
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def per_tensor_generator(
|
| 28 |
+
actor_module,
|
| 29 |
+
model_config,
|
| 30 |
+
weight_converter,
|
| 31 |
+
transformer_config,
|
| 32 |
+
layer_name_mapping,
|
| 33 |
+
convert_qkv_gate_up_by_simple_split=True,
|
| 34 |
+
):
|
| 35 |
+
tp_rank = mpu.get_tensor_model_parallel_rank()
|
| 36 |
+
pp_rank = mpu.get_pipeline_model_parallel_rank()
|
| 37 |
+
ep_rank = mpu.get_expert_model_parallel_rank()
|
| 38 |
+
etp_rank = mpu.get_expert_tensor_parallel_rank()
|
| 39 |
+
ep_size = mpu.get_expert_model_parallel_world_size()
|
| 40 |
+
etp_size = mpu.get_expert_tensor_parallel_world_size()
|
| 41 |
+
ep_group = mpu.get_expert_model_parallel_group()
|
| 42 |
+
etp_group = mpu.get_expert_tensor_parallel_group()
|
| 43 |
+
vpp_size = len(actor_module)
|
| 44 |
+
tp_group = mpu.get_tensor_model_parallel_group()
|
| 45 |
+
tp_size = torch.distributed.get_world_size(group=tp_group)
|
| 46 |
+
|
| 47 |
+
def tensor_generator():
|
| 48 |
+
for scan_vpp_idx in range(vpp_size):
|
| 49 |
+
existing_keys = set()
|
| 50 |
+
model = unwrap_model(actor_module[scan_vpp_idx])
|
| 51 |
+
for name, param in model.named_parameters():
|
| 52 |
+
existing_keys.add(name)
|
| 53 |
+
yield name, param
|
| 54 |
+
# note
|
| 55 |
+
# there is a bug in megatron GPTModel
|
| 56 |
+
# decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in
|
| 57 |
+
# state_dict(). for now we patch it by adding those keys to extra_keys.
|
| 58 |
+
extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys]
|
| 59 |
+
for name in extra_keys:
|
| 60 |
+
yield name, model.state_dict()[name].to(get_device_id())
|
| 61 |
+
|
| 62 |
+
def get_tensor_spec(tensor):
|
| 63 |
+
shape = tensor.shape
|
| 64 |
+
dtype = tensor.dtype
|
| 65 |
+
tensor_parallel = getattr(tensor, "tensor_model_parallel", None)
|
| 66 |
+
partition_dim = getattr(tensor, "partition_dim", None)
|
| 67 |
+
tensor_spec = (shape, dtype, tensor_parallel, partition_dim)
|
| 68 |
+
return tensor_spec
|
| 69 |
+
|
| 70 |
+
def make_tensor(tensor_spec):
|
| 71 |
+
tensor = torch.empty(size=tensor_spec[0], dtype=tensor_spec[1], device=get_device_id())
|
| 72 |
+
if tensor_spec[2] is not None:
|
| 73 |
+
tensor.tensor_model_parallel = tensor_spec[2]
|
| 74 |
+
if tensor_spec[3] is not None:
|
| 75 |
+
tensor.partition_dim = tensor_spec[3]
|
| 76 |
+
return tensor
|
| 77 |
+
|
| 78 |
+
# we need first make all rank get full model information
|
| 79 |
+
meta_info = []
|
| 80 |
+
for scan_vpp_idx in range(vpp_size):
|
| 81 |
+
existing_keys = set()
|
| 82 |
+
model = unwrap_model(actor_module[scan_vpp_idx])
|
| 83 |
+
for idx, (name, param) in enumerate(model.named_parameters()):
|
| 84 |
+
existing_keys.add(name)
|
| 85 |
+
meta_info.append((pp_rank, scan_vpp_idx, idx, name, get_tensor_spec(param)))
|
| 86 |
+
extra_keys = [
|
| 87 |
+
(x, y) for x, y in model.state_dict().items() if "_extra_state" not in x and x not in existing_keys
|
| 88 |
+
]
|
| 89 |
+
for name, param in extra_keys:
|
| 90 |
+
meta_info.append((pp_rank, scan_vpp_idx, idx, name, get_tensor_spec(param)))
|
| 91 |
+
|
| 92 |
+
obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size()
|
| 93 |
+
torch.distributed.all_gather_object(
|
| 94 |
+
object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group()
|
| 95 |
+
)
|
| 96 |
+
layer_list_meta = [item for sublist in obj_spec_output for item in sublist]
|
| 97 |
+
|
| 98 |
+
gen_func = tensor_generator()
|
| 99 |
+
|
| 100 |
+
# lazy load tensor for full model
|
| 101 |
+
for cur_pp_rank, scan_vpp_idx, idx, name, tensor_spec in layer_list_meta:
|
| 102 |
+
# fp.write(f"DEBUG: ({cur_pp_rank}, {scan_vpp_idx}, {name})\n")
|
| 103 |
+
if model_config.tie_word_embeddings and ("output_layers" in name):
|
| 104 |
+
import warnings
|
| 105 |
+
|
| 106 |
+
warnings.warn(
|
| 107 |
+
"Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2
|
| 108 |
+
)
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config)
|
| 112 |
+
|
| 113 |
+
if cur_pp_rank == pp_rank:
|
| 114 |
+
_, cur_tensor = next(gen_func)
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
cur_tensor = None
|
| 118 |
+
|
| 119 |
+
if pp_rank == 0:
|
| 120 |
+
if cur_tensor is None:
|
| 121 |
+
cur_tensor = make_tensor(tensor_spec)
|
| 122 |
+
torch.distributed.recv(cur_tensor, group=mpu.get_pipeline_model_parallel_group(), group_src=cur_pp_rank)
|
| 123 |
+
else:
|
| 124 |
+
if cur_tensor is None:
|
| 125 |
+
cur_tensor = make_tensor(tensor_spec)
|
| 126 |
+
else:
|
| 127 |
+
torch.distributed.send(cur_tensor, group=mpu.get_pipeline_model_parallel_group(), group_dst=0)
|
| 128 |
+
|
| 129 |
+
# (xya): this is a hack to fix the name of the parameters
|
| 130 |
+
while cur_name.startswith("module."):
|
| 131 |
+
cur_name = cur_name[len("module.") :]
|
| 132 |
+
|
| 133 |
+
def gather(tensor, gather_list, group, group_dst, group_rank):
|
| 134 |
+
if group_rank == group_dst:
|
| 135 |
+
torch.distributed.gather(tensor, gather_list, group=group, group_dst=group_dst)
|
| 136 |
+
else:
|
| 137 |
+
torch.distributed.gather(tensor, None, group=group, group_dst=group_dst)
|
| 138 |
+
|
| 139 |
+
# EP
|
| 140 |
+
if ".mlp.experts.linear_fc" in cur_name and ep_size > 1:
|
| 141 |
+
num_experts = weight_converter.mcore_config.num_moe_experts
|
| 142 |
+
num_experts_per_rank = num_experts // ep_size
|
| 143 |
+
infer_params = [torch.empty_like(cur_tensor) for _ in range(ep_size)]
|
| 144 |
+
gather(cur_tensor, infer_params, group=ep_group, group_dst=0, group_rank=ep_rank)
|
| 145 |
+
|
| 146 |
+
name_prefix, local_expert_id = cur_name.split(".weight")
|
| 147 |
+
local_expert_id = int(local_expert_id)
|
| 148 |
+
global_expert_ids = [num_experts_per_rank * _ep_rank + local_expert_id for _ep_rank in range(ep_size)]
|
| 149 |
+
global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids]
|
| 150 |
+
|
| 151 |
+
for name, param in zip(global_expert_names, infer_params, strict=True):
|
| 152 |
+
if etp_size > 1:
|
| 153 |
+
# gather etp
|
| 154 |
+
etp_params = [torch.empty_like(param) for _ in range(etp_size)]
|
| 155 |
+
gather(param, etp_params, group=etp_group, group_dst=0, group_rank=etp_rank)
|
| 156 |
+
params = etp_params
|
| 157 |
+
else:
|
| 158 |
+
params = [param]
|
| 159 |
+
|
| 160 |
+
merge_params = default_tp_concat_fn(
|
| 161 |
+
layer_name_mapping,
|
| 162 |
+
name,
|
| 163 |
+
cur_tensor,
|
| 164 |
+
params,
|
| 165 |
+
model_config,
|
| 166 |
+
weight_converter.hf_config,
|
| 167 |
+
convert_qkv_gate_up_by_simple_split,
|
| 168 |
+
)
|
| 169 |
+
if not isinstance(merge_params, list):
|
| 170 |
+
merge_params = [merge_params]
|
| 171 |
+
converted_names, converted_params = weight_converter.convert_param(name, merge_params)
|
| 172 |
+
|
| 173 |
+
yield from zip(converted_names, [param.detach() for param in converted_params], strict=True)
|
| 174 |
+
|
| 175 |
+
continue
|
| 176 |
+
# tp all gather
|
| 177 |
+
if tp_utils.is_tensor_parallel_param(cur_tensor):
|
| 178 |
+
# allocate a new tensor with proper size
|
| 179 |
+
if tp_size <= 1:
|
| 180 |
+
infer_params = [cur_tensor]
|
| 181 |
+
else:
|
| 182 |
+
infer_params = [torch.empty_like(cur_tensor) for _ in range(tp_size)]
|
| 183 |
+
gather(cur_tensor, infer_params, tp_group, group_dst=0, group_rank=tp_rank)
|
| 184 |
+
infer_params = default_tp_concat_fn(
|
| 185 |
+
layer_name_mapping,
|
| 186 |
+
cur_name,
|
| 187 |
+
cur_tensor,
|
| 188 |
+
infer_params,
|
| 189 |
+
model_config,
|
| 190 |
+
weight_converter.hf_config,
|
| 191 |
+
convert_qkv_gate_up_by_simple_split,
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
infer_params = cur_tensor
|
| 195 |
+
|
| 196 |
+
if not isinstance(infer_params, list):
|
| 197 |
+
infer_params = [infer_params]
|
| 198 |
+
converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params)
|
| 199 |
+
|
| 200 |
+
yield from zip(converted_names, [param.detach() for param in converted_params], strict=True)
|
ICL/DAPO/verl-recipe/gvpo/config/gvpo_trainer.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
searchpath:
|
| 3 |
+
- file://verl/trainer/config
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- ppo_trainer
|
| 7 |
+
- _self_
|
| 8 |
+
|
| 9 |
+
actor_rollout_ref:
|
| 10 |
+
actor:
|
| 11 |
+
_target_: recipe.gvpo.gvpo_actor_config.FSDPActorConfig
|
| 12 |
+
gvpo_beta: 0.1
|
| 13 |
+
|
| 14 |
+
trainer:
|
| 15 |
+
project_name: gvpo
|
ICL/DAPO/verl-recipe/langgraph_agent/example/README.md
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MathExpression: LangGraph Agent Example
|
| 2 |
+
|
| 3 |
+
MathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/).
|
| 4 |
+
|
| 5 |
+
### Define react agent with tool
|
| 6 |
+
Firstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`:
|
| 7 |
+
```python
|
| 8 |
+
@tool(parse_docstring=True)
|
| 9 |
+
def calculate(a: int, b: int, operand: str) -> int:
|
| 10 |
+
"""
|
| 11 |
+
Compute the results using operand with two integers
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
a: the first operand
|
| 15 |
+
b: the second operand
|
| 16 |
+
operand: '+' or '-' or '*' or '@'
|
| 17 |
+
"""
|
| 18 |
+
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
|
| 19 |
+
if operand == "@":
|
| 20 |
+
return 3 * a - 2 * b
|
| 21 |
+
return eval(f"{a} {operand} {b}")
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Without calling `calculate`, ReactAgent is impossible to evaluate math expression correctly.
|
| 25 |
+
|
| 26 |
+
Then, we can equip ReactAgent with `calculate` tool:
|
| 27 |
+
```python
|
| 28 |
+
class MathExpressionReactAgentLoop(ReactAgentLoop):
|
| 29 |
+
@classmethod
|
| 30 |
+
def init_class(cls, config, tokenizer):
|
| 31 |
+
cls.tools = [calculate]
|
| 32 |
+
super().init_class(config, tokenizer)
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
We can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class.
|
| 36 |
+
```yaml
|
| 37 |
+
- name: math_expression
|
| 38 |
+
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Prepare dataset
|
| 42 |
+
Now, let's prepare two small datasets for training and evaluation:
|
| 43 |
+
```bash
|
| 44 |
+
python recipe/langgraph_agent/example/create_dataset.py
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
- Parameters: `--train_size` (default: 5000), `--test_size` (default: 500), `--output_dir` (default: `data/math_expression_tool`).
|
| 48 |
+
- Example with custom sizes/output:
|
| 49 |
+
```bash
|
| 50 |
+
python recipe/langgraph_agent/example/create_dataset.py \
|
| 51 |
+
--train_size 10000 \
|
| 52 |
+
--test_size 1000 \
|
| 53 |
+
--output_dir data/math_expression_tool
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Note that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the
|
| 57 |
+
agent loop class.
|
| 58 |
+
| prompt | reward_model | agent_name |
|
| 59 |
+
|--------------------------------------|------------------------------|-----------------|
|
| 60 |
+
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
|
| 61 |
+
| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression |
|
| 62 |
+
|
| 63 |
+
Generated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions.
|
| 64 |
+
```
|
| 65 |
+
(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6
|
| 66 |
+
(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6
|
| 67 |
+
9 @ 4
|
| 68 |
+
((2 @ 2) @ (3 @ 3)) @ 4
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### Training
|
| 72 |
+
Hook all these up and start training:
|
| 73 |
+
```bash
|
| 74 |
+
bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
To submit on a SLURM cluster (the script contains SBATCH headers):
|
| 78 |
+
```bash
|
| 79 |
+
sbatch recipe/langgraph_agent/example/run_qwen2.5_3b.sh
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
**Note on `GPUS_PER_NODE` and `NNODES`:**
|
| 83 |
+
|
| 84 |
+
- `GPUS_PER_NODE`: GPUs per node.
|
| 85 |
+
Detection order: `SLURM_GPUS_ON_NODE` (if set) → `GPUS_PER_NODE` → `2`.
|
| 86 |
+
- `NNODES`: number of nodes.
|
| 87 |
+
Detection order: `SLURM_JOB_NUM_NODES` (if set) → `NNODES` → `1`.
|
| 88 |
+
- Total GPUs = `GPUS_PER_NODE × NNODES` (must be ≥ 2).
|
| 89 |
+
|
| 90 |
+
Local override (no `SLURM_*` set):
|
| 91 |
+
```bash
|
| 92 |
+
GPUS_PER_NODE=4 NNODES=2 bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
After total 39 steps, model should achieve 100% accuray on test dataset:
|
| 96 |
+
- val-aux/lighteval/MATH/reward: 1.0
|
| 97 |
+
- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns.
|
| 98 |
+
|
| 99 |
+
We can see model can solve math expression by calling `calculate` tool:
|
| 100 |
+
```
|
| 101 |
+
system
|
| 102 |
+
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
|
| 103 |
+
user
|
| 104 |
+
We define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \boxed{} only in the last turn. Now answer the following questions:
|
| 105 |
+
Compute 5 @ (3 @ 6 @ 8)
|
| 106 |
+
assistant
|
| 107 |
+
Sure, let's break this down step by step. We need to compute the expression \(5 @ (3 @ 6 @ 8)\). First, we need to compute the inner expression \(3 @ 6 @ 8\).
|
| 108 |
+
|
| 109 |
+
Let's compute \(3 @ 6\) first using the external tool.
|
| 110 |
+
<tool_call>
|
| 111 |
+
{"name": "calculate", "arguments": {"a": 3, "b": 6, "operand": "@"}}
|
| 112 |
+
</tool_call>user
|
| 113 |
+
<tool_response>
|
| 114 |
+
-3
|
| 115 |
+
</tool_response>
|
| 116 |
+
assistant
|
| 117 |
+
We have computed that \(3 @ 6 = -3\). Now we need to compute \(-3 @ 8\).
|
| 118 |
+
|
| 119 |
+
Let's use the external tool to compute this.
|
| 120 |
+
<tool_call>
|
| 121 |
+
{"name": "calculate", "arguments": {"a": -3, "b": 8, "operand": "@"}}
|
| 122 |
+
</tool_call>user
|
| 123 |
+
<tool_response>
|
| 124 |
+
-25
|
| 125 |
+
</tool_response>
|
| 126 |
+
assistant
|
| 127 |
+
We have computed that \(-3 @ 8 = -25\). Now we can put this result into the original expression to get the final result of \(5 @ (3 @ 6 @ 8)\), which is \(5 @ (-25)\).
|
| 128 |
+
|
| 129 |
+
Let's compute \(5 @ (-25)\) using the external tool.
|
| 130 |
+
<tool_call>
|
| 131 |
+
{"name": "calculate", "arguments": {"a": 5, "b": -25, "operand": "@"}}
|
| 132 |
+
</tool_call>user
|
| 133 |
+
<tool_response>
|
| 134 |
+
65
|
| 135 |
+
</tool_response>
|
| 136 |
+
assistant
|
| 137 |
+
The final result of \(5 @ (3 @ 6 @ 8)\) is \(\boxed{65}\).
|
| 138 |
+
```
|
ICL/DAPO/verl-recipe/langgraph_agent/example/agent.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- name: math_expression
|
| 2 |
+
_target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop
|
ICL/DAPO/verl-recipe/langgraph_agent/example/create_dataset.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Create dataset for calculator
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
|
| 22 |
+
import pandas as pd
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2):
|
| 26 |
+
"""
|
| 27 |
+
Generate a random mathematical expression with operators +, -, *, /, and parentheses.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
min_terms (int): Minimum number of terms in the expression.
|
| 31 |
+
max_terms (int): Maximum number of terms in the expression.
|
| 32 |
+
max_number (int): Maximum value for numbers in the expression.
|
| 33 |
+
allow_decimals (bool): Whether to allow decimal numbers.
|
| 34 |
+
max_depth (int): Maximum nesting depth for parentheses.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
str: A valid mathematical expression as a string.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def generate_number():
|
| 41 |
+
"""Generate a random number (integer or float)."""
|
| 42 |
+
assert min_number < max_number
|
| 43 |
+
num = random.uniform(min_number, max_number)
|
| 44 |
+
if not allow_decimals:
|
| 45 |
+
num = int(num)
|
| 46 |
+
else:
|
| 47 |
+
num = round(num, random.randint(0, 2)) # Round to 0-2 decimal places
|
| 48 |
+
return str(num)
|
| 49 |
+
|
| 50 |
+
def generate_term(depth=0):
|
| 51 |
+
"""Generate a term (number or parenthesized expression)."""
|
| 52 |
+
if depth < max_depth and random.random() < 0.5: # 50% chance to add parentheses
|
| 53 |
+
expr = generate_expression(depth + 1)
|
| 54 |
+
return f"({expr})"
|
| 55 |
+
else:
|
| 56 |
+
return generate_number()
|
| 57 |
+
|
| 58 |
+
def generate_expression(depth=0):
|
| 59 |
+
"""Generate a full expression with multiple terms and operators."""
|
| 60 |
+
num_terms = random.randint(min_terms, max_terms)
|
| 61 |
+
terms = [generate_term(depth) for _ in range(num_terms)]
|
| 62 |
+
|
| 63 |
+
# Randomly select operators
|
| 64 |
+
operators = ["+", "-", "*", "/", "@"]
|
| 65 |
+
expr = terms[0]
|
| 66 |
+
|
| 67 |
+
for i in range(1, num_terms):
|
| 68 |
+
# Bias towards + and - for readability
|
| 69 |
+
op = random.choices(
|
| 70 |
+
operators,
|
| 71 |
+
weights=[0, 0, 0, 0, 1], # + and - are 1.5x more likely than * and /
|
| 72 |
+
)[0]
|
| 73 |
+
expr += f" {op} " + terms[i]
|
| 74 |
+
|
| 75 |
+
return expr
|
| 76 |
+
|
| 77 |
+
return generate_expression()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test():
|
| 81 |
+
# Example 1: Basic integer expression
|
| 82 |
+
print(generate_math_expression())
|
| 83 |
+
# Output: (3 + 7) * 2 - 5
|
| 84 |
+
|
| 85 |
+
# Example 2: Expression with decimals
|
| 86 |
+
print(generate_math_expression(allow_decimals=True))
|
| 87 |
+
# Output: 4.5 / (2.1 + 3.7) - 1.2
|
| 88 |
+
|
| 89 |
+
# Example 3: More complex expression with higher depth
|
| 90 |
+
print(generate_math_expression(max_terms=6, max_depth=3))
|
| 91 |
+
# Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4
|
| 92 |
+
|
| 93 |
+
# Example 4: Simplified expression
|
| 94 |
+
print(generate_math_expression(min_terms=2, max_terms=3, max_number=5))
|
| 95 |
+
# Output: 4 - 2 * 3
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def calculate(expression: str) -> float:
|
| 99 |
+
"""
|
| 100 |
+
Evaluate a mathematical expression with +, -, *, /, @, and parentheses.
|
| 101 |
+
The @ operator is defined as: a @ b = 3a - 2b.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
expression (str): Input mathematical expression (e.g., "3@2+4").
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
float: Result of the evaluated expression.
|
| 108 |
+
|
| 109 |
+
Raises:
|
| 110 |
+
ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero).
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def tokenize(s: str) -> list:
|
| 114 |
+
"""Convert the input string into tokens (numbers, operators, parentheses)."""
|
| 115 |
+
tokens = []
|
| 116 |
+
i = 0
|
| 117 |
+
while i < len(s):
|
| 118 |
+
if s[i].isdigit() or s[i] == ".":
|
| 119 |
+
# Parse number (integer or float)
|
| 120 |
+
j = i
|
| 121 |
+
while j < len(s) and (s[j].isdigit() or s[j] == "."):
|
| 122 |
+
j += 1
|
| 123 |
+
tokens.append(s[i:j])
|
| 124 |
+
i = j
|
| 125 |
+
elif s[i] in "+-*/@()":
|
| 126 |
+
# Operator or parenthesis
|
| 127 |
+
tokens.append(s[i])
|
| 128 |
+
i += 1
|
| 129 |
+
elif s[i].isspace():
|
| 130 |
+
# Skip whitespace
|
| 131 |
+
i += 1
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Invalid character: {s[i]}")
|
| 134 |
+
return tokens
|
| 135 |
+
|
| 136 |
+
def infix_to_postfix(tokens: list) -> list:
|
| 137 |
+
"""Convert infix notation to postfix notation (Reverse Polish Notation)."""
|
| 138 |
+
output = []
|
| 139 |
+
stack = []
|
| 140 |
+
# Higher precedence for @ (between * and +)
|
| 141 |
+
precedence = {"@": 3, "*": 2, "/": 2, "+": 1, "-": 1}
|
| 142 |
+
|
| 143 |
+
for token in tokens:
|
| 144 |
+
if token.isdigit() or "." in token:
|
| 145 |
+
output.append(token)
|
| 146 |
+
elif token == "(":
|
| 147 |
+
stack.append(token)
|
| 148 |
+
elif token == ")":
|
| 149 |
+
while stack and stack[-1] != "(":
|
| 150 |
+
output.append(stack.pop())
|
| 151 |
+
if not stack or stack[-1] != "(":
|
| 152 |
+
raise ValueError("Mismatched parentheses")
|
| 153 |
+
stack.pop() # Discard '('
|
| 154 |
+
else: # Operator
|
| 155 |
+
while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence.get(token, 0):
|
| 156 |
+
output.append(stack.pop())
|
| 157 |
+
stack.append(token)
|
| 158 |
+
|
| 159 |
+
# Pop remaining operators
|
| 160 |
+
while stack:
|
| 161 |
+
if stack[-1] in "()":
|
| 162 |
+
raise ValueError("Mismatched parentheses")
|
| 163 |
+
output.append(stack.pop())
|
| 164 |
+
|
| 165 |
+
return output
|
| 166 |
+
|
| 167 |
+
def evaluate_postfix(postfix: list) -> float:
|
| 168 |
+
"""Evaluate postfix expression using a stack."""
|
| 169 |
+
stack = []
|
| 170 |
+
for token in postfix:
|
| 171 |
+
if token.isdigit() or "." in token:
|
| 172 |
+
stack.append(float(token))
|
| 173 |
+
else:
|
| 174 |
+
if len(stack) < 2:
|
| 175 |
+
raise ValueError("Invalid expression")
|
| 176 |
+
b = stack.pop()
|
| 177 |
+
a = stack.pop()
|
| 178 |
+
if token == "+":
|
| 179 |
+
res = a + b
|
| 180 |
+
elif token == "-":
|
| 181 |
+
res = a - b
|
| 182 |
+
elif token == "*":
|
| 183 |
+
res = a * b
|
| 184 |
+
elif token == "/":
|
| 185 |
+
if b == 0:
|
| 186 |
+
raise ValueError("Division by zero")
|
| 187 |
+
res = a / b
|
| 188 |
+
elif token == "@":
|
| 189 |
+
res = 3 * a - 2 * b # Custom @ operator implementation
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"Invalid operator: {token}")
|
| 192 |
+
stack.append(res)
|
| 193 |
+
|
| 194 |
+
if len(stack) != 1:
|
| 195 |
+
raise ValueError("Invalid expression")
|
| 196 |
+
return stack[0]
|
| 197 |
+
|
| 198 |
+
# Remove spaces and validate parentheses
|
| 199 |
+
expression = expression.replace(" ", "")
|
| 200 |
+
if expression.count("(") != expression.count(")"):
|
| 201 |
+
raise ValueError("Mismatched parentheses")
|
| 202 |
+
|
| 203 |
+
tokens = tokenize(expression)
|
| 204 |
+
postfix = infix_to_postfix(tokens)
|
| 205 |
+
result = evaluate_postfix(postfix)
|
| 206 |
+
|
| 207 |
+
# Convert integers to integer representation
|
| 208 |
+
if result.is_integer():
|
| 209 |
+
return int(result)
|
| 210 |
+
return result
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def generate_data(total_num_dataset, split, agent_name="math_expression"):
|
| 214 |
+
rl_dataset = {
|
| 215 |
+
"prompt": [],
|
| 216 |
+
"data_source": [],
|
| 217 |
+
"ability": [],
|
| 218 |
+
"reward_model": [],
|
| 219 |
+
"extra_info": [],
|
| 220 |
+
"agent_name": [],
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
for idx in range(total_num_dataset):
|
| 224 |
+
while True:
|
| 225 |
+
try:
|
| 226 |
+
expression: str = generate_math_expression(
|
| 227 |
+
min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
num_plus = expression.count("+")
|
| 231 |
+
num_minus = expression.count("-")
|
| 232 |
+
num_mul = expression.count("*")
|
| 233 |
+
num_star = expression.count("@")
|
| 234 |
+
|
| 235 |
+
answer = str(calculate(expression))
|
| 236 |
+
# answer = str(eval(expression))
|
| 237 |
+
break
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(e)
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
num_tool_calls = num_plus + num_minus + num_mul + num_star
|
| 243 |
+
|
| 244 |
+
prompt = (
|
| 245 |
+
f"We define a new math operator @, where you can only call an external tool to compute. "
|
| 246 |
+
f"Please put your final answer inside \\boxed{{}} only in the last turn. Now answer the "
|
| 247 |
+
f"following questions:\nCompute {expression}"
|
| 248 |
+
)
|
| 249 |
+
prompt_with_template = [
|
| 250 |
+
{
|
| 251 |
+
"role": "user",
|
| 252 |
+
"content": prompt,
|
| 253 |
+
}
|
| 254 |
+
]
|
| 255 |
+
|
| 256 |
+
rl_dataset["prompt"].append(prompt_with_template)
|
| 257 |
+
rl_dataset["data_source"].append("lighteval/MATH")
|
| 258 |
+
rl_dataset["ability"].append("math")
|
| 259 |
+
rl_dataset["reward_model"].append({"style": "lighteval/MATH", "ground_truth": answer})
|
| 260 |
+
rl_dataset["extra_info"].append(
|
| 261 |
+
{"index": idx, "expression": expression, "split": split, "expected_tool_calls": num_tool_calls}
|
| 262 |
+
)
|
| 263 |
+
rl_dataset["agent_name"].append(agent_name)
|
| 264 |
+
|
| 265 |
+
rl_dataset = pd.DataFrame(data=rl_dataset)
|
| 266 |
+
return rl_dataset
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
parser = argparse.ArgumentParser(description="Math Expression Dataset Generator")
|
| 271 |
+
parser.add_argument("--train_size", type=int, default=5000, help="Number of training samples")
|
| 272 |
+
parser.add_argument("--test_size", type=int, default=500, help="Number of testing samples")
|
| 273 |
+
parser.add_argument("--output_dir", default="data/math_expression_tool", help="Directory to save the dataset")
|
| 274 |
+
parser.add_argument("--agent_name", default="math_expression", help="Name of the agent")
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
# print(calculate("3@2")) # Output: 5 (3*3 - 2*2)
|
| 278 |
+
# print(calculate("3@2+4")) # Output: 9 (5 + 4)
|
| 279 |
+
# print(calculate("3*(4@2)")) # Output: 24 (3 * 8)
|
| 280 |
+
# print(calculate("(5@3)*2")) # Output: 18 (9 * 2)
|
| 281 |
+
|
| 282 |
+
train_dataset = generate_data(total_num_dataset=args.train_size, split="train", agent_name=args.agent_name)
|
| 283 |
+
test_dataset = generate_data(total_num_dataset=args.test_size, split="test", agent_name=args.agent_name)
|
| 284 |
+
|
| 285 |
+
# Make sure the dataset directory exists
|
| 286 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
# Save the datasets to parquet files
|
| 289 |
+
train_dataset.to_parquet(os.path.join(args.output_dir, "train.parquet"))
|
| 290 |
+
test_dataset.to_parquet(os.path.join(args.output_dir, "test.parquet"))
|
ICL/DAPO/verl-recipe/langgraph_agent/example/math_expression.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from langchain_core.tools import tool
|
| 15 |
+
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@tool(parse_docstring=True)
|
| 19 |
+
def calculate(a: int, b: int, operand: str) -> int:
|
| 20 |
+
"""
|
| 21 |
+
Compute the results using operand with two integers
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
a: the first operand
|
| 25 |
+
b: the second operand
|
| 26 |
+
operand: '+' or '-' or '*' or '@'
|
| 27 |
+
"""
|
| 28 |
+
assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}"
|
| 29 |
+
if operand == "@":
|
| 30 |
+
return 3 * a - 2 * b
|
| 31 |
+
return eval(f"{a} {operand} {b}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MathExpressionReactAgentLoop(ReactAgentLoop):
|
| 35 |
+
@classmethod
|
| 36 |
+
def init_class(cls, config, tokenizer, **kwargs):
|
| 37 |
+
cls.tools = [calculate]
|
| 38 |
+
super().init_class(config, tokenizer)
|
ICL/DAPO/verl-recipe/langgraph_agent/example/run_gpt_oss_20b_bf16.sh
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#SBATCH --job-name=rl-langgraph-3B
|
| 3 |
+
#SBATCH --partition=main
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks-per-node=1
|
| 6 |
+
#SBATCH --cpus-per-task=64
|
| 7 |
+
#SBATCH --gres=gpu:4
|
| 8 |
+
#SBATCH --mem=0
|
| 9 |
+
#SBATCH --time=10:00:00
|
| 10 |
+
#SBATCH --output=%x_%j.out
|
| 11 |
+
#SBATCH --error=%x_%j.err
|
| 12 |
+
|
| 13 |
+
set -xeuo pipefail
|
| 14 |
+
|
| 15 |
+
# ================= cluster topology =================
|
| 16 |
+
export GPUS_PER_NODE=${SLURM_GPUS_ON_NODE:-${GPUS_PER_NODE:-2}} # GPUs on this node
|
| 17 |
+
NNODES=${SLURM_JOB_NUM_NODES:-${NNODES:-1}}
|
| 18 |
+
export NNODES
|
| 19 |
+
export RAY_NUM_NODES=$NNODES
|
| 20 |
+
|
| 21 |
+
# Require at least 2 GPUs
|
| 22 |
+
TOTAL_GPUS=$((GPUS_PER_NODE * NNODES))
|
| 23 |
+
if [ "$TOTAL_GPUS" -lt 2 ]; then
|
| 24 |
+
echo "Error: at least 2 GPUs are required, detected $TOTAL_GPUS." >&2
|
| 25 |
+
exit 1
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
echo "Using $NNODES nodes and $GPUS_PER_NODE GPUs per node..."
|
| 29 |
+
|
| 30 |
+
# ================= data/model/tool =================
|
| 31 |
+
HDFS_ROOT=${HDFS_ROOT:-$PWD}
|
| 32 |
+
DATA_ROOT=${DATA_ROOT:-$PWD}
|
| 33 |
+
|
| 34 |
+
# Prefer local model if present, otherwise fall back to HF hub path
|
| 35 |
+
model_path="lmsys/gpt-oss-20b-bf16"
|
| 36 |
+
|
| 37 |
+
# Use the default output directory produced by create_dataset.py
|
| 38 |
+
train_files=$DATA_ROOT/data/math_expression_tool/train.parquet
|
| 39 |
+
test_files=$DATA_ROOT/data/math_expression_tool/test.parquet
|
| 40 |
+
|
| 41 |
+
# Agent config
|
| 42 |
+
agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml
|
| 43 |
+
|
| 44 |
+
# =================== wandb ===================
|
| 45 |
+
project_name=math_expression_tool
|
| 46 |
+
experiment_name=gpt-oss-20b-bf16
|
| 47 |
+
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
|
| 48 |
+
|
| 49 |
+
# ================= algorithm =================
|
| 50 |
+
adv_estimator=grpo
|
| 51 |
+
|
| 52 |
+
use_kl_in_reward=false
|
| 53 |
+
kl_coef=0.0
|
| 54 |
+
use_kl_loss=false
|
| 55 |
+
kl_loss_coef=0.0
|
| 56 |
+
|
| 57 |
+
clip_ratio_low=0.2
|
| 58 |
+
clip_ratio_high=0.28
|
| 59 |
+
|
| 60 |
+
max_turns=8
|
| 61 |
+
max_prompt_length=1024
|
| 62 |
+
max_response_length=8192
|
| 63 |
+
actor_lr=1e-6
|
| 64 |
+
|
| 65 |
+
train_batch_size=128
|
| 66 |
+
ppo_mini_batch_size=16
|
| 67 |
+
n_resp_per_prompt=8
|
| 68 |
+
n_resp_per_prompt_val=1
|
| 69 |
+
|
| 70 |
+
# =================== logging ===================
|
| 71 |
+
export RAY_LOGGING_LEVEL=DEBUG
|
| 72 |
+
export HYDRA_FULL_ERROR=1
|
| 73 |
+
|
| 74 |
+
# ================= performance =================
|
| 75 |
+
export NCCL_IBEXT_DISABLE=1
|
| 76 |
+
export NCCL_NVLS_ENABLE=1
|
| 77 |
+
export NCCL_IB_HCA=mlx5
|
| 78 |
+
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
|
| 79 |
+
export VLLM_USE_V1=1
|
| 80 |
+
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
|
| 81 |
+
|
| 82 |
+
infer_tp=2 # vLLM tensor parallel size
|
| 83 |
+
train_sp=4 # Ulysses sequence parallel size for actor
|
| 84 |
+
offload=true
|
| 85 |
+
|
| 86 |
+
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
|
| 87 |
+
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))
|
| 88 |
+
|
| 89 |
+
train_files="['$train_files']"
|
| 90 |
+
test_files="['$test_files']"
|
| 91 |
+
|
| 92 |
+
python3 -m verl.trainer.main_ppo \
|
| 93 |
+
algorithm.adv_estimator=$adv_estimator \
|
| 94 |
+
algorithm.use_kl_in_reward=$use_kl_in_reward \
|
| 95 |
+
algorithm.kl_ctrl.kl_coef=$kl_coef \
|
| 96 |
+
data.train_files="$train_files" \
|
| 97 |
+
data.val_files="$test_files" \
|
| 98 |
+
data.return_raw_chat=true \
|
| 99 |
+
data.train_batch_size=$train_batch_size \
|
| 100 |
+
data.max_prompt_length=$max_prompt_length \
|
| 101 |
+
data.max_response_length=$max_response_length \
|
| 102 |
+
data.filter_overlong_prompts=true \
|
| 103 |
+
data.truncation='error' \
|
| 104 |
+
actor_rollout_ref.model.path="$model_path" \
|
| 105 |
+
actor_rollout_ref.model.use_remove_padding=true \
|
| 106 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=true \
|
| 107 |
+
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
|
| 108 |
+
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
|
| 109 |
+
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
|
| 110 |
+
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
|
| 111 |
+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
| 112 |
+
actor_rollout_ref.actor.optim.lr=$actor_lr \
|
| 113 |
+
actor_rollout_ref.actor.use_dynamic_bsz=true \
|
| 114 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
|
| 115 |
+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
|
| 116 |
+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
|
| 117 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
|
| 118 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
|
| 119 |
+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
|
| 120 |
+
actor_rollout_ref.rollout.name=sglang \
|
| 121 |
+
actor_rollout_ref.rollout.mode=async \
|
| 122 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
|
| 123 |
+
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
|
| 124 |
+
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
|
| 125 |
+
actor_rollout_ref.rollout.multi_turn.format=gpt-oss \
|
| 126 |
+
+actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \
|
| 127 |
+
actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \
|
| 128 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
| 129 |
+
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
|
| 130 |
+
actor_rollout_ref.rollout.val_kwargs.top_p=1.0\
|
| 131 |
+
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
|
| 132 |
+
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
|
| 133 |
+
trainer.logger='["console","wandb"]' \
|
| 134 |
+
trainer.project_name=$project_name \
|
| 135 |
+
trainer.experiment_name=$experiment_name \
|
| 136 |
+
trainer.n_gpus_per_node="$GPUS_PER_NODE" \
|
| 137 |
+
trainer.val_before_train=true \
|
| 138 |
+
trainer.log_val_generations=50 \
|
| 139 |
+
trainer.nnodes="$NNODES" \
|
| 140 |
+
trainer.save_freq=-1 \
|
| 141 |
+
trainer.default_local_dir="$default_local_dir" \
|
| 142 |
+
trainer.test_freq=5 \
|
| 143 |
+
trainer.total_epochs=1 "$@"
|
ICL/DAPO/verl-recipe/langgraph_agent/example/run_qwen2.5_3b.sh
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#SBATCH --job-name=rl-langgraph-3B
|
| 3 |
+
#SBATCH --partition=main
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks-per-node=1
|
| 6 |
+
#SBATCH --cpus-per-task=64
|
| 7 |
+
#SBATCH --gres=gpu:4
|
| 8 |
+
#SBATCH --mem=0
|
| 9 |
+
#SBATCH --time=10:00:00
|
| 10 |
+
#SBATCH --output=%x_%j.out
|
| 11 |
+
#SBATCH --error=%x_%j.err
|
| 12 |
+
|
| 13 |
+
set -xeuo pipefail
|
| 14 |
+
|
| 15 |
+
# ================= cluster topology =================
|
| 16 |
+
export GPUS_PER_NODE=${SLURM_GPUS_ON_NODE:-${GPUS_PER_NODE:-2}} # GPUs on this node
|
| 17 |
+
NNODES=${SLURM_JOB_NUM_NODES:-${NNODES:-1}}
|
| 18 |
+
export NNODES
|
| 19 |
+
export RAY_NUM_NODES=$NNODES
|
| 20 |
+
|
| 21 |
+
# Require at least 2 GPUs
|
| 22 |
+
TOTAL_GPUS=$((GPUS_PER_NODE * NNODES))
|
| 23 |
+
if [ "$TOTAL_GPUS" -lt 2 ]; then
|
| 24 |
+
echo "Error: at least 2 GPUs are required, detected $TOTAL_GPUS." >&2
|
| 25 |
+
exit 1
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
echo "Using $NNODES nodes and $GPUS_PER_NODE GPUs per node..."
|
| 29 |
+
|
| 30 |
+
# ================= data/model/tool =================
|
| 31 |
+
HDFS_ROOT=${HDFS_ROOT:-$PWD}
|
| 32 |
+
DATA_ROOT=${DATA_ROOT:-$PWD}
|
| 33 |
+
|
| 34 |
+
# Prefer local model if present, otherwise fall back to HF hub path
|
| 35 |
+
model_path=${model_path:-$DATA_ROOT/model/Qwen2.5-3B-Instruct}
|
| 36 |
+
if [ ! -d "$model_path" ]; then
|
| 37 |
+
model_path=Qwen/Qwen2.5-3B-Instruct
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
# Use the default output directory produced by create_dataset.py
|
| 41 |
+
train_files=$DATA_ROOT/data/math_expression_tool/train.parquet
|
| 42 |
+
test_files=$DATA_ROOT/data/math_expression_tool/test.parquet
|
| 43 |
+
|
| 44 |
+
# Agent config
|
| 45 |
+
agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml
|
| 46 |
+
|
| 47 |
+
# =================== wandb ===================
|
| 48 |
+
project_name=math_expression_tool
|
| 49 |
+
experiment_name=qwen2.5-3b
|
| 50 |
+
default_local_dir=$DATA_ROOT/checkpoint/$experiment_name
|
| 51 |
+
|
| 52 |
+
# ================= algorithm =================
|
| 53 |
+
adv_estimator=grpo
|
| 54 |
+
|
| 55 |
+
use_kl_in_reward=false
|
| 56 |
+
kl_coef=0.0
|
| 57 |
+
use_kl_loss=false
|
| 58 |
+
kl_loss_coef=0.0
|
| 59 |
+
|
| 60 |
+
clip_ratio_low=0.2
|
| 61 |
+
clip_ratio_high=0.28
|
| 62 |
+
|
| 63 |
+
max_turns=8
|
| 64 |
+
max_prompt_length=1024
|
| 65 |
+
max_response_length=2048
|
| 66 |
+
actor_lr=1e-6
|
| 67 |
+
|
| 68 |
+
train_batch_size=128
|
| 69 |
+
ppo_mini_batch_size=16
|
| 70 |
+
n_resp_per_prompt=8
|
| 71 |
+
n_resp_per_prompt_val=1
|
| 72 |
+
|
| 73 |
+
# =================== logging ===================
|
| 74 |
+
export RAY_LOGGING_LEVEL=DEBUG
|
| 75 |
+
export HYDRA_FULL_ERROR=1
|
| 76 |
+
|
| 77 |
+
# ================= performance =================
|
| 78 |
+
export NCCL_IBEXT_DISABLE=1
|
| 79 |
+
export NCCL_NVLS_ENABLE=1
|
| 80 |
+
export NCCL_IB_HCA=mlx5
|
| 81 |
+
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
|
| 82 |
+
export VLLM_USE_V1=1
|
| 83 |
+
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
|
| 84 |
+
|
| 85 |
+
infer_tp=2 # vLLM tensor parallel size
|
| 86 |
+
train_sp=4 # Ulysses sequence parallel size for actor
|
| 87 |
+
offload=true
|
| 88 |
+
|
| 89 |
+
actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 ))
|
| 90 |
+
log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 ))
|
| 91 |
+
|
| 92 |
+
train_files="['$train_files']"
|
| 93 |
+
test_files="['$test_files']"
|
| 94 |
+
|
| 95 |
+
python3 -m verl.trainer.main_ppo \
|
| 96 |
+
algorithm.adv_estimator=$adv_estimator \
|
| 97 |
+
algorithm.use_kl_in_reward=$use_kl_in_reward \
|
| 98 |
+
algorithm.kl_ctrl.kl_coef=$kl_coef \
|
| 99 |
+
data.train_files="$train_files" \
|
| 100 |
+
data.val_files="$test_files" \
|
| 101 |
+
data.return_raw_chat=true \
|
| 102 |
+
data.train_batch_size=$train_batch_size \
|
| 103 |
+
data.max_prompt_length=$max_prompt_length \
|
| 104 |
+
data.max_response_length=$max_response_length \
|
| 105 |
+
data.filter_overlong_prompts=true \
|
| 106 |
+
data.truncation='error' \
|
| 107 |
+
actor_rollout_ref.model.path="$model_path" \
|
| 108 |
+
actor_rollout_ref.model.use_remove_padding=true \
|
| 109 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=true \
|
| 110 |
+
actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \
|
| 111 |
+
actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \
|
| 112 |
+
actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \
|
| 113 |
+
actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \
|
| 114 |
+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
| 115 |
+
actor_rollout_ref.actor.optim.lr=$actor_lr \
|
| 116 |
+
actor_rollout_ref.actor.use_dynamic_bsz=true \
|
| 117 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \
|
| 118 |
+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \
|
| 119 |
+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \
|
| 120 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=$offload \
|
| 121 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \
|
| 122 |
+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \
|
| 123 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 124 |
+
actor_rollout_ref.rollout.mode=async \
|
| 125 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \
|
| 126 |
+
actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \
|
| 127 |
+
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \
|
| 128 |
+
actor_rollout_ref.rollout.multi_turn.format=hermes \
|
| 129 |
+
actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \
|
| 130 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
|
| 131 |
+
actor_rollout_ref.rollout.n=$n_resp_per_prompt \
|
| 132 |
+
actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \
|
| 133 |
+
actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \
|
| 134 |
+
actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \
|
| 135 |
+
trainer.logger='["console","wandb"]' \
|
| 136 |
+
trainer.project_name=$project_name \
|
| 137 |
+
trainer.experiment_name=$experiment_name \
|
| 138 |
+
trainer.n_gpus_per_node="$GPUS_PER_NODE" \
|
| 139 |
+
trainer.val_before_train=true \
|
| 140 |
+
trainer.log_val_generations=50 \
|
| 141 |
+
trainer.nnodes="$NNODES" \
|
| 142 |
+
trainer.save_freq=-1 \
|
| 143 |
+
trainer.default_local_dir="$default_local_dir" \
|
| 144 |
+
trainer.test_freq=5 \
|
| 145 |
+
trainer.total_epochs=1 "$@"
|
ICL/DAPO/verl-recipe/open_math_reasoning/run_eval.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# Evaluation
|
| 4 |
+
python3 -m verl.trainer.main_eval \
|
| 5 |
+
data.path=$HOME/data/gen/qwen_8b_gen_test.parquet \
|
| 6 |
+
custom_reward_function.path=recipe/open_math_reasoning/compute_score.py \
|
| 7 |
+
custom_reward_function.name=compute_score_data_source
|
ICL/DAPO/verl-recipe/prime/config/prime_trainer.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# the prime config will override default ppo_trainer.yaml
|
| 2 |
+
|
| 3 |
+
hydra:
|
| 4 |
+
searchpath:
|
| 5 |
+
- file://verl/trainer/config
|
| 6 |
+
|
| 7 |
+
defaults:
|
| 8 |
+
- ppo_trainer
|
| 9 |
+
- _self_
|
| 10 |
+
|
| 11 |
+
data:
|
| 12 |
+
filter_accuracy: True
|
| 13 |
+
accuracy_lower_bound: 0.2
|
| 14 |
+
accuracy_upper_bound: 0.8
|
| 15 |
+
oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized.
|
| 16 |
+
filter_truncate: True
|
| 17 |
+
truncation: right
|
| 18 |
+
|
| 19 |
+
actor_rollout_ref:
|
| 20 |
+
hybrid_engine: True
|
| 21 |
+
model:
|
| 22 |
+
use_remove_padding: True
|
| 23 |
+
rollout:
|
| 24 |
+
mode: sync
|
| 25 |
+
# number of responses (i.e. num sample times)
|
| 26 |
+
n: 4
|
| 27 |
+
actor:
|
| 28 |
+
entropy_coeff: 0.001
|
| 29 |
+
|
| 30 |
+
reward_model:
|
| 31 |
+
enable: True
|
| 32 |
+
strategy: fsdp
|
| 33 |
+
model:
|
| 34 |
+
ref_path: ${reward_model.model.path}
|
| 35 |
+
use_remove_padding: True
|
| 36 |
+
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
|
| 37 |
+
fused_kernel_options:
|
| 38 |
+
impl_backend: torch # triton, torch
|
| 39 |
+
tokenizer_path: ${actor_rollout_ref.model.path}
|
| 40 |
+
enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
|
| 41 |
+
ref_type: freeze
|
| 42 |
+
fsdp_config:
|
| 43 |
+
min_num_params: 0
|
| 44 |
+
param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}
|
| 45 |
+
optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}
|
| 46 |
+
update: before # ``before`` for double-forward, ``after`` for single-forward
|
| 47 |
+
optim:
|
| 48 |
+
lr: 1e-6
|
| 49 |
+
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
|
| 50 |
+
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
|
| 51 |
+
min_lr_ratio: null
|
| 52 |
+
warmup_style: null # deprecated
|
| 53 |
+
lr_scheduler_type: constant
|
| 54 |
+
total_training_steps: -1 # must be overridden by program
|
| 55 |
+
weight_decay: 0.
|
| 56 |
+
grad_clip: 10.0
|
| 57 |
+
beta_train: 0.05
|
| 58 |
+
loss_type: ce # currently only supports ce loss
|
| 59 |
+
prime_granularity: token
|
| 60 |
+
prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train
|
| 61 |
+
mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
|
| 62 |
+
reward_manager: prime
|
| 63 |
+
|
| 64 |
+
algorithm:
|
| 65 |
+
adv_estimator: rloo
|
| 66 |
+
# now supports rloo. it treats different source of reward separately.
|
| 67 |
+
kl_ctrl:
|
| 68 |
+
type: fixed
|
| 69 |
+
kl_coef: 0.000
|
| 70 |
+
reward_gt_coef: 5
|
| 71 |
+
reward_dpo_coef: 5
|
| 72 |
+
|
| 73 |
+
trainer:
|
| 74 |
+
project_name: prime
|
| 75 |
+
experiment_name: examples
|
| 76 |
+
val_before_train: False
|
| 77 |
+
balance_batch: False
|
ICL/DAPO/verl-recipe/qat/config/nvfp4_w4a16.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"quant_method": "compressed-tensors",
|
| 3 |
+
"format": "nvfp4-pack-quantized",
|
| 4 |
+
"quantization_status": "compressed",
|
| 5 |
+
"config_groups": {
|
| 6 |
+
"group_0": {
|
| 7 |
+
"format": "nvfp4-pack-quantized",
|
| 8 |
+
"targets": [
|
| 9 |
+
"Linear"
|
| 10 |
+
],
|
| 11 |
+
"weights": {
|
| 12 |
+
"actorder": null,
|
| 13 |
+
"block_structure": null,
|
| 14 |
+
"dynamic": false,
|
| 15 |
+
"group_size": 16,
|
| 16 |
+
"num_bits": 4,
|
| 17 |
+
"observer": "minmax",
|
| 18 |
+
"observer_kwargs": {},
|
| 19 |
+
"strategy": "tensor_group",
|
| 20 |
+
"symmetric": true,
|
| 21 |
+
"type": "float"
|
| 22 |
+
},
|
| 23 |
+
"input_activations": null,
|
| 24 |
+
"output_activations": null
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
+
"ignore": [
|
| 28 |
+
"lm_head"
|
| 29 |
+
],
|
| 30 |
+
"kv_cache_scheme": null,
|
| 31 |
+
"sparsity_config": {},
|
| 32 |
+
"transform_config": {},
|
| 33 |
+
"global_compression_ratio": null
|
| 34 |
+
}
|
ICL/DAPO/verl-recipe/qat/config/nvfp4_w4a4.json
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"quant_method": "compressed-tensors",
|
| 3 |
+
"format": "nvfp4-pack-quantized",
|
| 4 |
+
"quantization_status": "compressed",
|
| 5 |
+
"config_groups": {
|
| 6 |
+
"group_0": {
|
| 7 |
+
"format": "nvfp4-pack-quantized",
|
| 8 |
+
"targets": [
|
| 9 |
+
"Linear"
|
| 10 |
+
],
|
| 11 |
+
"weights": {
|
| 12 |
+
"num_bits": 4,
|
| 13 |
+
"type": "float",
|
| 14 |
+
"symmetric": true,
|
| 15 |
+
"strategy": "tensor_group",
|
| 16 |
+
"group_size": 16,
|
| 17 |
+
"dynamic": false,
|
| 18 |
+
"observer": "minmax",
|
| 19 |
+
"observer_kwargs": {},
|
| 20 |
+
"actorder": null,
|
| 21 |
+
"block_structure": null
|
| 22 |
+
},
|
| 23 |
+
"input_activations": {
|
| 24 |
+
"num_bits": 4,
|
| 25 |
+
"type": "float",
|
| 26 |
+
"symmetric": true,
|
| 27 |
+
"strategy": "tensor_group",
|
| 28 |
+
"group_size": 16,
|
| 29 |
+
"dynamic": "local",
|
| 30 |
+
"observer": "minmax",
|
| 31 |
+
"observer_kwargs": {},
|
| 32 |
+
"actorder": null,
|
| 33 |
+
"block_structure": null
|
| 34 |
+
},
|
| 35 |
+
"output_activations": null
|
| 36 |
+
}
|
| 37 |
+
},
|
| 38 |
+
"ignore": [
|
| 39 |
+
"lm_head"
|
| 40 |
+
],
|
| 41 |
+
"kv_cache_scheme": null,
|
| 42 |
+
"sparsity_config": {},
|
| 43 |
+
"transform_config": {},
|
| 44 |
+
"global_compression_ratio": null
|
| 45 |
+
}
|
ICL/DAPO/verl-recipe/r1/config/evaluation.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
path: /tmp/math_Qwen2-7B-Instruct.parquet
|
| 3 |
+
prompt_key: prompt
|
| 4 |
+
response_key: responses
|
| 5 |
+
data_source_key: data_source
|
| 6 |
+
reward_model_key: reward_model
|
| 7 |
+
|
| 8 |
+
custom_reward_function:
|
| 9 |
+
path: null
|
| 10 |
+
name: compute_score
|
| 11 |
+
|
| 12 |
+
ray_kwargs:
|
| 13 |
+
ray_init:
|
| 14 |
+
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
|
ICL/DAPO/verl-recipe/r1/tasks/math_reward.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import contextlib
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from math_verify.metric import math_metric
|
| 18 |
+
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig
|
| 19 |
+
except ImportError:
|
| 20 |
+
print("To use Math-Verify, please install it first by running `pip install math-verify`.")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def compute_score(model_output: str, ground_truth: str) -> bool:
|
| 24 |
+
verify_func = math_metric(
|
| 25 |
+
gold_extraction_target=(LatexExtractionConfig(),),
|
| 26 |
+
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
|
| 27 |
+
)
|
| 28 |
+
ret_score = 0.0
|
| 29 |
+
|
| 30 |
+
# Wrap the ground truth in \boxed{} format for verification
|
| 31 |
+
ground_truth_boxed = "\\boxed{" + ground_truth + "}"
|
| 32 |
+
with contextlib.suppress(Exception):
|
| 33 |
+
ret_score, _ = verify_func([ground_truth_boxed], [model_output])
|
| 34 |
+
|
| 35 |
+
return ret_score
|
ICL/DAPO/verl-recipe/r1_ascend/figures/response_len.png
ADDED
|
ICL/DAPO/verl-recipe/r1_ascend/figures/rewards.png
ADDED
|
ICL/DAPO/verl-recipe/r1_ascend/figures/val_score.png
ADDED
|
ICL/DAPO/verl-recipe/rep_exp/config/_generated_ppo_megatron_trainer.yaml
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'
|
| 2 |
+
# in which it invokes 'python3 scripts/print_cfg.py --cfg job --config-name=ppo_megatron_trainer.yaml' to flatten the 'verl/trainer/config/ppo_megatron_trainer.yaml' config fields into a single file.
|
| 3 |
+
# Do not modify this file directly.
|
| 4 |
+
# The file is usually only for reference and never used.
|
| 5 |
+
|
| 6 |
+
actor_rollout_ref:
|
| 7 |
+
actor:
|
| 8 |
+
optim:
|
| 9 |
+
_target_: verl.workers.config.McoreOptimizerConfig
|
| 10 |
+
lr: 1.0e-06
|
| 11 |
+
lr_warmup_steps_ratio: 0.0
|
| 12 |
+
total_training_steps: -1
|
| 13 |
+
weight_decay: 0.01
|
| 14 |
+
lr_warmup_steps: -1
|
| 15 |
+
betas:
|
| 16 |
+
- 0.9
|
| 17 |
+
- 0.999
|
| 18 |
+
clip_grad: 1.0
|
| 19 |
+
optimizer: adam
|
| 20 |
+
lr_warmup_init: 0.0
|
| 21 |
+
lr_decay_steps: null
|
| 22 |
+
lr_decay_style: constant
|
| 23 |
+
min_lr: 0.0
|
| 24 |
+
weight_decay_incr_style: constant
|
| 25 |
+
lr_wsd_decay_style: exponential
|
| 26 |
+
lr_wsd_decay_steps: null
|
| 27 |
+
use_checkpoint_opt_param_scheduler: false
|
| 28 |
+
override_optimizer_config: {}
|
| 29 |
+
megatron:
|
| 30 |
+
_target_: verl.workers.config.McoreEngineConfig
|
| 31 |
+
param_offload: false
|
| 32 |
+
grad_offload: false
|
| 33 |
+
optimizer_offload: false
|
| 34 |
+
tensor_model_parallel_size: 1
|
| 35 |
+
expert_model_parallel_size: 1
|
| 36 |
+
expert_tensor_parallel_size: null
|
| 37 |
+
pipeline_model_parallel_size: 1
|
| 38 |
+
virtual_pipeline_model_parallel_size: null
|
| 39 |
+
context_parallel_size: 1
|
| 40 |
+
sequence_parallel: true
|
| 41 |
+
use_distributed_optimizer: true
|
| 42 |
+
use_dist_checkpointing: false
|
| 43 |
+
dist_checkpointing_path: null
|
| 44 |
+
dist_checkpointing_prefix: ''
|
| 45 |
+
seed: 42
|
| 46 |
+
override_ddp_config: {}
|
| 47 |
+
override_transformer_config:
|
| 48 |
+
recompute_granularity: null
|
| 49 |
+
recompute_modules:
|
| 50 |
+
- core_attn
|
| 51 |
+
recompute_method: null
|
| 52 |
+
recompute_num_layers: null
|
| 53 |
+
attention_backend: flash
|
| 54 |
+
override_mcore_model_config: {}
|
| 55 |
+
use_mbridge: false
|
| 56 |
+
forward_only: false
|
| 57 |
+
dtype: bfloat16
|
| 58 |
+
_target_: verl.workers.config.McoreActorConfig
|
| 59 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 60 |
+
strategy: megatron
|
| 61 |
+
ppo_mini_batch_size: 256
|
| 62 |
+
ppo_micro_batch_size: null
|
| 63 |
+
ppo_micro_batch_size_per_gpu: null
|
| 64 |
+
use_dynamic_bsz: false
|
| 65 |
+
ppo_max_token_len_per_gpu: 16384
|
| 66 |
+
clip_ratio: 0.2
|
| 67 |
+
clip_ratio_low: 0.2
|
| 68 |
+
clip_ratio_high: 0.2
|
| 69 |
+
freeze_vision_tower: false
|
| 70 |
+
policy_loss:
|
| 71 |
+
_target_: verl.workers.config.PolicyLossConfig
|
| 72 |
+
loss_mode: vanilla
|
| 73 |
+
clip_cov_ratio: 0.0002
|
| 74 |
+
clip_cov_lb: 1.0
|
| 75 |
+
clip_cov_ub: 5.0
|
| 76 |
+
kl_cov_ratio: 0.0002
|
| 77 |
+
ppo_kl_coef: 0.1
|
| 78 |
+
clip_ratio_c: 3.0
|
| 79 |
+
loss_agg_mode: token-mean
|
| 80 |
+
entropy_coeff: 0
|
| 81 |
+
use_kl_loss: false
|
| 82 |
+
use_torch_compile: true
|
| 83 |
+
kl_loss_coef: 0.001
|
| 84 |
+
kl_loss_type: low_var_kl
|
| 85 |
+
ppo_epochs: 1
|
| 86 |
+
shuffle: false
|
| 87 |
+
checkpoint:
|
| 88 |
+
_target_: verl.trainer.config.CheckpointConfig
|
| 89 |
+
save_contents:
|
| 90 |
+
- model
|
| 91 |
+
- optimizer
|
| 92 |
+
- extra
|
| 93 |
+
load_contents: ${.save_contents}
|
| 94 |
+
async_save: false
|
| 95 |
+
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
|
| 96 |
+
profiler:
|
| 97 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 98 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 99 |
+
enable: false
|
| 100 |
+
all_ranks: false
|
| 101 |
+
ranks: []
|
| 102 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 103 |
+
tool_config:
|
| 104 |
+
nsys:
|
| 105 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 106 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 107 |
+
npu:
|
| 108 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 109 |
+
contents: []
|
| 110 |
+
level: level1
|
| 111 |
+
analysis: true
|
| 112 |
+
discrete: false
|
| 113 |
+
torch:
|
| 114 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 115 |
+
step_start: 0
|
| 116 |
+
step_end: null
|
| 117 |
+
torch_memory:
|
| 118 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 119 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 120 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
| 121 |
+
data_loader_seed: 42
|
| 122 |
+
load_weight: true
|
| 123 |
+
ref:
|
| 124 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 125 |
+
strategy: megatron
|
| 126 |
+
use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
|
| 127 |
+
log_prob_micro_batch_size: null
|
| 128 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 129 |
+
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
| 130 |
+
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
|
| 131 |
+
profiler:
|
| 132 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 133 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 134 |
+
enable: false
|
| 135 |
+
all_ranks: false
|
| 136 |
+
ranks: []
|
| 137 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 138 |
+
tool_config:
|
| 139 |
+
nsys:
|
| 140 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 141 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 142 |
+
npu:
|
| 143 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 144 |
+
contents: []
|
| 145 |
+
level: level1
|
| 146 |
+
analysis: true
|
| 147 |
+
discrete: false
|
| 148 |
+
torch:
|
| 149 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 150 |
+
step_start: 0
|
| 151 |
+
step_end: null
|
| 152 |
+
torch_memory:
|
| 153 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 154 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 155 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
| 156 |
+
megatron:
|
| 157 |
+
_target_: verl.workers.config.McoreEngineConfig
|
| 158 |
+
param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False}
|
| 159 |
+
grad_offload: false
|
| 160 |
+
optimizer_offload: false
|
| 161 |
+
tensor_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.tensor_model_parallel_size,1}
|
| 162 |
+
expert_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_model_parallel_size,1}
|
| 163 |
+
expert_tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_tensor_parallel_size,null}
|
| 164 |
+
pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.pipeline_model_parallel_size,1}
|
| 165 |
+
virtual_pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size,null}
|
| 166 |
+
context_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.context_parallel_size,1}
|
| 167 |
+
sequence_parallel: true
|
| 168 |
+
use_distributed_optimizer: true
|
| 169 |
+
use_dist_checkpointing: false
|
| 170 |
+
dist_checkpointing_path: null
|
| 171 |
+
dist_checkpointing_prefix: ''
|
| 172 |
+
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
|
| 173 |
+
override_ddp_config: {}
|
| 174 |
+
override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}
|
| 175 |
+
override_mcore_model_config: {}
|
| 176 |
+
use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}
|
| 177 |
+
forward_only: true
|
| 178 |
+
dtype: bfloat16
|
| 179 |
+
_target_: verl.workers.config.McoreActorConfig
|
| 180 |
+
load_weight: true
|
| 181 |
+
rollout:
|
| 182 |
+
_target_: verl.workers.config.RolloutConfig
|
| 183 |
+
name: ???
|
| 184 |
+
mode: async
|
| 185 |
+
temperature: 1.0
|
| 186 |
+
top_k: -1
|
| 187 |
+
top_p: 1
|
| 188 |
+
prompt_length: ${oc.select:data.max_prompt_length,512}
|
| 189 |
+
response_length: ${oc.select:data.max_response_length,512}
|
| 190 |
+
dtype: bfloat16
|
| 191 |
+
gpu_memory_utilization: 0.5
|
| 192 |
+
ignore_eos: false
|
| 193 |
+
enforce_eager: false
|
| 194 |
+
cudagraph_capture_sizes: null
|
| 195 |
+
free_cache_engine: true
|
| 196 |
+
tensor_model_parallel_size: 2
|
| 197 |
+
data_parallel_size: 1
|
| 198 |
+
expert_parallel_size: 1
|
| 199 |
+
pipeline_model_parallel_size: 1
|
| 200 |
+
max_num_batched_tokens: 8192
|
| 201 |
+
max_model_len: null
|
| 202 |
+
max_num_seqs: 1024
|
| 203 |
+
enable_chunked_prefill: true
|
| 204 |
+
enable_prefix_caching: true
|
| 205 |
+
load_format: dummy
|
| 206 |
+
log_prob_micro_batch_size: null
|
| 207 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 208 |
+
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
| 209 |
+
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
|
| 210 |
+
disable_log_stats: true
|
| 211 |
+
do_sample: true
|
| 212 |
+
'n': 1
|
| 213 |
+
over_sample_rate: 0
|
| 214 |
+
multi_stage_wake_up: false
|
| 215 |
+
engine_kwargs:
|
| 216 |
+
vllm: {}
|
| 217 |
+
sglang: {}
|
| 218 |
+
val_kwargs:
|
| 219 |
+
_target_: verl.workers.config.SamplingConfig
|
| 220 |
+
top_k: -1
|
| 221 |
+
top_p: 1.0
|
| 222 |
+
temperature: 0
|
| 223 |
+
'n': 1
|
| 224 |
+
do_sample: false
|
| 225 |
+
multi_turn:
|
| 226 |
+
_target_: verl.workers.config.MultiTurnConfig
|
| 227 |
+
enable: false
|
| 228 |
+
max_assistant_turns: null
|
| 229 |
+
tool_config_path: null
|
| 230 |
+
max_user_turns: null
|
| 231 |
+
max_parallel_calls: 1
|
| 232 |
+
max_tool_response_length: 256
|
| 233 |
+
tool_response_truncate_side: middle
|
| 234 |
+
interaction_config_path: null
|
| 235 |
+
use_inference_chat_template: false
|
| 236 |
+
tokenization_sanity_check_mode: strict
|
| 237 |
+
format: hermes
|
| 238 |
+
num_repeat_rollouts: null
|
| 239 |
+
calculate_log_probs: false
|
| 240 |
+
agent:
|
| 241 |
+
_target_: verl.workers.config.AgentLoopConfig
|
| 242 |
+
num_workers: 8
|
| 243 |
+
default_agent_loop: single_turn_agent
|
| 244 |
+
agent_loop_config_path: null
|
| 245 |
+
custom_async_server:
|
| 246 |
+
_target_: verl.workers.config.CustomAsyncServerConfig
|
| 247 |
+
path: null
|
| 248 |
+
name: null
|
| 249 |
+
update_weights_bucket_megabytes: 512
|
| 250 |
+
trace:
|
| 251 |
+
_target_: verl.workers.config.TraceConfig
|
| 252 |
+
backend: null
|
| 253 |
+
token2text: false
|
| 254 |
+
max_samples_per_step_per_worker: null
|
| 255 |
+
skip_rollout: false
|
| 256 |
+
skip_dump_dir: /tmp/rollout_dump
|
| 257 |
+
skip_tokenizer_init: true
|
| 258 |
+
profiler:
|
| 259 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 260 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 261 |
+
enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false}
|
| 262 |
+
all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}
|
| 263 |
+
ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}
|
| 264 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 265 |
+
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
|
| 266 |
+
prometheus:
|
| 267 |
+
_target_: verl.workers.config.PrometheusConfig
|
| 268 |
+
enable: false
|
| 269 |
+
port: 9090
|
| 270 |
+
file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
|
| 271 |
+
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
|
| 272 |
+
quantization: null
|
| 273 |
+
layer_name_map:
|
| 274 |
+
qkv_layer_name: qkv
|
| 275 |
+
gate_proj_layer_name: gate_up
|
| 276 |
+
model:
|
| 277 |
+
_target_: verl.workers.config.HFModelConfig
|
| 278 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 279 |
+
hf_config_path: null
|
| 280 |
+
tokenizer_path: null
|
| 281 |
+
use_shm: false
|
| 282 |
+
trust_remote_code: false
|
| 283 |
+
custom_chat_template: null
|
| 284 |
+
external_lib: null
|
| 285 |
+
override_config:
|
| 286 |
+
model_config: {}
|
| 287 |
+
moe_config:
|
| 288 |
+
freeze_moe_router: false
|
| 289 |
+
enable_gradient_checkpointing: true
|
| 290 |
+
enable_activation_offload: false
|
| 291 |
+
use_remove_padding: false
|
| 292 |
+
lora_rank: 0
|
| 293 |
+
lora_alpha: 16
|
| 294 |
+
target_modules: all-linear
|
| 295 |
+
exclude_modules: null
|
| 296 |
+
lora_adapter_path: null
|
| 297 |
+
use_liger: false
|
| 298 |
+
use_fused_kernels: false
|
| 299 |
+
fused_kernel_options:
|
| 300 |
+
impl_backend: torch
|
| 301 |
+
hybrid_engine: true
|
| 302 |
+
nccl_timeout: 600
|
| 303 |
+
data:
|
| 304 |
+
tokenizer: null
|
| 305 |
+
use_shm: false
|
| 306 |
+
train_files: ~/data/rlhf/gsm8k/train.parquet
|
| 307 |
+
val_files: ~/data/rlhf/gsm8k/test.parquet
|
| 308 |
+
train_max_samples: -1
|
| 309 |
+
val_max_samples: -1
|
| 310 |
+
prompt_key: prompt
|
| 311 |
+
reward_fn_key: data_source
|
| 312 |
+
max_prompt_length: 512
|
| 313 |
+
max_response_length: 512
|
| 314 |
+
train_batch_size: 1024
|
| 315 |
+
val_batch_size: null
|
| 316 |
+
tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path,
|
| 317 |
+
null}
|
| 318 |
+
return_raw_input_ids: false
|
| 319 |
+
return_raw_chat: true
|
| 320 |
+
return_full_prompt: false
|
| 321 |
+
shuffle: true
|
| 322 |
+
seed: null
|
| 323 |
+
dataloader_num_workers: 8
|
| 324 |
+
image_patch_size: 14
|
| 325 |
+
validation_shuffle: false
|
| 326 |
+
filter_overlong_prompts: false
|
| 327 |
+
filter_overlong_prompts_workers: 1
|
| 328 |
+
truncation: error
|
| 329 |
+
image_key: images
|
| 330 |
+
video_key: videos
|
| 331 |
+
trust_remote_code: false
|
| 332 |
+
custom_cls:
|
| 333 |
+
path: null
|
| 334 |
+
name: null
|
| 335 |
+
return_multi_modal_inputs: true
|
| 336 |
+
sampler:
|
| 337 |
+
class_path: null
|
| 338 |
+
class_name: null
|
| 339 |
+
datagen:
|
| 340 |
+
path: null
|
| 341 |
+
name: null
|
| 342 |
+
apply_chat_template_kwargs: {}
|
| 343 |
+
critic:
|
| 344 |
+
optim:
|
| 345 |
+
_target_: verl.workers.config.McoreOptimizerConfig
|
| 346 |
+
lr: 1.0e-05
|
| 347 |
+
lr_warmup_steps_ratio: 0.0
|
| 348 |
+
total_training_steps: -1
|
| 349 |
+
weight_decay: 0.01
|
| 350 |
+
lr_warmup_steps: -1
|
| 351 |
+
betas:
|
| 352 |
+
- 0.9
|
| 353 |
+
- 0.999
|
| 354 |
+
clip_grad: 1.0
|
| 355 |
+
optimizer: adam
|
| 356 |
+
lr_warmup_init: 0.0
|
| 357 |
+
lr_decay_steps: null
|
| 358 |
+
lr_decay_style: constant
|
| 359 |
+
min_lr: 0.0
|
| 360 |
+
weight_decay_incr_style: constant
|
| 361 |
+
lr_wsd_decay_style: exponential
|
| 362 |
+
lr_wsd_decay_steps: null
|
| 363 |
+
use_checkpoint_opt_param_scheduler: false
|
| 364 |
+
override_optimizer_config: {}
|
| 365 |
+
megatron:
|
| 366 |
+
_target_: verl.workers.config.McoreEngineConfig
|
| 367 |
+
param_offload: false
|
| 368 |
+
grad_offload: false
|
| 369 |
+
optimizer_offload: false
|
| 370 |
+
tensor_model_parallel_size: 1
|
| 371 |
+
expert_model_parallel_size: 1
|
| 372 |
+
expert_tensor_parallel_size: null
|
| 373 |
+
pipeline_model_parallel_size: 1
|
| 374 |
+
virtual_pipeline_model_parallel_size: null
|
| 375 |
+
context_parallel_size: 1
|
| 376 |
+
sequence_parallel: true
|
| 377 |
+
use_distributed_optimizer: true
|
| 378 |
+
use_dist_checkpointing: false
|
| 379 |
+
dist_checkpointing_path: null
|
| 380 |
+
dist_checkpointing_prefix: ''
|
| 381 |
+
seed: 42
|
| 382 |
+
override_ddp_config: {}
|
| 383 |
+
override_transformer_config:
|
| 384 |
+
recompute_granularity: null
|
| 385 |
+
recompute_modules:
|
| 386 |
+
- core_attn
|
| 387 |
+
recompute_method: null
|
| 388 |
+
recompute_num_layers: null
|
| 389 |
+
attention_backend: flash
|
| 390 |
+
override_mcore_model_config: {}
|
| 391 |
+
use_mbridge: false
|
| 392 |
+
forward_only: false
|
| 393 |
+
dtype: bfloat16
|
| 394 |
+
_target_: verl.workers.config.McoreCriticConfig
|
| 395 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 396 |
+
strategy: megatron
|
| 397 |
+
enable: null
|
| 398 |
+
model:
|
| 399 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 400 |
+
tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"}
|
| 401 |
+
override_config:
|
| 402 |
+
model_config: {}
|
| 403 |
+
moe_config:
|
| 404 |
+
freeze_moe_router: false
|
| 405 |
+
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
|
| 406 |
+
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
|
| 407 |
+
_target_: verl.trainer.config.BaseModelConfig
|
| 408 |
+
ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}
|
| 409 |
+
ppo_micro_batch_size: null
|
| 410 |
+
ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}
|
| 411 |
+
use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
| 412 |
+
ppo_max_token_len_per_gpu: 32768
|
| 413 |
+
forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}
|
| 414 |
+
ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}
|
| 415 |
+
shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}
|
| 416 |
+
cliprange_value: 0.5
|
| 417 |
+
loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
|
| 418 |
+
checkpoint:
|
| 419 |
+
_target_: verl.trainer.config.CheckpointConfig
|
| 420 |
+
save_contents:
|
| 421 |
+
- model
|
| 422 |
+
- optimizer
|
| 423 |
+
- extra
|
| 424 |
+
load_contents: ${.save_contents}
|
| 425 |
+
async_save: false
|
| 426 |
+
profiler:
|
| 427 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 428 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 429 |
+
enable: false
|
| 430 |
+
all_ranks: false
|
| 431 |
+
ranks: []
|
| 432 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 433 |
+
tool_config:
|
| 434 |
+
nsys:
|
| 435 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 436 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 437 |
+
npu:
|
| 438 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 439 |
+
contents: []
|
| 440 |
+
level: level1
|
| 441 |
+
analysis: true
|
| 442 |
+
discrete: false
|
| 443 |
+
torch:
|
| 444 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 445 |
+
step_start: 0
|
| 446 |
+
step_end: null
|
| 447 |
+
torch_memory:
|
| 448 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 449 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 450 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
| 451 |
+
nccl_timeout: 600
|
| 452 |
+
load_weight: true
|
| 453 |
+
data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}
|
| 454 |
+
reward_model:
|
| 455 |
+
enable: false
|
| 456 |
+
enable_resource_pool: false
|
| 457 |
+
n_gpus_per_node: 0
|
| 458 |
+
nnodes: 0
|
| 459 |
+
strategy: megatron
|
| 460 |
+
model:
|
| 461 |
+
input_tokenizer: ${actor_rollout_ref.model.path}
|
| 462 |
+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
| 463 |
+
external_lib: ${actor_rollout_ref.model.external_lib}
|
| 464 |
+
trust_remote_code: false
|
| 465 |
+
micro_batch_size: null
|
| 466 |
+
micro_batch_size_per_gpu: null
|
| 467 |
+
max_length: null
|
| 468 |
+
use_dynamic_bsz: ${critic.use_dynamic_bsz}
|
| 469 |
+
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
|
| 470 |
+
reward_manager: naive
|
| 471 |
+
launch_reward_fn_async: false
|
| 472 |
+
sandbox_fusion:
|
| 473 |
+
url: null
|
| 474 |
+
max_concurrent: 64
|
| 475 |
+
memory_limit_mb: 1024
|
| 476 |
+
profiler:
|
| 477 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 478 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 479 |
+
enable: false
|
| 480 |
+
all_ranks: false
|
| 481 |
+
ranks: []
|
| 482 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 483 |
+
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
|
| 484 |
+
nccl_timeout: 600
|
| 485 |
+
megatron:
|
| 486 |
+
_target_: verl.workers.config.MegatronEngineConfig
|
| 487 |
+
param_offload: false
|
| 488 |
+
tensor_model_parallel_size: 1
|
| 489 |
+
expert_model_parallel_size: 1
|
| 490 |
+
expert_tensor_parallel_size: null
|
| 491 |
+
pipeline_model_parallel_size: 1
|
| 492 |
+
virtual_pipeline_model_parallel_size: null
|
| 493 |
+
context_parallel_size: 1
|
| 494 |
+
sequence_parallel: true
|
| 495 |
+
use_distributed_optimizer: false
|
| 496 |
+
use_dist_checkpointing: false
|
| 497 |
+
dist_checkpointing_path: null
|
| 498 |
+
dist_checkpointing_prefix: ''
|
| 499 |
+
seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42}
|
| 500 |
+
override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}}
|
| 501 |
+
use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False}
|
| 502 |
+
dtype: bfloat16
|
| 503 |
+
load_weight: true
|
| 504 |
+
algorithm:
|
| 505 |
+
rollout_correction:
|
| 506 |
+
rollout_is: null
|
| 507 |
+
rollout_is_threshold: 2.0
|
| 508 |
+
rollout_rs: null
|
| 509 |
+
rollout_rs_threshold: null
|
| 510 |
+
rollout_rs_threshold_lower: null
|
| 511 |
+
rollout_token_veto_threshold: null
|
| 512 |
+
bypass_mode: false
|
| 513 |
+
use_policy_gradient: false
|
| 514 |
+
rollout_is_batch_normalize: false
|
| 515 |
+
_target_: verl.trainer.config.AlgoConfig
|
| 516 |
+
gamma: 1.0
|
| 517 |
+
lam: 1.0
|
| 518 |
+
adv_estimator: gae
|
| 519 |
+
norm_adv_by_std_in_grpo: true
|
| 520 |
+
use_kl_in_reward: false
|
| 521 |
+
kl_penalty: kl
|
| 522 |
+
kl_ctrl:
|
| 523 |
+
_target_: verl.trainer.config.KLControlConfig
|
| 524 |
+
type: fixed
|
| 525 |
+
kl_coef: 0.001
|
| 526 |
+
horizon: 10000
|
| 527 |
+
target_kl: 0.1
|
| 528 |
+
use_pf_ppo: false
|
| 529 |
+
pf_ppo:
|
| 530 |
+
reweight_method: pow
|
| 531 |
+
weight_pow: 2.0
|
| 532 |
+
custom_reward_function:
|
| 533 |
+
path: null
|
| 534 |
+
name: compute_score
|
| 535 |
+
trainer:
|
| 536 |
+
balance_batch: true
|
| 537 |
+
total_epochs: 30
|
| 538 |
+
total_training_steps: null
|
| 539 |
+
project_name: verl_examples
|
| 540 |
+
experiment_name: gsm8k
|
| 541 |
+
logger:
|
| 542 |
+
- console
|
| 543 |
+
- wandb
|
| 544 |
+
log_val_generations: 0
|
| 545 |
+
nnodes: 1
|
| 546 |
+
n_gpus_per_node: 8
|
| 547 |
+
save_freq: -1
|
| 548 |
+
esi_redundant_time: 0
|
| 549 |
+
resume_mode: auto
|
| 550 |
+
resume_from_path: null
|
| 551 |
+
del_local_ckpt_after_load: false
|
| 552 |
+
val_before_train: true
|
| 553 |
+
test_freq: -1
|
| 554 |
+
critic_warmup: 0
|
| 555 |
+
default_hdfs_dir: null
|
| 556 |
+
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
| 557 |
+
max_actor_ckpt_to_keep: null
|
| 558 |
+
max_critic_ckpt_to_keep: null
|
| 559 |
+
ray_wait_register_center_timeout: 300
|
| 560 |
+
device: cuda
|
| 561 |
+
rollout_data_dir: null
|
| 562 |
+
use_legacy_worker_impl: auto
|
| 563 |
+
global_profiler:
|
| 564 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 565 |
+
tool: null
|
| 566 |
+
steps: null
|
| 567 |
+
profile_continuous_steps: false
|
| 568 |
+
save_path: outputs/profile
|
| 569 |
+
global_tool_config:
|
| 570 |
+
nsys:
|
| 571 |
+
discrete: false
|
| 572 |
+
controller_nsight_options:
|
| 573 |
+
trace: cuda,nvtx,cublas,ucx
|
| 574 |
+
cuda-memory-usage: 'true'
|
| 575 |
+
cuda-graph-trace: graph
|
| 576 |
+
worker_nsight_options:
|
| 577 |
+
trace: cuda,nvtx,cublas,ucx
|
| 578 |
+
cuda-memory-usage: 'true'
|
| 579 |
+
cuda-graph-trace: graph
|
| 580 |
+
capture-range: cudaProfilerApi
|
| 581 |
+
capture-range-end: null
|
| 582 |
+
kill: none
|
| 583 |
+
torch_memory:
|
| 584 |
+
trace_alloc_max_entries: 100000
|
| 585 |
+
stack_depth: 32
|
| 586 |
+
context: all
|
| 587 |
+
stacks: all
|
| 588 |
+
kw_args: {}
|
| 589 |
+
transfer_queue:
|
| 590 |
+
enable: false
|
| 591 |
+
ray_kwargs:
|
| 592 |
+
ray_init:
|
| 593 |
+
num_cpus: null
|
| 594 |
+
timeline_json_file: null
|
ICL/DAPO/verl-recipe/rep_exp/config/_generated_ppo_trainer.yaml
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'
|
| 2 |
+
# in which it invokes 'python3 scripts/print_cfg.py --cfg job ' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file.
|
| 3 |
+
# Do not modify this file directly.
|
| 4 |
+
# The file is usually only for reference and never used.
|
| 5 |
+
|
| 6 |
+
actor_rollout_ref:
|
| 7 |
+
actor:
|
| 8 |
+
optim:
|
| 9 |
+
_target_: verl.workers.config.FSDPOptimizerConfig
|
| 10 |
+
optimizer: AdamW
|
| 11 |
+
optimizer_impl: torch.optim
|
| 12 |
+
lr: 1.0e-06
|
| 13 |
+
lr_warmup_steps_ratio: 0.0
|
| 14 |
+
total_training_steps: -1
|
| 15 |
+
weight_decay: 0.01
|
| 16 |
+
lr_warmup_steps: -1
|
| 17 |
+
betas:
|
| 18 |
+
- 0.9
|
| 19 |
+
- 0.999
|
| 20 |
+
clip_grad: 1.0
|
| 21 |
+
min_lr_ratio: 0.0
|
| 22 |
+
num_cycles: 0.5
|
| 23 |
+
lr_scheduler_type: constant
|
| 24 |
+
warmup_style: null
|
| 25 |
+
override_optimizer_config: null
|
| 26 |
+
fsdp_config:
|
| 27 |
+
_target_: verl.workers.config.FSDPEngineConfig
|
| 28 |
+
wrap_policy:
|
| 29 |
+
min_num_params: 0
|
| 30 |
+
param_offload: false
|
| 31 |
+
optimizer_offload: false
|
| 32 |
+
offload_policy: false
|
| 33 |
+
reshard_after_forward: true
|
| 34 |
+
fsdp_size: -1
|
| 35 |
+
forward_prefetch: false
|
| 36 |
+
model_dtype: fp32
|
| 37 |
+
use_orig_params: false
|
| 38 |
+
ulysses_sequence_parallel_size: 1
|
| 39 |
+
entropy_from_logits_with_chunking: false
|
| 40 |
+
use_torch_compile: true
|
| 41 |
+
entropy_checkpointing: false
|
| 42 |
+
forward_only: false
|
| 43 |
+
strategy: fsdp
|
| 44 |
+
dtype: bfloat16
|
| 45 |
+
_target_: verl.workers.config.FSDPActorConfig
|
| 46 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 47 |
+
strategy: fsdp
|
| 48 |
+
ppo_mini_batch_size: 256
|
| 49 |
+
ppo_micro_batch_size: null
|
| 50 |
+
ppo_micro_batch_size_per_gpu: null
|
| 51 |
+
use_dynamic_bsz: false
|
| 52 |
+
ppo_max_token_len_per_gpu: 16384
|
| 53 |
+
clip_ratio: 0.2
|
| 54 |
+
clip_ratio_low: 0.2
|
| 55 |
+
clip_ratio_high: 0.2
|
| 56 |
+
freeze_vision_tower: false
|
| 57 |
+
policy_loss:
|
| 58 |
+
_target_: verl.workers.config.PolicyLossConfig
|
| 59 |
+
loss_mode: vanilla
|
| 60 |
+
clip_cov_ratio: 0.0002
|
| 61 |
+
clip_cov_lb: 1.0
|
| 62 |
+
clip_cov_ub: 5.0
|
| 63 |
+
kl_cov_ratio: 0.0002
|
| 64 |
+
ppo_kl_coef: 0.1
|
| 65 |
+
clip_ratio_c: 3.0
|
| 66 |
+
loss_agg_mode: token-mean
|
| 67 |
+
entropy_coeff: 0
|
| 68 |
+
use_kl_loss: false
|
| 69 |
+
use_torch_compile: true
|
| 70 |
+
kl_loss_coef: 0.001
|
| 71 |
+
kl_loss_type: low_var_kl
|
| 72 |
+
ppo_epochs: 1
|
| 73 |
+
shuffle: false
|
| 74 |
+
checkpoint:
|
| 75 |
+
_target_: verl.trainer.config.CheckpointConfig
|
| 76 |
+
save_contents:
|
| 77 |
+
- model
|
| 78 |
+
- optimizer
|
| 79 |
+
- extra
|
| 80 |
+
load_contents: ${.save_contents}
|
| 81 |
+
async_save: false
|
| 82 |
+
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
|
| 83 |
+
profiler:
|
| 84 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 85 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 86 |
+
enable: false
|
| 87 |
+
all_ranks: false
|
| 88 |
+
ranks: []
|
| 89 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 90 |
+
tool_config:
|
| 91 |
+
nsys:
|
| 92 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 93 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 94 |
+
npu:
|
| 95 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 96 |
+
contents: []
|
| 97 |
+
level: level1
|
| 98 |
+
analysis: true
|
| 99 |
+
discrete: false
|
| 100 |
+
torch:
|
| 101 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 102 |
+
step_start: 0
|
| 103 |
+
step_end: null
|
| 104 |
+
torch_memory:
|
| 105 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 106 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 107 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
| 108 |
+
grad_clip: 1.0
|
| 109 |
+
ulysses_sequence_parallel_size: 1
|
| 110 |
+
entropy_from_logits_with_chunking: false
|
| 111 |
+
entropy_checkpointing: false
|
| 112 |
+
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
|
| 113 |
+
ref:
|
| 114 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 115 |
+
strategy: ${actor_rollout_ref.actor.strategy}
|
| 116 |
+
use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true}
|
| 117 |
+
log_prob_micro_batch_size: null
|
| 118 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 119 |
+
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
| 120 |
+
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
|
| 121 |
+
profiler:
|
| 122 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 123 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 124 |
+
enable: false
|
| 125 |
+
all_ranks: false
|
| 126 |
+
ranks: []
|
| 127 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 128 |
+
tool_config:
|
| 129 |
+
nsys:
|
| 130 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 131 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 132 |
+
npu:
|
| 133 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 134 |
+
contents: []
|
| 135 |
+
level: level1
|
| 136 |
+
analysis: true
|
| 137 |
+
discrete: false
|
| 138 |
+
torch:
|
| 139 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 140 |
+
step_start: 0
|
| 141 |
+
step_end: null
|
| 142 |
+
torch_memory:
|
| 143 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 144 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 145 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
| 146 |
+
fsdp_config:
|
| 147 |
+
_target_: verl.workers.config.FSDPEngineConfig
|
| 148 |
+
wrap_policy:
|
| 149 |
+
min_num_params: 0
|
| 150 |
+
param_offload: false
|
| 151 |
+
optimizer_offload: false
|
| 152 |
+
offload_policy: false
|
| 153 |
+
reshard_after_forward: true
|
| 154 |
+
fsdp_size: -1
|
| 155 |
+
forward_prefetch: false
|
| 156 |
+
model_dtype: fp32
|
| 157 |
+
use_orig_params: false
|
| 158 |
+
ulysses_sequence_parallel_size: 1
|
| 159 |
+
entropy_from_logits_with_chunking: false
|
| 160 |
+
use_torch_compile: true
|
| 161 |
+
entropy_checkpointing: false
|
| 162 |
+
forward_only: true
|
| 163 |
+
strategy: fsdp
|
| 164 |
+
dtype: bfloat16
|
| 165 |
+
_target_: verl.workers.config.FSDPActorConfig
|
| 166 |
+
ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1}
|
| 167 |
+
entropy_from_logits_with_chunking: false
|
| 168 |
+
entropy_checkpointing: false
|
| 169 |
+
rollout:
|
| 170 |
+
_target_: verl.workers.config.RolloutConfig
|
| 171 |
+
name: ???
|
| 172 |
+
mode: async
|
| 173 |
+
temperature: 1.0
|
| 174 |
+
top_k: -1
|
| 175 |
+
top_p: 1
|
| 176 |
+
prompt_length: ${oc.select:data.max_prompt_length,512}
|
| 177 |
+
response_length: ${oc.select:data.max_response_length,512}
|
| 178 |
+
dtype: bfloat16
|
| 179 |
+
gpu_memory_utilization: 0.5
|
| 180 |
+
ignore_eos: false
|
| 181 |
+
enforce_eager: false
|
| 182 |
+
cudagraph_capture_sizes: null
|
| 183 |
+
free_cache_engine: true
|
| 184 |
+
tensor_model_parallel_size: 2
|
| 185 |
+
data_parallel_size: 1
|
| 186 |
+
expert_parallel_size: 1
|
| 187 |
+
pipeline_model_parallel_size: 1
|
| 188 |
+
max_num_batched_tokens: 8192
|
| 189 |
+
max_model_len: null
|
| 190 |
+
max_num_seqs: 1024
|
| 191 |
+
enable_chunked_prefill: true
|
| 192 |
+
enable_prefix_caching: true
|
| 193 |
+
load_format: dummy
|
| 194 |
+
log_prob_micro_batch_size: null
|
| 195 |
+
log_prob_micro_batch_size_per_gpu: null
|
| 196 |
+
log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
| 197 |
+
log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384}
|
| 198 |
+
disable_log_stats: true
|
| 199 |
+
do_sample: true
|
| 200 |
+
'n': 1
|
| 201 |
+
over_sample_rate: 0
|
| 202 |
+
multi_stage_wake_up: false
|
| 203 |
+
engine_kwargs:
|
| 204 |
+
vllm: {}
|
| 205 |
+
sglang: {}
|
| 206 |
+
val_kwargs:
|
| 207 |
+
_target_: verl.workers.config.SamplingConfig
|
| 208 |
+
top_k: -1
|
| 209 |
+
top_p: 1.0
|
| 210 |
+
temperature: 0
|
| 211 |
+
'n': 1
|
| 212 |
+
do_sample: false
|
| 213 |
+
multi_turn:
|
| 214 |
+
_target_: verl.workers.config.MultiTurnConfig
|
| 215 |
+
enable: false
|
| 216 |
+
max_assistant_turns: null
|
| 217 |
+
tool_config_path: null
|
| 218 |
+
max_user_turns: null
|
| 219 |
+
max_parallel_calls: 1
|
| 220 |
+
max_tool_response_length: 256
|
| 221 |
+
tool_response_truncate_side: middle
|
| 222 |
+
interaction_config_path: null
|
| 223 |
+
use_inference_chat_template: false
|
| 224 |
+
tokenization_sanity_check_mode: strict
|
| 225 |
+
format: hermes
|
| 226 |
+
num_repeat_rollouts: null
|
| 227 |
+
calculate_log_probs: false
|
| 228 |
+
agent:
|
| 229 |
+
_target_: verl.workers.config.AgentLoopConfig
|
| 230 |
+
num_workers: 8
|
| 231 |
+
default_agent_loop: single_turn_agent
|
| 232 |
+
agent_loop_config_path: null
|
| 233 |
+
custom_async_server:
|
| 234 |
+
_target_: verl.workers.config.CustomAsyncServerConfig
|
| 235 |
+
path: null
|
| 236 |
+
name: null
|
| 237 |
+
update_weights_bucket_megabytes: 512
|
| 238 |
+
trace:
|
| 239 |
+
_target_: verl.workers.config.TraceConfig
|
| 240 |
+
backend: null
|
| 241 |
+
token2text: false
|
| 242 |
+
max_samples_per_step_per_worker: null
|
| 243 |
+
skip_rollout: false
|
| 244 |
+
skip_dump_dir: /tmp/rollout_dump
|
| 245 |
+
skip_tokenizer_init: true
|
| 246 |
+
profiler:
|
| 247 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 248 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 249 |
+
enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false}
|
| 250 |
+
all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}
|
| 251 |
+
ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}
|
| 252 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 253 |
+
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
|
| 254 |
+
prometheus:
|
| 255 |
+
_target_: verl.workers.config.PrometheusConfig
|
| 256 |
+
enable: false
|
| 257 |
+
port: 9090
|
| 258 |
+
file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml
|
| 259 |
+
served_model_name: ${oc.select:actor_rollout_ref.model.path,null}
|
| 260 |
+
layered_summon: false
|
| 261 |
+
model:
|
| 262 |
+
_target_: verl.workers.config.HFModelConfig
|
| 263 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 264 |
+
hf_config_path: null
|
| 265 |
+
tokenizer_path: null
|
| 266 |
+
use_shm: false
|
| 267 |
+
trust_remote_code: false
|
| 268 |
+
custom_chat_template: null
|
| 269 |
+
external_lib: null
|
| 270 |
+
override_config: {}
|
| 271 |
+
enable_gradient_checkpointing: true
|
| 272 |
+
enable_activation_offload: false
|
| 273 |
+
use_remove_padding: false
|
| 274 |
+
lora_rank: 0
|
| 275 |
+
lora_alpha: 16
|
| 276 |
+
target_modules: all-linear
|
| 277 |
+
exclude_modules: null
|
| 278 |
+
lora_adapter_path: null
|
| 279 |
+
use_liger: false
|
| 280 |
+
use_fused_kernels: false
|
| 281 |
+
fused_kernel_options:
|
| 282 |
+
impl_backend: torch
|
| 283 |
+
hybrid_engine: true
|
| 284 |
+
nccl_timeout: 600
|
| 285 |
+
data:
|
| 286 |
+
tokenizer: null
|
| 287 |
+
use_shm: false
|
| 288 |
+
train_files: ~/data/rlhf/gsm8k/train.parquet
|
| 289 |
+
val_files: ~/data/rlhf/gsm8k/test.parquet
|
| 290 |
+
train_max_samples: -1
|
| 291 |
+
val_max_samples: -1
|
| 292 |
+
prompt_key: prompt
|
| 293 |
+
reward_fn_key: data_source
|
| 294 |
+
max_prompt_length: 512
|
| 295 |
+
max_response_length: 512
|
| 296 |
+
train_batch_size: 1024
|
| 297 |
+
val_batch_size: null
|
| 298 |
+
tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path,
|
| 299 |
+
null}
|
| 300 |
+
return_raw_input_ids: false
|
| 301 |
+
return_raw_chat: true
|
| 302 |
+
return_full_prompt: false
|
| 303 |
+
shuffle: true
|
| 304 |
+
seed: null
|
| 305 |
+
dataloader_num_workers: 8
|
| 306 |
+
image_patch_size: 14
|
| 307 |
+
validation_shuffle: false
|
| 308 |
+
filter_overlong_prompts: false
|
| 309 |
+
filter_overlong_prompts_workers: 1
|
| 310 |
+
truncation: error
|
| 311 |
+
image_key: images
|
| 312 |
+
video_key: videos
|
| 313 |
+
trust_remote_code: false
|
| 314 |
+
custom_cls:
|
| 315 |
+
path: null
|
| 316 |
+
name: null
|
| 317 |
+
return_multi_modal_inputs: true
|
| 318 |
+
sampler:
|
| 319 |
+
class_path: null
|
| 320 |
+
class_name: null
|
| 321 |
+
datagen:
|
| 322 |
+
path: null
|
| 323 |
+
name: null
|
| 324 |
+
apply_chat_template_kwargs: {}
|
| 325 |
+
critic:
|
| 326 |
+
optim:
|
| 327 |
+
_target_: verl.workers.config.FSDPOptimizerConfig
|
| 328 |
+
optimizer: AdamW
|
| 329 |
+
optimizer_impl: torch.optim
|
| 330 |
+
lr: 1.0e-05
|
| 331 |
+
lr_warmup_steps_ratio: 0.0
|
| 332 |
+
total_training_steps: -1
|
| 333 |
+
weight_decay: 0.01
|
| 334 |
+
lr_warmup_steps: -1
|
| 335 |
+
betas:
|
| 336 |
+
- 0.9
|
| 337 |
+
- 0.999
|
| 338 |
+
clip_grad: 1.0
|
| 339 |
+
min_lr_ratio: 0.0
|
| 340 |
+
num_cycles: 0.5
|
| 341 |
+
lr_scheduler_type: constant
|
| 342 |
+
warmup_style: null
|
| 343 |
+
override_optimizer_config: null
|
| 344 |
+
model:
|
| 345 |
+
fsdp_config:
|
| 346 |
+
_target_: verl.workers.config.FSDPEngineConfig
|
| 347 |
+
wrap_policy:
|
| 348 |
+
min_num_params: 0
|
| 349 |
+
param_offload: false
|
| 350 |
+
optimizer_offload: false
|
| 351 |
+
offload_policy: false
|
| 352 |
+
reshard_after_forward: true
|
| 353 |
+
fsdp_size: -1
|
| 354 |
+
forward_prefetch: false
|
| 355 |
+
model_dtype: fp32
|
| 356 |
+
use_orig_params: false
|
| 357 |
+
ulysses_sequence_parallel_size: 1
|
| 358 |
+
entropy_from_logits_with_chunking: false
|
| 359 |
+
use_torch_compile: true
|
| 360 |
+
entropy_checkpointing: false
|
| 361 |
+
forward_only: false
|
| 362 |
+
strategy: fsdp
|
| 363 |
+
dtype: bfloat16
|
| 364 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 365 |
+
tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"}
|
| 366 |
+
override_config: {}
|
| 367 |
+
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
|
| 368 |
+
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
|
| 369 |
+
_target_: verl.workers.config.FSDPCriticModelCfg
|
| 370 |
+
use_shm: false
|
| 371 |
+
enable_gradient_checkpointing: true
|
| 372 |
+
enable_activation_offload: false
|
| 373 |
+
use_remove_padding: false
|
| 374 |
+
lora_rank: 0
|
| 375 |
+
lora_alpha: 16
|
| 376 |
+
target_modules: all-linear
|
| 377 |
+
_target_: verl.workers.config.FSDPCriticConfig
|
| 378 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 379 |
+
strategy: fsdp
|
| 380 |
+
enable: null
|
| 381 |
+
ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}
|
| 382 |
+
ppo_micro_batch_size: null
|
| 383 |
+
ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}
|
| 384 |
+
use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
| 385 |
+
ppo_max_token_len_per_gpu: 32768
|
| 386 |
+
forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}
|
| 387 |
+
ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}
|
| 388 |
+
shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}
|
| 389 |
+
cliprange_value: 0.5
|
| 390 |
+
loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
|
| 391 |
+
checkpoint:
|
| 392 |
+
_target_: verl.trainer.config.CheckpointConfig
|
| 393 |
+
save_contents:
|
| 394 |
+
- model
|
| 395 |
+
- optimizer
|
| 396 |
+
- extra
|
| 397 |
+
load_contents: ${.save_contents}
|
| 398 |
+
async_save: false
|
| 399 |
+
profiler:
|
| 400 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 401 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 402 |
+
enable: false
|
| 403 |
+
all_ranks: false
|
| 404 |
+
ranks: []
|
| 405 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 406 |
+
tool_config:
|
| 407 |
+
nsys:
|
| 408 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 409 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 410 |
+
npu:
|
| 411 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 412 |
+
contents: []
|
| 413 |
+
level: level1
|
| 414 |
+
analysis: true
|
| 415 |
+
discrete: false
|
| 416 |
+
torch:
|
| 417 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 418 |
+
step_start: 0
|
| 419 |
+
step_end: null
|
| 420 |
+
torch_memory:
|
| 421 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 422 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 423 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
| 424 |
+
forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}
|
| 425 |
+
forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}
|
| 426 |
+
ulysses_sequence_parallel_size: 1
|
| 427 |
+
grad_clip: 1.0
|
| 428 |
+
reward_model:
|
| 429 |
+
enable: false
|
| 430 |
+
enable_resource_pool: false
|
| 431 |
+
n_gpus_per_node: 0
|
| 432 |
+
nnodes: 0
|
| 433 |
+
strategy: fsdp
|
| 434 |
+
model:
|
| 435 |
+
input_tokenizer: ${actor_rollout_ref.model.path}
|
| 436 |
+
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
| 437 |
+
external_lib: ${actor_rollout_ref.model.external_lib}
|
| 438 |
+
trust_remote_code: false
|
| 439 |
+
use_shm: false
|
| 440 |
+
use_remove_padding: false
|
| 441 |
+
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
|
| 442 |
+
fsdp_config:
|
| 443 |
+
_target_: verl.workers.config.FSDPEngineConfig
|
| 444 |
+
wrap_policy:
|
| 445 |
+
min_num_params: 0
|
| 446 |
+
param_offload: false
|
| 447 |
+
reshard_after_forward: true
|
| 448 |
+
fsdp_size: -1
|
| 449 |
+
forward_prefetch: false
|
| 450 |
+
micro_batch_size: null
|
| 451 |
+
micro_batch_size_per_gpu: null
|
| 452 |
+
max_length: null
|
| 453 |
+
use_dynamic_bsz: ${critic.use_dynamic_bsz}
|
| 454 |
+
forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
|
| 455 |
+
reward_manager: naive
|
| 456 |
+
launch_reward_fn_async: false
|
| 457 |
+
sandbox_fusion:
|
| 458 |
+
url: null
|
| 459 |
+
max_concurrent: 64
|
| 460 |
+
memory_limit_mb: 1024
|
| 461 |
+
profiler:
|
| 462 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 463 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 464 |
+
enable: false
|
| 465 |
+
all_ranks: false
|
| 466 |
+
ranks: []
|
| 467 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 468 |
+
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
|
| 469 |
+
ulysses_sequence_parallel_size: 1
|
| 470 |
+
algorithm:
|
| 471 |
+
rollout_correction:
|
| 472 |
+
rollout_is: null
|
| 473 |
+
rollout_is_threshold: 2.0
|
| 474 |
+
rollout_rs: null
|
| 475 |
+
rollout_rs_threshold: null
|
| 476 |
+
rollout_rs_threshold_lower: null
|
| 477 |
+
rollout_token_veto_threshold: null
|
| 478 |
+
bypass_mode: false
|
| 479 |
+
use_policy_gradient: false
|
| 480 |
+
rollout_is_batch_normalize: false
|
| 481 |
+
_target_: verl.trainer.config.AlgoConfig
|
| 482 |
+
gamma: 1.0
|
| 483 |
+
lam: 1.0
|
| 484 |
+
adv_estimator: gae
|
| 485 |
+
norm_adv_by_std_in_grpo: true
|
| 486 |
+
use_kl_in_reward: false
|
| 487 |
+
kl_penalty: kl
|
| 488 |
+
kl_ctrl:
|
| 489 |
+
_target_: verl.trainer.config.KLControlConfig
|
| 490 |
+
type: fixed
|
| 491 |
+
kl_coef: 0.001
|
| 492 |
+
horizon: 10000
|
| 493 |
+
target_kl: 0.1
|
| 494 |
+
use_pf_ppo: false
|
| 495 |
+
pf_ppo:
|
| 496 |
+
reweight_method: pow
|
| 497 |
+
weight_pow: 2.0
|
| 498 |
+
custom_reward_function:
|
| 499 |
+
path: null
|
| 500 |
+
name: compute_score
|
| 501 |
+
trainer:
|
| 502 |
+
balance_batch: true
|
| 503 |
+
total_epochs: 30
|
| 504 |
+
total_training_steps: null
|
| 505 |
+
project_name: verl_examples
|
| 506 |
+
experiment_name: gsm8k
|
| 507 |
+
logger:
|
| 508 |
+
- console
|
| 509 |
+
- wandb
|
| 510 |
+
log_val_generations: 0
|
| 511 |
+
rollout_data_dir: null
|
| 512 |
+
validation_data_dir: null
|
| 513 |
+
nnodes: 1
|
| 514 |
+
n_gpus_per_node: 8
|
| 515 |
+
save_freq: -1
|
| 516 |
+
esi_redundant_time: 0
|
| 517 |
+
resume_mode: auto
|
| 518 |
+
resume_from_path: null
|
| 519 |
+
val_before_train: true
|
| 520 |
+
val_only: false
|
| 521 |
+
test_freq: -1
|
| 522 |
+
critic_warmup: 0
|
| 523 |
+
default_hdfs_dir: null
|
| 524 |
+
del_local_ckpt_after_load: false
|
| 525 |
+
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
|
| 526 |
+
max_actor_ckpt_to_keep: null
|
| 527 |
+
max_critic_ckpt_to_keep: null
|
| 528 |
+
ray_wait_register_center_timeout: 300
|
| 529 |
+
device: cuda
|
| 530 |
+
use_legacy_worker_impl: auto
|
| 531 |
+
global_profiler:
|
| 532 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 533 |
+
tool: null
|
| 534 |
+
steps: null
|
| 535 |
+
profile_continuous_steps: false
|
| 536 |
+
save_path: outputs/profile
|
| 537 |
+
global_tool_config:
|
| 538 |
+
nsys:
|
| 539 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 540 |
+
discrete: false
|
| 541 |
+
controller_nsight_options:
|
| 542 |
+
trace: cuda,nvtx,cublas,ucx
|
| 543 |
+
cuda-memory-usage: 'true'
|
| 544 |
+
cuda-graph-trace: graph
|
| 545 |
+
worker_nsight_options:
|
| 546 |
+
trace: cuda,nvtx,cublas,ucx
|
| 547 |
+
cuda-memory-usage: 'true'
|
| 548 |
+
cuda-graph-trace: graph
|
| 549 |
+
capture-range: cudaProfilerApi
|
| 550 |
+
capture-range-end: null
|
| 551 |
+
kill: none
|
| 552 |
+
torch_memory:
|
| 553 |
+
trace_alloc_max_entries: 100000
|
| 554 |
+
stack_depth: 32
|
| 555 |
+
context: all
|
| 556 |
+
stacks: all
|
| 557 |
+
kw_args: {}
|
| 558 |
+
transfer_queue:
|
| 559 |
+
enable: false
|
| 560 |
+
ray_kwargs:
|
| 561 |
+
ray_init:
|
| 562 |
+
num_cpus: null
|
| 563 |
+
timeline_json_file: null
|
ICL/DAPO/verl-recipe/rep_exp/config/actor/actor.yaml
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Format checks enforced on CI:
|
| 2 |
+
# 1. Comments must appear above each field.
|
| 3 |
+
# 2. There must be a blank line between each field.
|
| 4 |
+
# 3. Inline comments (after a field on the same line) are not allowed.
|
| 5 |
+
# 4. Indentation level is respected for nested fields.
|
| 6 |
+
|
| 7 |
+
# Target class for this configuration
|
| 8 |
+
_target_: verl.workers.config.ActorConfig
|
| 9 |
+
|
| 10 |
+
# Number of rollouts per update (mirrors actor rollout_n)
|
| 11 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 12 |
+
|
| 13 |
+
# the abstract actor configs
|
| 14 |
+
# fsdp, fsdp2 or megatron. must be set.
|
| 15 |
+
strategy: ???
|
| 16 |
+
|
| 17 |
+
# Split each sample into sub-batches of this size for PPO
|
| 18 |
+
ppo_mini_batch_size: 256
|
| 19 |
+
|
| 20 |
+
# [Deprecated] Global micro batch size
|
| 21 |
+
ppo_micro_batch_size: null
|
| 22 |
+
|
| 23 |
+
# Local per-GPU micro batch size
|
| 24 |
+
ppo_micro_batch_size_per_gpu: null
|
| 25 |
+
|
| 26 |
+
# Whether to automatically adjust batch size at runtime
|
| 27 |
+
# oc.select: the default val for ref.log_prob_use_dynamic_bsz
|
| 28 |
+
use_dynamic_bsz: false
|
| 29 |
+
|
| 30 |
+
# Max tokens per GPU in one PPO batch; affects gradient accumulation
|
| 31 |
+
# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length}
|
| 32 |
+
# oc.select: the default val for ref.log_prob_max_token_len_per_gpu
|
| 33 |
+
ppo_max_token_len_per_gpu: 16384
|
| 34 |
+
|
| 35 |
+
# PPO clip ratio
|
| 36 |
+
clip_ratio: 0.2
|
| 37 |
+
|
| 38 |
+
# Lower bound for asymmetric clipping (used in dual-clip PPO)
|
| 39 |
+
clip_ratio_low: 0.2
|
| 40 |
+
|
| 41 |
+
# Upper bound for asymmetric clipping (used in dual-clip PPO)
|
| 42 |
+
clip_ratio_high: 0.2
|
| 43 |
+
|
| 44 |
+
# Whether to freeze vision model, if set true, it will be freeze vision model
|
| 45 |
+
freeze_vision_tower: false
|
| 46 |
+
|
| 47 |
+
# policy loss config
|
| 48 |
+
policy_loss:
|
| 49 |
+
|
| 50 |
+
# # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 51 |
+
_target_: verl.workers.config.PolicyLossConfig
|
| 52 |
+
|
| 53 |
+
# Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617
|
| 54 |
+
loss_mode: "vanilla"
|
| 55 |
+
|
| 56 |
+
# Ratio of tokens to be clipped for clip-cov loss
|
| 57 |
+
clip_cov_ratio: 0.0002
|
| 58 |
+
|
| 59 |
+
# Lower bound for clip-cov loss
|
| 60 |
+
clip_cov_lb: 1.0
|
| 61 |
+
|
| 62 |
+
# Upper bound for clip-cov loss
|
| 63 |
+
clip_cov_ub: 5.0
|
| 64 |
+
|
| 65 |
+
# Ratio of tokens to be applied kl penalty for kl-cov loss
|
| 66 |
+
kl_cov_ratio: 0.0002
|
| 67 |
+
|
| 68 |
+
# KL divergence penalty coefficient
|
| 69 |
+
ppo_kl_coef: 0.1
|
| 70 |
+
|
| 71 |
+
# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C
|
| 72 |
+
clip_ratio_c: 3.0
|
| 73 |
+
|
| 74 |
+
# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean"
|
| 75 |
+
loss_agg_mode: token-mean
|
| 76 |
+
|
| 77 |
+
# Entropy regularization coefficient in PPO loss
|
| 78 |
+
entropy_coeff: 0
|
| 79 |
+
|
| 80 |
+
# Whether to use KL loss instead of KL reward penalty. True for GRPO
|
| 81 |
+
use_kl_loss: false
|
| 82 |
+
|
| 83 |
+
# Whether to use torch.compile()
|
| 84 |
+
# oc.select: the default val for ref.use_torch_compile
|
| 85 |
+
use_torch_compile: true
|
| 86 |
+
|
| 87 |
+
# KL loss coefficient when use_kl_loss is enabled. For GRPO
|
| 88 |
+
kl_loss_coef: 0.001
|
| 89 |
+
|
| 90 |
+
# Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full"
|
| 91 |
+
kl_loss_type: low_var_kl
|
| 92 |
+
|
| 93 |
+
# Number of PPO epochs per batch
|
| 94 |
+
ppo_epochs: 1
|
| 95 |
+
|
| 96 |
+
# Shuffle training data across PPO epochs
|
| 97 |
+
shuffle: false
|
| 98 |
+
|
| 99 |
+
# checkpoint configs
|
| 100 |
+
checkpoint:
|
| 101 |
+
|
| 102 |
+
# Target dataclass for this configuration
|
| 103 |
+
_target_: verl.trainer.config.CheckpointConfig
|
| 104 |
+
|
| 105 |
+
# What to include in saved checkpoints
|
| 106 |
+
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
| 107 |
+
save_contents: ['model', 'optimizer', 'extra']
|
| 108 |
+
|
| 109 |
+
# For more flexibility, you can specify the contents to load from the checkpoint.
|
| 110 |
+
# .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg
|
| 111 |
+
load_contents: ${.save_contents}
|
| 112 |
+
|
| 113 |
+
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
|
| 114 |
+
async_save: False
|
| 115 |
+
|
| 116 |
+
# optimizer configs
|
| 117 |
+
optim:
|
| 118 |
+
|
| 119 |
+
# Learning rate
|
| 120 |
+
lr: 1e-6
|
| 121 |
+
|
| 122 |
+
# Warmup steps ratio (used if lr_warmup_steps is 0 or negative)
|
| 123 |
+
lr_warmup_steps_ratio: 0.0
|
| 124 |
+
|
| 125 |
+
# Total training steps (must be overridden at runtime)
|
| 126 |
+
total_training_steps: -1
|
| 127 |
+
|
| 128 |
+
# Weight decay
|
| 129 |
+
weight_decay: 0.01
|
| 130 |
+
|
| 131 |
+
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
| 132 |
+
lr_warmup_steps: -1
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Whether to use custom fused kernels (e.g., FlashAttention, fused MLP)
|
| 136 |
+
use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false}
|
| 137 |
+
|
| 138 |
+
# profile the actor model in `update_policy`
|
| 139 |
+
profiler:
|
| 140 |
+
|
| 141 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 142 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 143 |
+
|
| 144 |
+
# profiler tool, default same as profiler.tool in global config
|
| 145 |
+
# choices: nsys, npu, torch
|
| 146 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 147 |
+
|
| 148 |
+
# whether enable profile on Actor
|
| 149 |
+
enable: False
|
| 150 |
+
|
| 151 |
+
# Whether to profile all ranks.
|
| 152 |
+
all_ranks: False
|
| 153 |
+
|
| 154 |
+
# The ranks that will be profiled. [] or [0,1,...]
|
| 155 |
+
ranks: []
|
| 156 |
+
|
| 157 |
+
# profile results saving path
|
| 158 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 159 |
+
|
| 160 |
+
# specific tool config which only related to the role
|
| 161 |
+
tool_config:
|
| 162 |
+
|
| 163 |
+
# nsys tool config
|
| 164 |
+
nsys:
|
| 165 |
+
|
| 166 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 167 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 168 |
+
|
| 169 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 170 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 171 |
+
|
| 172 |
+
# npu config
|
| 173 |
+
npu:
|
| 174 |
+
|
| 175 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 176 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 177 |
+
|
| 178 |
+
# Contents to profile, can be empty
|
| 179 |
+
# options: npu, cpu, memory, shapes, module, stack
|
| 180 |
+
contents: []
|
| 181 |
+
|
| 182 |
+
# Collection level, optional values: level_none, level0, level1, level2.
|
| 183 |
+
level: "level1"
|
| 184 |
+
|
| 185 |
+
# Whether to automatically parse the data.
|
| 186 |
+
analysis: True
|
| 187 |
+
|
| 188 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 189 |
+
discrete: False
|
| 190 |
+
|
| 191 |
+
# torch profiler config
|
| 192 |
+
torch:
|
| 193 |
+
|
| 194 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 195 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 196 |
+
|
| 197 |
+
# start profile mini-batch in training
|
| 198 |
+
# NOTICE: different with global steps config which refers to iteration
|
| 199 |
+
# This field only related with mini-batch
|
| 200 |
+
step_start: 0
|
| 201 |
+
|
| 202 |
+
# stop profile mini-batch in training
|
| 203 |
+
step_end: null
|
| 204 |
+
|
| 205 |
+
# torch memory profiler config
|
| 206 |
+
torch_memory:
|
| 207 |
+
|
| 208 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 209 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 210 |
+
|
| 211 |
+
# Maximum number of memory allocation entries to track
|
| 212 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 213 |
+
|
| 214 |
+
# Stack trace depth for memory allocations
|
| 215 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
ICL/DAPO/verl-recipe/rep_exp/config/actor/dp_actor.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Format checks enforced on CI:
|
| 2 |
+
# 1. Comments must appear above each field.
|
| 3 |
+
# 2. There must be a blank line between each field.
|
| 4 |
+
# 3. Inline comments (after a field on the same line) are not allowed.
|
| 5 |
+
# 4. Indentation level is respected for nested fields.
|
| 6 |
+
|
| 7 |
+
# defaults specify the default config from each component
|
| 8 |
+
defaults:
|
| 9 |
+
|
| 10 |
+
# fsdp optimizer config
|
| 11 |
+
- ../optim@optim: fsdp
|
| 12 |
+
|
| 13 |
+
# fsdp engine config
|
| 14 |
+
- ../engine@fsdp_config: fsdp
|
| 15 |
+
|
| 16 |
+
# dp actor config, inheriting from trainer/config/actor/actor.yaml
|
| 17 |
+
- actor
|
| 18 |
+
|
| 19 |
+
# load the reference default config, then apply the fields in the current yaml
|
| 20 |
+
- _self_
|
| 21 |
+
|
| 22 |
+
# Target class for this configuration
|
| 23 |
+
_target_: verl.workers.config.FSDPActorConfig
|
| 24 |
+
|
| 25 |
+
# TODO(haibin.lin): switch to fsdp2
|
| 26 |
+
strategy: fsdp
|
| 27 |
+
|
| 28 |
+
# Gradient clipping for actor updates, specific to the strategy.
|
| 29 |
+
grad_clip: 1.0
|
| 30 |
+
|
| 31 |
+
# Sequence parallelism size for Ulysses-style model parallelism
|
| 32 |
+
# oc.select: the default val for ref.ulysses_sequence_parallel_size
|
| 33 |
+
# [DEPRECATED] use fsdp_config.ulysses_sequence_parallel_size instead
|
| 34 |
+
ulysses_sequence_parallel_size: 1
|
| 35 |
+
|
| 36 |
+
# calculate entropy with chunking to reduce memory peak
|
| 37 |
+
entropy_from_logits_with_chunking: False
|
| 38 |
+
|
| 39 |
+
# recompute entropy
|
| 40 |
+
entropy_checkpointing: False
|
| 41 |
+
|
| 42 |
+
# Whether to remove padding tokens in inputs during training
|
| 43 |
+
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
|
ICL/DAPO/verl-recipe/rep_exp/config/actor/megatron_actor.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# megatron actor config, inheriting from trainer/config/actor/actor.yaml
|
| 2 |
+
defaults:
|
| 3 |
+
# megatron optimizer config
|
| 4 |
+
- ../optim@optim: megatron
|
| 5 |
+
|
| 6 |
+
# megatron engine config
|
| 7 |
+
- ../engine@megatron: megatron
|
| 8 |
+
|
| 9 |
+
- actor
|
| 10 |
+
|
| 11 |
+
# load the reference default config, then apply the fields in the current yaml
|
| 12 |
+
- _self_
|
| 13 |
+
|
| 14 |
+
_target_: verl.workers.config.McoreActorConfig
|
| 15 |
+
|
| 16 |
+
strategy: megatron
|
| 17 |
+
|
| 18 |
+
data_loader_seed: 42
|
| 19 |
+
|
| 20 |
+
load_weight: True
|
ICL/DAPO/verl-recipe/rep_exp/config/algorithm/rollout_correction.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Rollout Correction: corrects off-policy distribution shifts
|
| 2 |
+
# See documentation: docs/algo/rollout_corr.md
|
| 3 |
+
# Use presets: RolloutCorrectionConfig.decoupled_seq_is(), .pg_is(), etc.
|
| 4 |
+
|
| 5 |
+
# IS aggregation level: null (disabled), "token" (per-token), "sequence" (per-sequence)
|
| 6 |
+
rollout_is: null
|
| 7 |
+
|
| 8 |
+
# Upper threshold for IS weight truncation (typical: 2.0-5.0)
|
| 9 |
+
rollout_is_threshold: 2.0
|
| 10 |
+
|
| 11 |
+
# RS aggregation level: null (disabled), "token", "sequence", "geometric"
|
| 12 |
+
rollout_rs: null
|
| 13 |
+
|
| 14 |
+
# Upper threshold for rejection sampling (null = use rollout_is_threshold)
|
| 15 |
+
rollout_rs_threshold: null
|
| 16 |
+
|
| 17 |
+
# Lower threshold for rejection sampling (null = auto-compute as 1/upper)
|
| 18 |
+
rollout_rs_threshold_lower: null
|
| 19 |
+
|
| 20 |
+
# Per-token veto threshold for catastrophic outliers (null = disabled)
|
| 21 |
+
rollout_token_veto_threshold: null
|
| 22 |
+
|
| 23 |
+
# Operating mode: false = Decoupled (3 policies), true = Bypass (2 policies)
|
| 24 |
+
bypass_mode: false
|
| 25 |
+
|
| 26 |
+
# Loss function: false = PPO with clipping, true = Policy gradient (no clipping)
|
| 27 |
+
use_policy_gradient: false
|
| 28 |
+
|
| 29 |
+
# Batch normalize IS weights: false = raw weights, true = normalize to mean=1.0
|
| 30 |
+
rollout_is_batch_normalize: false
|
ICL/DAPO/verl-recipe/rep_exp/config/critic/critic.yaml
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 2 |
+
_target_: verl.workers.config.CriticConfig
|
| 3 |
+
|
| 4 |
+
# Number of rollouts per update (mirrors actor rollout_n)
|
| 5 |
+
rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1}
|
| 6 |
+
|
| 7 |
+
# fsdp or fsdp2 strategy used for critic model training
|
| 8 |
+
strategy: ???
|
| 9 |
+
|
| 10 |
+
# whether to enable the critic worker.
|
| 11 |
+
# by default it is only enabled if advantage estimator is gae
|
| 12 |
+
# set it to True manually if you always want to enable critic worker
|
| 13 |
+
enable: null
|
| 14 |
+
|
| 15 |
+
# optimizer configs
|
| 16 |
+
optim:
|
| 17 |
+
|
| 18 |
+
# Learning rate
|
| 19 |
+
lr: 1e-5
|
| 20 |
+
|
| 21 |
+
# Warmup steps ratio; total steps will be injected at runtime
|
| 22 |
+
lr_warmup_steps_ratio: 0.0
|
| 23 |
+
|
| 24 |
+
# Total training steps (must be overridden at runtime)
|
| 25 |
+
total_training_steps: -1
|
| 26 |
+
|
| 27 |
+
# Weight decay
|
| 28 |
+
weight_decay: 0.01
|
| 29 |
+
|
| 30 |
+
# Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio.
|
| 31 |
+
lr_warmup_steps: -1
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# model config for the critic
|
| 35 |
+
model:
|
| 36 |
+
|
| 37 |
+
# Path to pretrained model weights
|
| 38 |
+
path: ~/models/deepseek-llm-7b-chat
|
| 39 |
+
|
| 40 |
+
# Tokenizer path (defaults to actor's model path)
|
| 41 |
+
tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"}
|
| 42 |
+
|
| 43 |
+
# Hugging Face config override
|
| 44 |
+
override_config: {}
|
| 45 |
+
|
| 46 |
+
# External model implementation (optional)
|
| 47 |
+
external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null}
|
| 48 |
+
|
| 49 |
+
# Whether to trust remote code from Hugging Face models
|
| 50 |
+
trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false}
|
| 51 |
+
|
| 52 |
+
# PPO mini-batch size per update
|
| 53 |
+
ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256}
|
| 54 |
+
|
| 55 |
+
# [Deprecated] Global micro batch size
|
| 56 |
+
ppo_micro_batch_size: null
|
| 57 |
+
|
| 58 |
+
# Local per-GPU micro batch size
|
| 59 |
+
ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null}
|
| 60 |
+
|
| 61 |
+
# Whether to automatically adjust batch size at runtime
|
| 62 |
+
use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false}
|
| 63 |
+
|
| 64 |
+
# Max tokens per GPU in one PPO batch (doubled for critic)
|
| 65 |
+
ppo_max_token_len_per_gpu: 32768
|
| 66 |
+
|
| 67 |
+
# Max token length per GPU in forward pass
|
| 68 |
+
forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu}
|
| 69 |
+
|
| 70 |
+
# Number of PPO epochs per batch
|
| 71 |
+
ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1}
|
| 72 |
+
|
| 73 |
+
# Shuffle training data across PPO epochs
|
| 74 |
+
shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false}
|
| 75 |
+
|
| 76 |
+
# PPO value function clipping range
|
| 77 |
+
cliprange_value: 0.5
|
| 78 |
+
|
| 79 |
+
# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean"
|
| 80 |
+
loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean}
|
| 81 |
+
|
| 82 |
+
# checkpoint configs
|
| 83 |
+
checkpoint:
|
| 84 |
+
|
| 85 |
+
# Target dataclass for this configuration
|
| 86 |
+
_target_: verl.trainer.config.CheckpointConfig
|
| 87 |
+
|
| 88 |
+
# What to include in saved checkpoints
|
| 89 |
+
# with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
|
| 90 |
+
save_contents: ['model', 'optimizer', 'extra']
|
| 91 |
+
|
| 92 |
+
# What to include when loading checkpoints
|
| 93 |
+
load_contents: ${.save_contents}
|
| 94 |
+
|
| 95 |
+
# Whether to save checkpoints asynchronously. Only effective for Megatron as of now.
|
| 96 |
+
async_save: False
|
| 97 |
+
|
| 98 |
+
# profile the critic model in `update_critic`
|
| 99 |
+
profiler:
|
| 100 |
+
|
| 101 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 102 |
+
_target_: verl.utils.profiler.ProfilerConfig
|
| 103 |
+
|
| 104 |
+
# profiler tool, default same as profiler.tool in global config
|
| 105 |
+
# choices: nsys, npu, torch, torch_memory
|
| 106 |
+
tool: ${oc.select:global_profiler.tool,null}
|
| 107 |
+
|
| 108 |
+
# whether enable profile on Critic
|
| 109 |
+
enable: False
|
| 110 |
+
|
| 111 |
+
# Whether to profile all ranks.
|
| 112 |
+
all_ranks: False
|
| 113 |
+
|
| 114 |
+
# The ranks that will be profiled. [] or [0,1,...]
|
| 115 |
+
ranks: []
|
| 116 |
+
|
| 117 |
+
# profile results saving path
|
| 118 |
+
save_path: ${oc.select:global_profiler.save_path,null}
|
| 119 |
+
|
| 120 |
+
# specific tool config which only related to the role
|
| 121 |
+
tool_config:
|
| 122 |
+
|
| 123 |
+
# nsys tool config
|
| 124 |
+
nsys:
|
| 125 |
+
|
| 126 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 127 |
+
_target_: verl.utils.profiler.config.NsightToolConfig
|
| 128 |
+
|
| 129 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 130 |
+
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
|
| 131 |
+
|
| 132 |
+
# npu config
|
| 133 |
+
npu:
|
| 134 |
+
|
| 135 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 136 |
+
_target_: verl.utils.profiler.config.NPUToolConfig
|
| 137 |
+
|
| 138 |
+
# Contents to profile, can be empty
|
| 139 |
+
# options: npu, cpu, memory, shapes, module, stack
|
| 140 |
+
contents: []
|
| 141 |
+
|
| 142 |
+
# Collection level, optional values: level_none, level0, level1, level2.
|
| 143 |
+
level: "level1"
|
| 144 |
+
|
| 145 |
+
# Whether to automatically parse the data.
|
| 146 |
+
analysis: True
|
| 147 |
+
|
| 148 |
+
# True for each task has its own database, False for all tasks in one training step share one database.
|
| 149 |
+
discrete: False
|
| 150 |
+
|
| 151 |
+
# torch profiler config
|
| 152 |
+
torch:
|
| 153 |
+
|
| 154 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 155 |
+
_target_: verl.utils.profiler.config.TorchProfilerToolConfig
|
| 156 |
+
|
| 157 |
+
# start profile mini-batch in training
|
| 158 |
+
# NOTICE: different with global steps config which refers to iteration
|
| 159 |
+
# This field only related with mini-batch
|
| 160 |
+
step_start: 0
|
| 161 |
+
|
| 162 |
+
# stop profile mini-batch in training
|
| 163 |
+
step_end: null
|
| 164 |
+
|
| 165 |
+
# torch memory profiler config
|
| 166 |
+
torch_memory:
|
| 167 |
+
|
| 168 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 169 |
+
_target_: verl.utils.profiler.config.TorchMemoryToolConfig
|
| 170 |
+
|
| 171 |
+
# Maximum number of memory allocation entries to track
|
| 172 |
+
trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000}
|
| 173 |
+
|
| 174 |
+
# Stack trace depth for memory allocations
|
| 175 |
+
stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32}
|
| 176 |
+
|
ICL/DAPO/verl-recipe/rep_exp/config/critic/dp_critic.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Format checks enforced on CI:
|
| 2 |
+
# 1. Comments must appear above each field.
|
| 3 |
+
# 2. There must be a blank line between each field.
|
| 4 |
+
# 3. Inline comments (after a field on the same line) are not allowed.
|
| 5 |
+
# 4. Indentation level is respected for nested fields.
|
| 6 |
+
|
| 7 |
+
# defaults specify the default config from each component
|
| 8 |
+
defaults:
|
| 9 |
+
|
| 10 |
+
# fsdp optimizer config
|
| 11 |
+
- ../optim@optim: fsdp
|
| 12 |
+
|
| 13 |
+
# fsdp engine config
|
| 14 |
+
- ../engine@model.fsdp_config: fsdp
|
| 15 |
+
|
| 16 |
+
# dp actor config, inheriting from trainer/config/critic/critic.yaml
|
| 17 |
+
- critic
|
| 18 |
+
|
| 19 |
+
# load the reference default config, then apply the fields in the current yaml
|
| 20 |
+
- _self_
|
| 21 |
+
|
| 22 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 23 |
+
_target_: verl.workers.config.FSDPCriticConfig
|
| 24 |
+
|
| 25 |
+
# distribution strategy. Options: fsdp (deprecating), fsdp2
|
| 26 |
+
strategy: fsdp
|
| 27 |
+
|
| 28 |
+
# model config for the critic
|
| 29 |
+
model:
|
| 30 |
+
|
| 31 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 32 |
+
_target_: verl.workers.config.FSDPCriticModelCfg
|
| 33 |
+
|
| 34 |
+
# Whether to use shared memory for loading the model
|
| 35 |
+
use_shm: False
|
| 36 |
+
|
| 37 |
+
# Enable gradient checkpointing to save memory
|
| 38 |
+
enable_gradient_checkpointing: True
|
| 39 |
+
|
| 40 |
+
# Offload activations to CPU to reduce GPU memory usage
|
| 41 |
+
enable_activation_offload: False
|
| 42 |
+
|
| 43 |
+
# Use remove padding optimization (saves compute)
|
| 44 |
+
use_remove_padding: False
|
| 45 |
+
|
| 46 |
+
# Set to positive value to enable LoRA (e.g., 32)
|
| 47 |
+
lora_rank: 0
|
| 48 |
+
|
| 49 |
+
# LoRA scaling factor
|
| 50 |
+
lora_alpha: 16
|
| 51 |
+
|
| 52 |
+
# LoRA target modules: "all-linear" or list of linear projection layers
|
| 53 |
+
target_modules: all-linear
|
| 54 |
+
|
| 55 |
+
# Forward-only batch size during inference (global)
|
| 56 |
+
forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}
|
| 57 |
+
|
| 58 |
+
# Forward-only batch size during inference (per GPU)
|
| 59 |
+
forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}
|
| 60 |
+
|
| 61 |
+
# Sequence parallelism size for Ulysses-style model parallelism
|
| 62 |
+
# [DEPRECATED] use fsdp_config.ulysses_sequence_parallel_size instead
|
| 63 |
+
ulysses_sequence_parallel_size: 1
|
| 64 |
+
|
| 65 |
+
# Gradient clipping for critic updates
|
| 66 |
+
grad_clip: 1.0
|
ICL/DAPO/verl-recipe/rep_exp/config/critic/megatron_critic.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# defaults specify the default config from each component
|
| 2 |
+
defaults:
|
| 3 |
+
|
| 4 |
+
# megatron optimizer config
|
| 5 |
+
- ../optim@optim: megatron
|
| 6 |
+
|
| 7 |
+
# megatron engine config
|
| 8 |
+
- ../engine@megatron: megatron
|
| 9 |
+
|
| 10 |
+
# dp actor config, inheriting from trainer/config/critic/critic.yaml
|
| 11 |
+
- critic
|
| 12 |
+
|
| 13 |
+
# load the reference default config, then apply the fields in the current yaml
|
| 14 |
+
- _self_
|
| 15 |
+
|
| 16 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 17 |
+
_target_: verl.workers.config.McoreCriticConfig
|
| 18 |
+
|
| 19 |
+
strategy: megatron
|
| 20 |
+
|
| 21 |
+
# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron
|
| 22 |
+
nccl_timeout: 600
|
| 23 |
+
|
| 24 |
+
# model config for the critic
|
| 25 |
+
model:
|
| 26 |
+
|
| 27 |
+
# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs
|
| 28 |
+
_target_: verl.trainer.config.BaseModelConfig
|
| 29 |
+
|
| 30 |
+
# override default empty mapping
|
| 31 |
+
override_config:
|
| 32 |
+
|
| 33 |
+
model_config: {}
|
| 34 |
+
|
| 35 |
+
moe_config:
|
| 36 |
+
|
| 37 |
+
freeze_moe_router: False
|
| 38 |
+
|
| 39 |
+
# Whether to load initial weights
|
| 40 |
+
load_weight: True
|
| 41 |
+
|
| 42 |
+
# seed for data loader
|
| 43 |
+
data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null}
|
ICL/DAPO/verl-recipe/rep_exp/config/data/legacy_data.yaml
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tokenizer class or path. If null, it will be inferred from the model.
|
| 2 |
+
tokenizer: null
|
| 3 |
+
|
| 4 |
+
# Whether to use shared memory for data loading.
|
| 5 |
+
use_shm: False
|
| 6 |
+
|
| 7 |
+
# Training set parquet. Can be a list or a single file.
|
| 8 |
+
# The program will read all files into memory, so it can't be too large (< 100GB).
|
| 9 |
+
# The path can be either a local path or an HDFS path.
|
| 10 |
+
# For HDFS path, we provide utils to download it to DRAM and convert it to a local path.
|
| 11 |
+
train_files: ~/data/rlhf/gsm8k/train.parquet
|
| 12 |
+
|
| 13 |
+
# Validation parquet. Can be a list or a single file.
|
| 14 |
+
val_files: ~/data/rlhf/gsm8k/test.parquet
|
| 15 |
+
|
| 16 |
+
# Maximum sample length to be used.
|
| 17 |
+
# Set to -1 to use full dataset, otherwise, randomly
|
| 18 |
+
# select the specified number of samples from train dataset
|
| 19 |
+
train_max_samples: -1
|
| 20 |
+
|
| 21 |
+
# Maximum sample length to be used.
|
| 22 |
+
# Set to -1 to use full dataset, otherwise, randomly
|
| 23 |
+
# select the specified number of samples from val dataset
|
| 24 |
+
val_max_samples: -1
|
| 25 |
+
|
| 26 |
+
# The field in the dataset where the prompt is located. Default is 'prompt'.
|
| 27 |
+
prompt_key: prompt
|
| 28 |
+
|
| 29 |
+
# The field used to select the reward function (if using different ones per example).
|
| 30 |
+
reward_fn_key: data_source
|
| 31 |
+
|
| 32 |
+
# Maximum prompt length. All prompts will be left-padded to this length.
|
| 33 |
+
# An error will be reported if the length is too long.
|
| 34 |
+
# oc.select: default val for rollout.prompt_length
|
| 35 |
+
max_prompt_length: 512
|
| 36 |
+
|
| 37 |
+
# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length.
|
| 38 |
+
# oc.select: default val for rollout.response_length
|
| 39 |
+
max_response_length: 512
|
| 40 |
+
|
| 41 |
+
# Batch size sampled for one training iteration of different RL algorithms.
|
| 42 |
+
train_batch_size: 1024
|
| 43 |
+
|
| 44 |
+
# Batch size used during validation. Can be null.
|
| 45 |
+
val_batch_size: null
|
| 46 |
+
|
| 47 |
+
# use tool config to calculate true prompt length
|
| 48 |
+
tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, null}
|
| 49 |
+
|
| 50 |
+
# Whether to return the original input_ids without adding chat template.
|
| 51 |
+
# This is used when the reward model's chat template differs from the policy.
|
| 52 |
+
# If using a model-based RM with different templates, this should be True.
|
| 53 |
+
return_raw_input_ids: False
|
| 54 |
+
|
| 55 |
+
# Whether to return the original chat (prompt) without applying chat template.
|
| 56 |
+
return_raw_chat: True
|
| 57 |
+
|
| 58 |
+
# Whether to return the full prompt with chat template.
|
| 59 |
+
return_full_prompt: False
|
| 60 |
+
|
| 61 |
+
# Whether to shuffle the data in the dataloader.
|
| 62 |
+
shuffle: True
|
| 63 |
+
|
| 64 |
+
# Seed to use when shuffling the data
|
| 65 |
+
seed: null
|
| 66 |
+
|
| 67 |
+
# num dataloader workers
|
| 68 |
+
dataloader_num_workers: 8
|
| 69 |
+
|
| 70 |
+
# image patch size
|
| 71 |
+
image_patch_size: 14
|
| 72 |
+
|
| 73 |
+
# Whether to shuffle the validation set.
|
| 74 |
+
validation_shuffle: False
|
| 75 |
+
|
| 76 |
+
# Whether to filter overlong prompts.
|
| 77 |
+
filter_overlong_prompts: False
|
| 78 |
+
|
| 79 |
+
# Number of workers for filtering overlong prompts.
|
| 80 |
+
# For large-scale datasets, filtering can be time-consuming.
|
| 81 |
+
# Use multiprocessing to speed up. Default is 1.
|
| 82 |
+
filter_overlong_prompts_workers: 1
|
| 83 |
+
|
| 84 |
+
# Truncate the input_ids or prompt if they exceed max_prompt_length.
|
| 85 |
+
# Options: 'error', 'left', 'right', 'middle'. Default is 'error'.
|
| 86 |
+
truncation: error
|
| 87 |
+
|
| 88 |
+
# The field in the multi-modal dataset where the image is located. Default is 'images'.
|
| 89 |
+
image_key: images
|
| 90 |
+
|
| 91 |
+
# The field in the multi-modal dataset where the video is located.
|
| 92 |
+
video_key: videos
|
| 93 |
+
|
| 94 |
+
# If the remote tokenizer has a Python file, this flag determines whether to allow using it.
|
| 95 |
+
trust_remote_code: False
|
| 96 |
+
|
| 97 |
+
# Optional: specify a custom dataset class path and name if overriding default loading behavior.
|
| 98 |
+
custom_cls:
|
| 99 |
+
|
| 100 |
+
# The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used.
|
| 101 |
+
path: null
|
| 102 |
+
|
| 103 |
+
# The name of the dataset class within the specified file.
|
| 104 |
+
name: null
|
| 105 |
+
|
| 106 |
+
# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs.
|
| 107 |
+
return_multi_modal_inputs: True
|
| 108 |
+
|
| 109 |
+
# settings related to data sampler
|
| 110 |
+
sampler:
|
| 111 |
+
|
| 112 |
+
# the path to the module containing a curriculum class which implements the
|
| 113 |
+
# AbstractSampler interface
|
| 114 |
+
class_path: null
|
| 115 |
+
|
| 116 |
+
# the name of the curriculum class like `MySampler`
|
| 117 |
+
class_name: null
|
| 118 |
+
|
| 119 |
+
# Data generation configuration for augmenting the dataset.
|
| 120 |
+
datagen:
|
| 121 |
+
|
| 122 |
+
# The path to the file containing your customized data generation class.
|
| 123 |
+
# E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset'
|
| 124 |
+
path: null
|
| 125 |
+
|
| 126 |
+
# The class name of the data generation class within the specified file.
|
| 127 |
+
# E.g. 'MockDataGenerator'
|
| 128 |
+
name: null
|
| 129 |
+
|
| 130 |
+
# Additional kwargs when calling tokenizer.apply_chat_template
|
| 131 |
+
apply_chat_template_kwargs: {}
|
ICL/DAPO/verl-recipe/rep_exp/config/engine/fsdp.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Target class for this configuration
|
| 2 |
+
_target_: verl.workers.config.FSDPEngineConfig
|
| 3 |
+
|
| 4 |
+
# policy for wrapping the model
|
| 5 |
+
wrap_policy:
|
| 6 |
+
|
| 7 |
+
# Minimum number of parameters to trigger wrapping a layer with FSDP
|
| 8 |
+
min_num_params: 0
|
| 9 |
+
|
| 10 |
+
# Whether to offload model parameters to CPU (trades speed for memory)
|
| 11 |
+
# Note that this differs from the offload_policy in FSDP
|
| 12 |
+
param_offload: false
|
| 13 |
+
|
| 14 |
+
# Whether to offload optimizer state to CPU
|
| 15 |
+
# Note that this differs from the offload_policy in FSDP
|
| 16 |
+
optimizer_offload: false
|
| 17 |
+
|
| 18 |
+
# Only for FSDP2: offload param/grad/optimizer during train
|
| 19 |
+
offload_policy: false
|
| 20 |
+
|
| 21 |
+
# Only for FSDP2: Reshard after forward pass to reduce memory footprint
|
| 22 |
+
reshard_after_forward: true
|
| 23 |
+
|
| 24 |
+
# Number of GPUs in each FSDP shard group; -1 means auto
|
| 25 |
+
fsdp_size: -1
|
| 26 |
+
|
| 27 |
+
# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
|
| 28 |
+
# before the current forward computation.
|
| 29 |
+
forward_prefetch: False
|
| 30 |
+
|
| 31 |
+
# model dtype of fsdp
|
| 32 |
+
model_dtype: fp32
|
| 33 |
+
|
| 34 |
+
# Whether to use original parameters in fsdp. Only avaiable in fsdp1
|
| 35 |
+
use_orig_params: false
|
| 36 |
+
|
| 37 |
+
# ulysses sequence parallel size
|
| 38 |
+
ulysses_sequence_parallel_size: 1
|
| 39 |
+
|
| 40 |
+
# Whether to use entropy_from_logits_with_chunking in fsdp.
|
| 41 |
+
entropy_from_logits_with_chunking: false
|
| 42 |
+
|
| 43 |
+
# Whether to use torch compile in fsdp.
|
| 44 |
+
use_torch_compile: true
|
| 45 |
+
|
| 46 |
+
# Whether to use entropy checkpointing in fsdp.
|
| 47 |
+
entropy_checkpointing: false
|
| 48 |
+
|
| 49 |
+
# Whether to use forward only in fsdp.
|
| 50 |
+
forward_only: false
|
| 51 |
+
|
| 52 |
+
# fsdp or fsdp2
|
| 53 |
+
strategy: fsdp
|
| 54 |
+
|
| 55 |
+
# Mixed precision training param dtype
|
| 56 |
+
dtype: bfloat16 # ["bfloat16", "float16"]
|