File size: 5,713 Bytes
42aefd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import re
from typing import Dict, Tuple, Optional
def extract_solution(solution_str):
# Split response to isolate assistant output
if "Assistant:" in solution_str:
processed_str = solution_str.split("Assistant:", 1)[1]
elif "<|im_start|>assistant" in solution_str:
processed_str = solution_str.split("<|im_start|>assistant", 1)[1]
elif "[/INST]" in solution_str:
processed_str = solution_str.split("[/INST]", 1)[1]
else:
print("[Error] Failed to locate model response header")
return None, solution_str
# Extract final answer using XML-style tags
answer_pattern = r'<answer>(.*?)</answer>'
matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))
if not matches:
print("[Error] No valid answer tags found")
return None, processed_str
final_answer = matches[-1].group(1).strip()
# format1_match = re.search(r"([a-zA-Z]):", final_answer)
format1_match = re.findall(r'(?:^|\s)([A-Za-z])(?=[:.])', final_answer, re.MULTILINE)
format2_match = re.findall(r'(yes|no)', final_answer.lower())
format3_match = re.findall(r'[A-Z]', final_answer.upper())
if format1_match:
# return format1_match.group(1).strip(), processed_str
return ','.join(format1_match), processed_str
elif format2_match:
return format2_match[-1].strip(), processed_str
elif format3_match:
return ",".join(format3_match), processed_str
return final_answer.replace(", ", ","), processed_str
def validate_response_structure(processed_str: str) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
print("\n[Structure Validation]")
validation_passed = True
# Check required tags
tags = {
'think_start': ('<think>', 1),
'think_end': ('</think>', 1),
'answer_start': ('<answer>', 1),
'answer_end': ('</answer>', 1)
}
positions = {}
for tag_name, (tag_str, expected_count) in tags.items():
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
print(f" {tag_str}: count={count}, position={pos}")
if count != expected_count:
print(f" [Error] {tag_str} appears {count} times (expected {expected_count})")
validation_passed = False
# Verify tag order
if (positions['think_start'] > positions['think_end'] or
positions['think_end'] > positions['answer_start'] or
positions['answer_start'] > positions['answer_end']):
print(" [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
validation_passed = False
else:
print(" Tag sequence validation passed")
return validation_passed
def parse_ground_truth_text_format(ground_truth):
answer_pattern = r'<answer>(.*?)</answer>'
matches = list(re.finditer(answer_pattern, ground_truth, re.DOTALL))
if ":" in ground_truth and "[{" in ground_truth:
return ground_truth
elif ":" in ground_truth:
format1_match = re.search(r"([a-zA-Z]):", ground_truth)
if format1_match:
return format1_match.group(1).strip()
else:
return ground_truth
elif matches:
return matches[-1].group(1).strip()
else:
return ground_truth
def compute_score(solution_str: str, ground_truth: str, method='strict', format_reward: int = 1, answer_reward: float = 1.0):
"""Computes comprehensive score for model response.
Args:
solution_str: Raw model response string
ground_truth: Dictionary containing ground truth data
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_reward: Points awarded/deducted for format correctness
answer_reward: Points awarded/deducted for answer correctness
Returns:
Total score (sum of format and answer rewards)
"""
print("\n" + "="*80)
print(" Processing New Sample ".center(80, '='))
print(f"[Ground Truth]: {ground_truth}")
ground_truth = parse_ground_truth_text_format(ground_truth)
# Extract model answer
answer_text, processed_str=extract_solution(solution_str=solution_str)
print(f"\n[Model Response]\n{processed_str}")
print(f"\n[Processed Model Response]\n{answer_text}")
# Validate response structure
format_correct = validate_response_structure(processed_str)
format_score = format_reward if format_correct else -abs(format_reward)
print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}")
print(f" Format score: {format_score}")
# Validate answer content
answer_score = 0
if format_correct and answer_text:
print(f"\n[Content Validation]")
print(f" Expected: {ground_truth}")
print(f" Predicted: {answer_text}")
if answer_text.casefold() == ground_truth.casefold():
answer_score = 2
print(" Content validation: FULL MATCH")
else:
answer_score = -1.5
print(" Content validation: MISMATCH")
else:
answer_score = -2
print("\n[Content Validation] Skipped due to format errors or missing answer")
total_score = format_score + answer_score
print("\n" + "-"*80)
print(f" Final Score ".center(80, '-'))
print(f" Format: {format_score}")
print(f" Answer: {answer_score}")
print(f" Total: {total_score}")
print("="*80 + "\n")
return total_score
|