|
import re |
|
from typing import Dict, Tuple, Optional |
|
|
|
def extract_solution(solution_str): |
|
|
|
if "</think>" in solution_str: |
|
processed_str = solution_str.split("</think>")[-1] |
|
elif "<answer>" in solution_str: |
|
processed_str = solution_str.split("<answer>")[-1] |
|
elif "{" in solution_str: |
|
processed_str = solution_str.split("{")[-1] |
|
else: |
|
print("[Error] Failed to locate model response header") |
|
return None, solution_str |
|
|
|
if '"action":' in processed_str and '"action_argument":' in processed_str: |
|
action = processed_str.split('"action":')[-1].split('"action_argument":')[0].replace('"','').replace(',','').strip() |
|
action_argument = processed_str.split('"action_argument":')[-1].split('}')[0].replace('"','').replace(',','').strip() |
|
return action, action_argument, processed_str |
|
elif "'action':" in processed_str and "'action_argument':" in processed_str: |
|
action = processed_str.split("'action':")[-1].split("'action_argument':")[0].replace("'",'').replace(',','').strip() |
|
action_argument = processed_str.split("'action_argument':")[-1].split('}')[0].replace("'",'').strip() |
|
return action, action_argument, processed_str |
|
else: |
|
print("[Error] No valid answer tags found") |
|
return None, None, processed_str |
|
|
|
|
|
def parse_ground_truth_text_format(ground_truth): |
|
answer_pattern = r'<answer_action>(.*?)</answer_action>' |
|
argument_pattern = r'<answer_argument>(.*?)</answer_argument>' |
|
|
|
action_matches = list(re.finditer(answer_pattern, ground_truth, re.DOTALL)) |
|
argument_matches = list(re.finditer(argument_pattern, ground_truth, re.DOTALL)) |
|
|
|
if action_matches and argument_matches: |
|
return action_matches[-1].group(1).strip(), argument_matches[-1].group(1).strip() |
|
else: |
|
return ground_truth, ground_truth |
|
|
|
|
|
def validate_response_structure(processed_str: str) -> bool: |
|
"""Performs comprehensive validation of response structure. |
|
|
|
Args: |
|
processed_str: Processed response string from the model |
|
|
|
Returns: |
|
Boolean indicating whether all formatting requirements are met |
|
""" |
|
print("\n[Structure Validation]") |
|
validation_passed = True |
|
|
|
|
|
tags = { |
|
|
|
|
|
'answer_start': ('<answer>', 1), |
|
'answer_end': ('</answer>', 1) |
|
} |
|
|
|
positions = {} |
|
for tag_name, (tag_str, expected_count) in tags.items(): |
|
count = processed_str.count(tag_str) |
|
positions[tag_name] = pos = processed_str.find(tag_str) |
|
|
|
print(f" {tag_str}: count={count}, position={pos}") |
|
|
|
if count != expected_count: |
|
print(f" [Error] {tag_str} appears {count} times (expected {expected_count})") |
|
validation_passed = False |
|
|
|
|
|
|
|
if (positions['answer_start'] > positions['answer_end']): |
|
print(" [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>") |
|
validation_passed = False |
|
else: |
|
print(" Tag sequence validation passed") |
|
|
|
return validation_passed |
|
|
|
def compute_score(solution_str: str, ground_truth: str, method='strict', format_reward: int = 1, action_reward: float = 1.0, argument_reward: float = 2.0): |
|
"""Computes comprehensive score for model response. |
|
|
|
Args: |
|
solution_str: Raw model response string |
|
ground_truth: Dictionary containing ground truth data |
|
method: the method to extract the solution, choices are 'strict' and 'flexible' |
|
format_reward: Points awarded/deducted for format correctness |
|
answer_reward: Points awarded/deducted for answer correctness |
|
|
|
Returns: |
|
Total score (sum of format and answer rewards) |
|
""" |
|
print("\n" + "="*80) |
|
print(" Processing New Sample ".center(80, '=')) |
|
print(f"[Ground Truth]: {ground_truth}") |
|
action_ground_truth, argument_ground_truth = parse_ground_truth_text_format(ground_truth) |
|
|
|
|
|
|
|
action, action_argument, processed_str=extract_solution(solution_str=solution_str) |
|
print(f"\n[Model Response]\n{processed_str}") |
|
print(f"\n[Processed Model Response Action]\n{action}") |
|
print(f"\n[Processed Model Response Action_Argument]\n{action_argument}") |
|
|
|
|
|
|
|
format_correct = validate_response_structure(processed_str) |
|
|
|
if format_correct: |
|
format_score = format_reward |
|
else: |
|
format_score = -abs(format_reward) |
|
print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}") |
|
print(f" Format score: {format_score}") |
|
|
|
|
|
answer_score = 0 |
|
if action: |
|
print(f"\n[Content Validation]") |
|
print(f" Expected Action: {action_ground_truth}") |
|
print(f" Predicted Action: {action}") |
|
if action.casefold() == action_ground_truth.casefold(): |
|
action_score = 1 |
|
print(" Action validation: FULL MATCH") |
|
else: |
|
action_score = -1 |
|
print(" Action validation: MISMATCH") |
|
|
|
if action_argument: |
|
print(f"\n[Content Validation]") |
|
print(f" Expected Argument: {argument_ground_truth}") |
|
print(f" Predicted Argument: {action_argument}") |
|
if action_argument.casefold() == argument_ground_truth.casefold(): |
|
argument_score = 2 |
|
print(" Argument validation: FULL MATCH") |
|
else: |
|
argument_score = -2 |
|
print(" Argument validation: MISMATCH") |
|
|
|
|
|
total_score = action_score + argument_score |
|
print("\n" + "-"*80) |
|
print(f" Final Score ".center(80, '-')) |
|
print(f" Action: {action_score}") |
|
print(f" Argument: {argument_score}") |
|
print(f" Total: {total_score}") |
|
print("="*80 + "\n") |
|
|
|
return total_score |