import re from typing import Dict, Tuple, Optional def extract_solution(solution_str): # Split response to isolate assistant output if "Assistant:" in solution_str: processed_str = solution_str.split("Assistant:", 1)[1] elif "<|im_start|>assistant" in solution_str: processed_str = solution_str.split("<|im_start|>assistant", 1)[1] elif "[/INST]" in solution_str: processed_str = solution_str.split("[/INST]", 1)[1] else: print("[Error] Failed to locate model response header") return None, solution_str # Extract final answer using XML-style tags answer_pattern = r'(.*?)' matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) if not matches: print("[Error] No valid answer tags found") return None, processed_str final_answer = matches[-1].group(1).strip() # format1_match = re.search(r"([a-zA-Z]):", final_answer) format1_match = re.findall(r'(?:^|\s)([A-Za-z])(?=[:.])', final_answer, re.MULTILINE) format2_match = re.findall(r'(yes|no)', final_answer.lower()) format3_match = re.findall(r'[A-Z]', final_answer.upper()) if format1_match: # return format1_match.group(1).strip(), processed_str return ','.join(format1_match), processed_str elif format2_match: return format2_match[-1].strip(), processed_str elif format3_match: return ",".join(format3_match), processed_str return final_answer.replace(", ", ","), processed_str def validate_response_structure(processed_str: str) -> bool: """Performs comprehensive validation of response structure. Args: processed_str: Processed response string from the model Returns: Boolean indicating whether all formatting requirements are met """ print("\n[Structure Validation]") validation_passed = True # Check required tags tags = { 'think_start': ('', 1), 'think_end': ('', 1), 'answer_start': ('', 1), 'answer_end': ('', 1) } positions = {} for tag_name, (tag_str, expected_count) in tags.items(): count = processed_str.count(tag_str) positions[tag_name] = pos = processed_str.find(tag_str) print(f" {tag_str}: count={count}, position={pos}") if count != expected_count: print(f" [Error] {tag_str} appears {count} times (expected {expected_count})") validation_passed = False # Verify tag order if (positions['think_start'] > positions['think_end'] or positions['think_end'] > positions['answer_start'] or positions['answer_start'] > positions['answer_end']): print(" [Error] Incorrect tag order: Expected ......") validation_passed = False else: print(" Tag sequence validation passed") return validation_passed def validate_response_BDI(processed_str: str) -> bool: """Checks if "Beliefs:", "Desires:", and "Intentions:" exist in the processed_str. Args: processed_str: Processed response string from the model Returns: Boolean indicating whether all BDI formatting requirements are met Returns True only if all three patterns exist """ #patterns = ["Beliefs:", "Desires:", "Intentions:"] ## Check if each pattern exists in the processed_str #exists = {pattern: pattern in processed_str for pattern in patterns} ## For debugging - print individual results #for pattern, found in exists.items(): # print(f"'{pattern}' exists: {found}") #print(f"All patterns exist: {all(exists.values())}") ## Return True only if all patterns exist #return all(exists.values()) # Check if all uppercase variants exist uppercase_exists = ("Beliefs:" in processed_str) and ("Desires:" in processed_str) and ("Intentions:" in processed_str) # Check if all lowercase variants exist lowercase_exists = ("beliefs:" in processed_str) and ("desires:" in processed_str) and ("intentions:" in processed_str) Customer_exists = ("Customer believes" in processed_str) and ("desires" in processed_str) and ("Customer intends" in processed_str) customer_exists = ("customer believes" in processed_str) and ("desires" in processed_str) and ("customer intends" in processed_str) # Return true if either all uppercase or all lowercase variants are present return uppercase_exists or lowercase_exists or customer_exists or Customer_exists def parse_ground_truth_text_format(ground_truth): answer_pattern = r'(.*?)' matches = list(re.finditer(answer_pattern, ground_truth, re.DOTALL)) if ":" in ground_truth and "[{" in ground_truth: return ground_truth elif ":" in ground_truth: format1_match = re.search(r"([a-zA-Z]):", ground_truth) if format1_match: return format1_match.group(1).strip() else: return ground_truth elif matches: return matches[-1].group(1).strip() else: return ground_truth def compute_score(solution_str: str, ground_truth: str, method='strict', format_reward: int = 1, answer_reward: float = 1.0): """Computes comprehensive score for model response. Args: solution_str: Raw model response string ground_truth: Dictionary containing ground truth data method: the method to extract the solution, choices are 'strict' and 'flexible' format_reward: Points awarded/deducted for format correctness answer_reward: Points awarded/deducted for answer correctness Returns: Total score (sum of format and answer rewards) """ print("\n" + "="*80) print(" Processing New Sample ".center(80, '=')) print(f"[Ground Truth]: {ground_truth}") ground_truth = parse_ground_truth_text_format(ground_truth) # Extract model answer answer_text, processed_str=extract_solution(solution_str=solution_str) print(f"\n[Model Response]\n{processed_str}") print(f"\n[Processed Model Response]\n{answer_text}") # Validate response structure format_correct = validate_response_structure(processed_str) BDI_correct = validate_response_BDI(processed_str) if format_correct and BDI_correct: format_score = format_reward elif format_correct: format_score = format_reward/2 else: format_score = -abs(format_reward) print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}") print(f"\n BDI Format validation: {'PASS' if BDI_correct else 'FAIL'}") print(f" Format score: {format_score}") # Validate answer content answer_score = 0 if format_correct and BDI_correct and answer_text: print(f"\n[Content Validation]") print(f" Expected: {ground_truth}") print(f" Predicted: {answer_text}") if answer_text.casefold() == ground_truth.casefold(): answer_score = 2 print(" Content validation: FULL MATCH") else: answer_score = -1.5 print(" Content validation: MISMATCH") else: answer_score = -2 print("\n[Content Validation] Skipped due to format errors or missing answer") total_score = format_score + answer_score print("\n" + "-"*80) print(f" Final Score ".center(80, '-')) print(f" Format: {format_score}") print(f" Answer: {answer_score}") print(f" Total: {total_score}") print("="*80 + "\n") return total_score