|
|
import re |
|
|
import torch |
|
|
torch.cuda.empty_cache() |
|
|
from typing import List |
|
|
from copy import deepcopy |
|
|
|
|
|
from swift.plugin import ORM, orms |
|
|
from swift.utils import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
""" |
|
|
Step 1: Define a Reward Class |
|
|
Implement your custom reward calculation logic within the __call__ method. |
|
|
The method accepts the model's output completions and dataset columns (passed as kwargs) as input parameters. |
|
|
|
|
|
Step 2: Register the Reward Class in orms |
|
|
For example: |
|
|
python orms['external_math_acc'] = MathAccuracy |
|
|
|
|
|
Step 3: Configure the Arguments |
|
|
Use the following arguments when running the script: |
|
|
bash --plugin /path/to/plugin.py --reward_funcs external_math_acc |
|
|
""" |
|
|
|
|
|
def count_xml(text) -> float: |
|
|
""" |
|
|
Count XML tags in response. |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
Score based on XML tag presence |
|
|
""" |
|
|
count = 0.0 |
|
|
if text.count("<think>") == 1: |
|
|
count += 0.5 |
|
|
if text.count("</think>") == 1: |
|
|
count += 0.5 |
|
|
return count |
|
|
|
|
|
def extract_xml_answer(text: str) -> str: |
|
|
""" |
|
|
Extract answer from XML-formatted text. |
|
|
|
|
|
Args: |
|
|
text: Input text with XML tags |
|
|
|
|
|
Returns: |
|
|
Extracted answer text |
|
|
""" |
|
|
try: |
|
|
answer = text.split("</think>")[1] |
|
|
return answer.strip() |
|
|
except: |
|
|
return "" |
|
|
|
|
|
def xmlcount_reward_func(completions, **kwargs) -> List[float]: |
|
|
""" |
|
|
Reward function based on proper XML tag usage. |
|
|
|
|
|
Args: |
|
|
completions: Model completions |
|
|
|
|
|
Returns: |
|
|
List of reward scores |
|
|
""" |
|
|
|
|
|
contents = completions |
|
|
return [count_xml(c) for c in contents] |
|
|
|
|
|
def int_reward_func(completions, **kwargs) -> List[float]: |
|
|
""" |
|
|
Reward function that checks if responses contain valid direction tokens. |
|
|
|
|
|
Args: |
|
|
completions: Model completions |
|
|
|
|
|
Returns: |
|
|
List of reward scores |
|
|
""" |
|
|
allowed_tokens = {"<|up|>", "<|down|>", "<|right|>", "<|left|>"} |
|
|
|
|
|
|
|
|
responses = completions |
|
|
extracted_responses = [extract_xml_answer(r) for r in responses] |
|
|
|
|
|
def is_valid_sequence(seq): |
|
|
|
|
|
seq_no_whitespace = re.sub(r'\s+', '', seq) |
|
|
if not seq_no_whitespace: |
|
|
return False |
|
|
found_tokens = re.findall(r'<\|(?:up|down|right|left)\|>', seq_no_whitespace) |
|
|
reconstructed = ''.join(found_tokens) |
|
|
if reconstructed != seq_no_whitespace: |
|
|
return False |
|
|
return all(token in allowed_tokens for token in found_tokens) |
|
|
|
|
|
return [1.0 if is_valid_sequence(r) else 0.0 for r in extracted_responses] |
|
|
|
|
|
def count_turns(steps): |
|
|
moves = re.findall(r"<\|(.*?)\|>", steps) |
|
|
turns = sum(1 for i in range(1, len(moves)) if moves[i] != moves[i - 1]) |
|
|
return moves, turns |
|
|
|
|
|
def correctness_reward_func(completions, answer, **kwargs) -> List[float]: |
|
|
""" |
|
|
Reward function that checks correctness of answers. |
|
|
|
|
|
Args: |
|
|
prompts: Input prompts |
|
|
completions: Model completions |
|
|
answer: Ground truth answers |
|
|
|
|
|
Returns: |
|
|
List of reward scores |
|
|
""" |
|
|
rewards = [] |
|
|
responses = completions |
|
|
extracted_responses = [extract_xml_answer(r) for r in responses] |
|
|
logger.debug('-'*20) |
|
|
|
|
|
logger.debug(f"\nAnswer:\n{answer[0]}") |
|
|
logger.debug(f"\nResponse:\n{responses[0]}") |
|
|
logger.debug(f"\nExtracted:\n{extracted_responses[0]}") |
|
|
for r, a in zip(extracted_responses, answer): |
|
|
r_steps, r_turns = count_turns(r) |
|
|
a_steps, a_turns = count_turns(a) |
|
|
if r == a: |
|
|
reward = len(r_steps) * 2 * (r_turns + 1) |
|
|
else: |
|
|
k = 0 |
|
|
for r_s, a_s in zip(r_steps, a_steps): |
|
|
if r_s == a_s: |
|
|
k += 1 |
|
|
else: |
|
|
break |
|
|
prefix = r_steps[:k] |
|
|
turns = count_turns("".join(prefix))[1] |
|
|
reward = k * 1 * (turns + 1) |
|
|
rewards.append(reward) |
|
|
return rewards |
|
|
|
|
|
class MazeReward(ORM): |
|
|
|
|
|
def __call__(self, completions, solution, **kwargs) -> List[float]: |
|
|
|
|
|
rewards = correctness_reward_func(completions, solution) |
|
|
return rewards |
|
|
|
|
|
class MazeFormat(ORM): |
|
|
|
|
|
def __call__(self, completions, solution, **kwargs) -> List[float]: |
|
|
|
|
|
rewards = int_reward_func(completions) |
|
|
return rewards |
|
|
|
|
|
class Format(ORM): |
|
|
|
|
|
def __call__(self, completions, **kwargs) -> List[float]: |
|
|
rewards = xmlcount_reward_func(completions) |
|
|
return rewards |
|
|
|
|
|
orms['external_r1v_acc'] = MazeReward |
|
|
orms['external_r1v_format'] = MazeFormat |
|
|
orms['format'] = Format |