File size: 6,097 Bytes
c12410f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import re
from typing import Dict, Tuple, Optional
def extract_solution(solution_str):
# Split response to isolate assistant output
if "</think>" in solution_str:
processed_str = solution_str.split("</think>")[-1]
elif "<answer>" in solution_str:
processed_str = solution_str.split("<answer>")[-1]
elif "{" in solution_str:
processed_str = solution_str.split("{")[-1]
else:
print("[Error] Failed to locate model response header")
return None, solution_str
if '"action":' in processed_str and '"action_argument":' in processed_str:
action = processed_str.split('"action":')[-1].split('"action_argument":')[0].replace('"','').replace(',','').strip()
action_argument = processed_str.split('"action_argument":')[-1].split('}')[0].replace('"','').replace(',','').strip()
return action, action_argument, processed_str
elif "'action':" in processed_str and "'action_argument':" in processed_str:
action = processed_str.split("'action':")[-1].split("'action_argument':")[0].replace("'",'').replace(',','').strip()
action_argument = processed_str.split("'action_argument':")[-1].split('}')[0].replace("'",'').strip()
return action, action_argument, processed_str
else:
print("[Error] No valid answer tags found")
return None, None, processed_str
def parse_ground_truth_text_format(ground_truth):
answer_pattern = r'<answer_action>(.*?)</answer_action>'
argument_pattern = r'<answer_argument>(.*?)</answer_argument>'
action_matches = list(re.finditer(answer_pattern, ground_truth, re.DOTALL))
argument_matches = list(re.finditer(argument_pattern, ground_truth, re.DOTALL))
if action_matches and argument_matches:
return action_matches[-1].group(1).strip(), argument_matches[-1].group(1).strip()
else:
return ground_truth, ground_truth
def validate_response_structure(processed_str: str) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
print("\n[Structure Validation]")
validation_passed = True
# Check required tags
tags = {
# 'think_start': ('<think>', 1),
# 'think_end': ('</think>', 1),
'answer_start': ('<answer>', 1),
'answer_end': ('</answer>', 1)
}
positions = {}
for tag_name, (tag_str, expected_count) in tags.items():
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
print(f" {tag_str}: count={count}, position={pos}")
if count != expected_count:
print(f" [Error] {tag_str} appears {count} times (expected {expected_count})")
validation_passed = False
# Verify tag order
# positions['think_start'] > positions['think_end'] or positions['think_end'] > positions['answer_start'] or
if (positions['answer_start'] > positions['answer_end']):
print(" [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
validation_passed = False
else:
print(" Tag sequence validation passed")
return validation_passed
def compute_score(solution_str: str, ground_truth: str, method='strict', format_reward: int = 1, action_reward: float = 1.0, argument_reward: float = 2.0):
"""Computes comprehensive score for model response.
Args:
solution_str: Raw model response string
ground_truth: Dictionary containing ground truth data
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_reward: Points awarded/deducted for format correctness
answer_reward: Points awarded/deducted for answer correctness
Returns:
Total score (sum of format and answer rewards)
"""
print("\n" + "="*80)
print(" Processing New Sample ".center(80, '='))
print(f"[Ground Truth]: {ground_truth}")
action_ground_truth, argument_ground_truth = parse_ground_truth_text_format(ground_truth)
# Extract model answer
# answer_text, processed_str=extract_solution(solution_str=solution_str)
action, action_argument, processed_str=extract_solution(solution_str=solution_str)
print(f"\n[Model Response]\n{processed_str}")
print(f"\n[Processed Model Response Action]\n{action}")
print(f"\n[Processed Model Response Action_Argument]\n{action_argument}")
# Validate response structure
format_correct = validate_response_structure(processed_str)
if format_correct:
format_score = format_reward
else:
format_score = -abs(format_reward)
print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}")
print(f" Format score: {format_score}")
# Validate answer content
answer_score = 0
if action:
print(f"\n[Content Validation]")
print(f" Expected Action: {action_ground_truth}")
print(f" Predicted Action: {action}")
if action.casefold() == action_ground_truth.casefold():
action_score = 1
print(" Action validation: FULL MATCH")
else:
action_score = -1
print(" Action validation: MISMATCH")
if action_argument:
print(f"\n[Content Validation]")
print(f" Expected Argument: {argument_ground_truth}")
print(f" Predicted Argument: {action_argument}")
if action_argument.casefold() == argument_ground_truth.casefold():
argument_score = 2
print(" Argument validation: FULL MATCH")
else:
argument_score = -2
print(" Argument validation: MISMATCH")
total_score = action_score + argument_score
print("\n" + "-"*80)
print(f" Final Score ".center(80, '-'))
print(f" Action: {action_score}")
print(f" Argument: {argument_score}")
print(f" Total: {total_score}")
print("="*80 + "\n")
return total_score |