All_updated_file / Action_ASIN.py
JackyChunKit's picture
Upload Action_ASIN.py
c12410f verified
import re
from typing import Dict, Tuple, Optional
def extract_solution(solution_str):
# Split response to isolate assistant output
if "</think>" in solution_str:
processed_str = solution_str.split("</think>")[-1]
elif "<answer>" in solution_str:
processed_str = solution_str.split("<answer>")[-1]
elif "{" in solution_str:
processed_str = solution_str.split("{")[-1]
else:
print("[Error] Failed to locate model response header")
return None, solution_str
if '"action":' in processed_str and '"action_argument":' in processed_str:
action = processed_str.split('"action":')[-1].split('"action_argument":')[0].replace('"','').replace(',','').strip()
action_argument = processed_str.split('"action_argument":')[-1].split('}')[0].replace('"','').replace(',','').strip()
return action, action_argument, processed_str
elif "'action':" in processed_str and "'action_argument':" in processed_str:
action = processed_str.split("'action':")[-1].split("'action_argument':")[0].replace("'",'').replace(',','').strip()
action_argument = processed_str.split("'action_argument':")[-1].split('}')[0].replace("'",'').strip()
return action, action_argument, processed_str
else:
print("[Error] No valid answer tags found")
return None, None, processed_str
def parse_ground_truth_text_format(ground_truth):
answer_pattern = r'<answer_action>(.*?)</answer_action>'
argument_pattern = r'<answer_argument>(.*?)</answer_argument>'
action_matches = list(re.finditer(answer_pattern, ground_truth, re.DOTALL))
argument_matches = list(re.finditer(argument_pattern, ground_truth, re.DOTALL))
if action_matches and argument_matches:
return action_matches[-1].group(1).strip(), argument_matches[-1].group(1).strip()
else:
return ground_truth, ground_truth
def validate_response_structure(processed_str: str) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
print("\n[Structure Validation]")
validation_passed = True
# Check required tags
tags = {
# 'think_start': ('<think>', 1),
# 'think_end': ('</think>', 1),
'answer_start': ('<answer>', 1),
'answer_end': ('</answer>', 1)
}
positions = {}
for tag_name, (tag_str, expected_count) in tags.items():
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
print(f" {tag_str}: count={count}, position={pos}")
if count != expected_count:
print(f" [Error] {tag_str} appears {count} times (expected {expected_count})")
validation_passed = False
# Verify tag order
# positions['think_start'] > positions['think_end'] or positions['think_end'] > positions['answer_start'] or
if (positions['answer_start'] > positions['answer_end']):
print(" [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
validation_passed = False
else:
print(" Tag sequence validation passed")
return validation_passed
def compute_score(solution_str: str, ground_truth: str, method='strict', format_reward: int = 1, action_reward: float = 1.0, argument_reward: float = 2.0):
"""Computes comprehensive score for model response.
Args:
solution_str: Raw model response string
ground_truth: Dictionary containing ground truth data
method: the method to extract the solution, choices are 'strict' and 'flexible'
format_reward: Points awarded/deducted for format correctness
answer_reward: Points awarded/deducted for answer correctness
Returns:
Total score (sum of format and answer rewards)
"""
print("\n" + "="*80)
print(" Processing New Sample ".center(80, '='))
print(f"[Ground Truth]: {ground_truth}")
action_ground_truth, argument_ground_truth = parse_ground_truth_text_format(ground_truth)
# Extract model answer
# answer_text, processed_str=extract_solution(solution_str=solution_str)
action, action_argument, processed_str=extract_solution(solution_str=solution_str)
print(f"\n[Model Response]\n{processed_str}")
print(f"\n[Processed Model Response Action]\n{action}")
print(f"\n[Processed Model Response Action_Argument]\n{action_argument}")
# Validate response structure
format_correct = validate_response_structure(processed_str)
if format_correct:
format_score = format_reward
else:
format_score = -abs(format_reward)
print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}")
print(f" Format score: {format_score}")
# Validate answer content
answer_score = 0
if action:
print(f"\n[Content Validation]")
print(f" Expected Action: {action_ground_truth}")
print(f" Predicted Action: {action}")
if action.casefold() == action_ground_truth.casefold():
action_score = 1
print(" Action validation: FULL MATCH")
else:
action_score = -1
print(" Action validation: MISMATCH")
if action_argument:
print(f"\n[Content Validation]")
print(f" Expected Argument: {argument_ground_truth}")
print(f" Predicted Argument: {action_argument}")
if action_argument.casefold() == argument_ground_truth.casefold():
argument_score = 2
print(" Argument validation: FULL MATCH")
else:
argument_score = -2
print(" Argument validation: MISMATCH")
total_score = action_score + argument_score
print("\n" + "-"*80)
print(f" Final Score ".center(80, '-'))
print(f" Action: {action_score}")
print(f" Argument: {argument_score}")
print(f" Total: {total_score}")
print("="*80 + "\n")
return total_score