π¨ This repo differs from Qwen's PRM. We trained our PRM based on Qwen2.5-Math-7B, while Qwen's PRM is based on Qwen2.5-Math-7B-Instruct.
PURE's PRM based on Qwen2.5-Math-7B
Introduction
Our PRM is used to fine-tune LLM for better math reasoning capability. See our PURE GitHub repo for more details. It is obtained by fine-tuning Qwen2.5-Math-7B on the training set of open-source dataset PRM800K. We choose Qwen2.5-Math-7B instead of Qwen2.5-Math-7B-Instruct to keep the base model consistent with our baselines. We treat the original 1 and 0 labels in PRM800K as our positive labels, while -1 as negative ones. To eliminate test data contamination, we also remove the PRM800K training samples that have the same math queries in MATH test set.
Requirements
transformers>=4.40.0
for Qwen2.5-Math models. The latest version is recommended.
Quick Start
PURE's PRM is a process reward model typically used for offering feedback on the quality of reasoning and intermediate steps rather than generation.
Prerequisites
- Step Separation: We recommend using double line breaks ("\n\n") to separate individual steps within the solution.
- Reward Computation: After each step, we insert a token "
\n
". For reward calculation, we extract the probability score of this token and subtract negative probabilities from positive probabilities, resulting in a reward value between -1 and 1. We regard steps with reward > 0 as correct, otherwise as incorrect.
π€ Hugging Face Transformers
- Here we show a code snippet to show you how to use our PRM with
transformers
:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
def make_step_rewards(logits, token_masks):
all_scores_res = []
for sample, token_mask in zip(logits, token_masks):
# sample: (seq_len, num_labels)
probs = sample[token_mask].softmax(dim=-1) # (num_steps, 2)
process_reward = probs[:, 1] - probs[:, 0] # (num_steps,)
# weighted sum to approx. min, highly recommend when BoN eval and Fine-tuning LLM
# weight = torch.softmax(
# -process_reward / 0.1,
# dim=-1,
# )
# process_reward = weight * process_reward
all_scores_res.append(process_reward.cpu().tolist())
return all_scores_res
model_name = "jinachris/PURE-PRM-7B"
device = "auto"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
model = AutoModelForTokenClassification.from_pretrained(
model_name,
device_map=device,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).eval()
question = "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?"
steps = [
"To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.",
"On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.",
"On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.",
"To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30})."
]
step_separator = "\n"
step_separator_token = tokenizer(
step_separator,
add_special_tokens=False,
return_tensors='pt',
)['input_ids']
input_ids = tokenizer(
question,
add_special_tokens=False,
return_tensors='pt',
)['input_ids']
score_ids = []
for step in steps:
step_ids = tokenizer(
step,
add_special_tokens=False,
return_tensors='pt',
)['input_ids']
input_ids = torch.cat(
[input_ids, step_ids, step_separator_token],
dim=-1,
)
score_ids.append(input_ids.size(-1) - 1)
input_ids = input_ids.to(model.device)
token_masks = torch.zeros_like(input_ids, dtype=torch.bool)
token_masks[0, score_ids] = True
assert torch.all(input_ids[token_masks].to("cpu") == step_separator_token)
logits = model(input_ids).logits
step_reward = make_step_rewards(logits, token_masks)
print(step_reward) # [[0.796875, 0.185546875, -0.0625, 0.078125]]
# For BoN eval,
# uncomment the weighted sum part in `make_step_rewards` func,
# then sum the rewards to get the final score (outcome reward):
# torch.tensor(step_reward).sum(dim=-1)
- Additionally, we share the code for BoN evalution on RLHFlow's data:
import numpy as np
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForTokenClassification, AutoTokenizer
ds_names = ["GSM8K", "MATH500"]
ds = [
load_dataset(
f"RLHFlow/Deepseek-{ds_name}-Test"
)['test'] for ds_name in ds_names
]
def make_step_rewards(logits, token_masks):
all_scores_res = []
for sample, token_mask in zip(logits, token_masks):
# sample: (seq_len, num_labels)
probs = sample[token_mask].softmax(dim=-1) # (num_steps, 2)
process_reward = probs[:, 1] - probs[:, 0] # (num_steps,)
# weighted sum to approx. min, highly recommend when BoN eval and Fine-tuning LLM
weight = torch.softmax(
-process_reward / 0.1,
dim=-1,
)
process_reward = weight * process_reward
all_scores_res.append(process_reward.cpu().tolist())
return all_scores_res
model_name = "jinachris/PURE-PRM-7B"
device = "auto"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
model = AutoModelForTokenClassification.from_pretrained(
model_name,
device_map=device,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).eval()
step_separator = "\n"
step_separator_token = tokenizer(
step_separator,
add_special_tokens=False,
return_tensors='pt',
)['input_ids']
for ds_item, ds_name in zip(ds, ds_names):
# sampled_ids = np.random.choice(range(len(ds_item)), size=100, replace=False)
correct = 0
total = 0
for idx in tqdm(range(len(ds_item)), desc=f"Processing questions in {ds_name}"):
question = ds_item['prompt'][idx]
answers = ds_item['answers'][idx]
labels = ds_item['label'][idx]
outcome_scores = []
question_ids = tokenizer(
question,
add_special_tokens=False,
return_tensors='pt',
)['input_ids']
for answer in tqdm(answers, desc="Processing answers"):
steps = [i.rstrip() for i in answer.split("\n\n")]
input_ids = question_ids.clone()
score_ids = []
for step in steps:
step_ids = tokenizer(
step,
add_special_tokens=False,
return_tensors='pt',
)['input_ids']
input_ids = torch.cat(
[input_ids, step_ids, step_separator_token],
dim=-1,
)
score_ids.append(input_ids.size(-1) - 1)
input_ids = input_ids.to(model.device, dtype=torch.long)
token_masks = torch.zeros_like(input_ids, dtype=torch.bool)
token_masks[0, score_ids] = True
assert torch.all(input_ids[token_masks].to("cpu") == step_separator_token)
with torch.no_grad():
logits = model(input_ids).logits
step_reward = make_step_rewards(logits, token_masks)
outcome_reward = torch.tensor(step_reward).sum(dim=-1)
# TODO: batch input & output
outcome_scores.append(outcome_reward.item())
best_idx = np.argmax(outcome_scores)
if labels[best_idx] == 1:
correct += 1
total += 1
print(f"Accuracy on {ds_name}: {correct / total}")
Citation
If you find our work useful, we would appreciate it if you could cite our work:
@article{cheng2025stop,
title={Stop Summation: Min-Form Credit Assignment Is All Process Reward Model Needs for Reasoning},
author={Cheng, Jie and Qiao, Ruixi and Li, Lijun and Guo, Chao and Wang, Junle and Xiong, Gang and Lv, Yisheng and Wang, Fei-Yue},
journal={arXiv preprint arXiv:2504.15275},
year={2025}
}
- Downloads last month
- 18