import re from typing import Dict, Tuple, Optional def extract_solution(solution_str): # Split response to isolate assistant output if "" in solution_str: processed_str = solution_str.split("")[-1] elif "" in solution_str: processed_str = solution_str.split("")[-1] elif "{" in solution_str: processed_str = solution_str.split("{")[-1] else: print("[Error] Failed to locate model response header") return None, solution_str if '"action":' in processed_str and '"action_argument":' in processed_str: action = processed_str.split('"action":')[-1].split('"action_argument":')[0].replace('"','').replace(',','').strip() action_argument = processed_str.split('"action_argument":')[-1].split('}')[0].replace('"','').replace(',','').strip() return action, action_argument, processed_str elif "'action':" in processed_str and "'action_argument':" in processed_str: action = processed_str.split("'action':")[-1].split("'action_argument':")[0].replace("'",'').replace(',','').strip() action_argument = processed_str.split("'action_argument':")[-1].split('}')[0].replace("'",'').strip() return action, action_argument, processed_str else: print("[Error] No valid answer tags found") return None, None, processed_str def parse_ground_truth_text_format(ground_truth): answer_pattern = r'(.*?)' argument_pattern = r'(.*?)' action_matches = list(re.finditer(answer_pattern, ground_truth, re.DOTALL)) argument_matches = list(re.finditer(argument_pattern, ground_truth, re.DOTALL)) if action_matches and argument_matches: return action_matches[-1].group(1).strip(), argument_matches[-1].group(1).strip() else: return ground_truth, ground_truth def validate_response_structure(processed_str: str) -> bool: """Performs comprehensive validation of response structure. Args: processed_str: Processed response string from the model Returns: Boolean indicating whether all formatting requirements are met """ print("\n[Structure Validation]") validation_passed = True # Check required tags tags = { # 'think_start': ('', 1), 'think_end': ('', 1), 'answer_start': ('', 1), 'answer_end': ('', 1) } positions = {} for tag_name, (tag_str, expected_count) in tags.items(): count = processed_str.count(tag_str) positions[tag_name] = pos = processed_str.find(tag_str) print(f" {tag_str}: count={count}, position={pos}") if count != expected_count: print(f" [Error] {tag_str} appears {count} times (expected {expected_count})") validation_passed = False # Verify tag order # positions['think_start'] > positions['think_end'] or positions['think_end'] > positions['answer_start'] or if (positions['think_end'] > positions['answer_start'] or positions['answer_start'] > positions['answer_end']): print(" [Error] Incorrect tag order: Expected ......") validation_passed = False else: print(" Tag sequence validation passed") return validation_passed def compute_score(solution_str: str, ground_truth: str, method='strict', format_reward: int = 1, action_reward: float = 1.0, argument_reward: float = 2.0): """Computes comprehensive score for model response. Args: solution_str: Raw model response string ground_truth: Dictionary containing ground truth data method: the method to extract the solution, choices are 'strict' and 'flexible' format_reward: Points awarded/deducted for format correctness answer_reward: Points awarded/deducted for answer correctness Returns: Total score (sum of format and answer rewards) """ print("\n" + "="*80) print(" Processing New Sample ".center(80, '=')) print(f"[Ground Truth]: {ground_truth}") action_ground_truth, argument_ground_truth = parse_ground_truth_text_format(ground_truth) # Extract model answer # answer_text, processed_str=extract_solution(solution_str=solution_str) action, action_argument, processed_str=extract_solution(solution_str=solution_str) print(f"\n[Model Response]\n{processed_str}") print(f"\n[Processed Model Response Action]\n{action}") print(f"\n[Processed Model Response Action_Argument]\n{action_argument}") # Validate response structure format_correct = validate_response_structure(solution_str) format_score = format_reward if format_correct else -abs(format_reward) print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}") print(f" Format score: {format_score}") # Validate answer content answer_score = 0 if format_correct and action: print(f"\n[Content Validation]") print(f" Expected Action: {action_ground_truth}") print(f" Predicted Action: {action}") if action.casefold() == action_ground_truth.casefold(): action_score = 1 print(" Action validation: FULL MATCH") else: action_score = -1 print(" Action validation: MISMATCH") if format_correct and action_argument: print(f"\n[Content Validation]") print(f" Expected Argument: {argument_ground_truth}") print(f" Predicted Argument: {action_argument}") if action_argument.casefold() == argument_ground_truth.casefold(): argument_score = 2 print(" Argument validation: FULL MATCH") else: argument_score = -2 print(" Argument validation: MISMATCH") total_score = format_score + action_score + argument_score print("\n" + "-"*80) print(f" Final Score ".center(80, '-')) print(f" Format: {format_score}") print(f" Action: {action_score}") print(f" Argument: {argument_score}") print(f" Total: {total_score}") print("="*80 + "\n") return total_score