All_updated_file / amazonmix.py
JackyChunKit's picture
Create amazonmix.py
ee576a1 verified
import re
from typing import Dict, Tuple, Optional
def extract_solution(solution_str):
# Split response to isolate assistant output
if "Assistant:" in solution_str:
processed_str = solution_str.split("Assistant:", 1)[1]
elif "<|im_start|>assistant" in solution_str:
processed_str = solution_str.split("<|im_start|>assistant", 1)[1]
elif "[/INST]" in solution_str:
processed_str = solution_str.split("[/INST]", 1)[1]
else:
print("[Error] Failed to locate model response header")
return None, solution_str
# Extract final answer using XML-style tags
answer_pattern = r'<answer>(.*?)</answer>'
matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))
if not matches:
print("[Error] No valid answer tags found")
return None, processed_str
final_answer = matches[-1].group(1).strip()
# format1_match = re.search(r"([a-zA-Z]):", final_answer)
format1_match = re.findall(r'(?:^|\s)([A-Za-z])(?=[:.])', final_answer, re.MULTILINE)
format2_match = re.findall(r'(yes|no)', final_answer.lower())
format3_match = re.findall(r'[A-Z]', final_answer.upper())
if format1_match:
# return format1_match.group(1).strip(), processed_str
return ','.join(format1_match), processed_str
elif format2_match:
return format2_match[-1].strip(), processed_str
elif format3_match:
return ",".join(format3_match), processed_str
return final_answer.replace(", ", ","), processed_str
def validate_response_structure(processed_str: str) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
print("\n[Structure Validation]")
validation_passed = True
# Check required tags
tags = {
'think_start': ('<think>', 1),
'think_end': ('</think>', 1),
'answer_start': ('<answer>', 1),
'answer_end': ('</answer>', 1)
}
positions = {}
for tag_name, (tag_str, expected_count) in tags.items():
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
print(f" {tag_str}: count={count}, position={pos}")
if count != expected_count:
print(f" [Error] {tag_str} appears {count} times (expected {expected_count})")
validation_passed = False
# Verify tag order
if (positions['think_start'] > positions['think_end'] or
positions['think_end'] > positions['answer_start'] or
positions['answer_start'] > positions['answer_end']):
print(" [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
validation_passed = False
else:
print(" Tag sequence validation passed")
return validation_passed
def validate_response_BDI(processed_str: str) -> bool:
"""Checks if "Beliefs:", "Desires:", and "Intentions:" exist in the processed_str.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all BDI formatting requirements are met
Returns True only if all three patterns exist
"""
#patterns = ["Beliefs:", "Desires:", "Intentions:"]
## Check if each pattern exists in the processed_str
#exists = {pattern: pattern in processed_str for pattern in patterns}
## For debugging - print individual results
#for pattern, found in exists.items():
# print(f"'{pattern}' exists: {found}")
#print(f"All patterns exist: {all(exists.values())}")
## Return True only if all patterns exist
#return all(exists.values())
# Check if all uppercase variants exist
uppercase_exists = ("Beliefs:" in processed_str) and ("Desires:" in processed_str) and ("Intentions:" in processed_str)
# Check if all lowercase variants exist
lowercase_exists = ("beliefs:" in processed_str) and ("desires:" in processed_str) and ("intentions:" in processed_str)
Customer_exists = ("Customer believes" in processed_str) and ("desires" in processed_str) and ("Customer intends" in processed_str)
customer_exists = ("customer believes" in processed_str) and ("desires" in processed_str) and ("customer intends" in processed_str)
# Return true if either all uppercase or all lowercase variants are present
return uppercase_exists or lowercase_exists or customer_exists or Customer_exists
def parse_ground_truth_text_format(ground_truth):
answer_pattern = r'<answer>(.*?)</answer>'
matches = list(re.finditer(answer_pattern, ground_truth, re.DOTALL))
if ":" in ground_truth and "[{" in ground_truth:
return ground_truth
elif ":" in ground_truth:
format1_match = re.search(r"([a-zA-Z]):", ground_truth)
if format1_match:
return format1_match.group(1).strip()
else:
return ground_truth
elif matches:
return matches[-1].group(1).strip()
else:
return ground_truth
def compute_score(solution_str: str, ground_truth: str, method='strict', format_reward: int = 1, answer_reward: float = 1.0):
"""Computes comprehensive score for model response.
Args:
solution_str: Raw model response string
ground_truth: Dictionary containing ground truth data
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_reward: Points awarded/deducted for format correctness
answer_reward: Points awarded/deducted for answer correctness
Returns:
Total score (sum of format and answer rewards)
"""
print("\n" + "="*80)
print(" Processing New Sample ".center(80, '='))
print(f"[Ground Truth]: {ground_truth}")
ground_truth = parse_ground_truth_text_format(ground_truth)
# Extract model answer
answer_text, processed_str=extract_solution(solution_str=solution_str)
print(f"\n[Model Response]\n{processed_str}")
print(f"\n[Processed Model Response]\n{answer_text}")
# Validate response structure
format_correct = validate_response_structure(processed_str)
BDI_correct = validate_response_BDI(processed_str)
if format_correct and BDI_correct:
format_score = format_reward
elif format_correct:
format_score = format_reward/2
else:
format_score = -abs(format_reward)
print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}")
print(f"\n BDI Format validation: {'PASS' if BDI_correct else 'FAIL'}")
print(f" Format score: {format_score}")
# Validate answer content
answer_score = 0
if format_correct and BDI_correct and answer_text:
print(f"\n[Content Validation]")
print(f" Expected: {ground_truth}")
print(f" Predicted: {answer_text}")
if answer_text.casefold() == ground_truth.casefold():
answer_score = 2
print(" Content validation: FULL MATCH")
else:
answer_score = -1.5
print(" Content validation: MISMATCH")
else:
answer_score = -2
print("\n[Content Validation] Skipped due to format errors or missing answer")
total_score = format_score + answer_score
print("\n" + "-"*80)
print(f" Final Score ".center(80, '-'))
print(f" Format: {format_score}")
print(f" Answer: {answer_score}")
print(f" Total: {total_score}")
print("="*80 + "\n")
return total_score