File size: 8,136 Bytes
decae83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import re
import xml.etree.ElementTree as ET
import math
from typing import Dict, Tuple
class SmoothXMLRewardEvaluator:
"""
Smooth, differentiable XML reward function.
Returns continuous scores between 0.0 and 1.0 for all components.
"""
def __init__(self):
self.composite_weights = {
'structure': 0.30,
'xml_valid': 0.25,
'order': 0.25,
'confidence': 0.18,
'distribution': 0.02
}
def evaluate_structure(self, output: str) -> float:
"""
Check structure (tags present).
"""
required_tags = [
r"<query_analysis>.*</query_analysis>",
r"<domain ambiguous=\"(true|false)\">.*</domain>",
r"<intent ambiguous=\"(true|false)\">.*</intent>",
r"<candidate confidence=\"(?:0\.\d|1\.0)\">.*?</candidate>",
r"<insufficient_context>(true|false)</insufficient_context>",
r"<rephrased>(true|false)</rephrased>",
r"<rephrased_query>.*?</rephrased_query>"
]
hits = sum(bool(re.search(tag, output, re.DOTALL)) for tag in required_tags)
return hits / len(required_tags)
def evaluate_xml_validity(self, output: str) -> float:
"""
Check XML validity.
"""
try:
ET.fromstring(output.strip())
return 1.0
except ET.ParseError:
return 0.0
def evaluate_order(self, output: str) -> float:
"""
Check order of elements (fraction of correct sequence for continuity).
"""
sequence = ["<query_analysis>", "<domain", "<intent", "<insufficient_context>", "<rephrased>", "<rephrased_query>"]
last_index = -1
correct_count = 0
for tag in sequence:
match = re.search(tag, output)
if match:
idx = match.start()
if idx > last_index:
correct_count += 1
last_index = idx
return correct_count / len(sequence)
def evaluate_confidence(self, output: str) -> float:
"""
Check confidence correctness.
"""
score = 0.0
blocks = ['domain', 'intent']
for block_name in blocks:
block_match = re.search(f"<{block_name} ambiguous=\"(true|false)\">.*?</{block_name}>", output, re.DOTALL)
if not block_match:
continue
try:
is_ambiguous = block_match.group(1) == "true"
confidences = [float(c) for c in re.findall(r"<candidate confidence=\"(0\.\d|1\.0)\">", block_match.group(0))]
if not confidences:
continue
if is_ambiguous:
target_sum = 1.0
actual_sum = sum(confidences)
# Continuous score: closer to 1.0 sum → higher reward
score += max(0, 1 - abs(actual_sum - target_sum))
else:
if len(confidences) == 1:
score += 1.0 - abs(confidences[0] - 1.0)
except (ValueError, AttributeError):
continue
return score / len(blocks) if blocks else 0.0
def evaluate_distribution(self, output: str) -> float:
"""
Check confidence distribution (entropy-based, normalized).
"""
total_score = 0.0
blocks_evaluated = 0
for block_name in ['domain', 'intent']:
block_match = re.search(f'<{block_name} ambiguous="(true|false)".*?</{block_name}>', output, re.DOTALL)
if not block_match:
continue
is_ambiguous = block_match.group(1) == "true"
confidences = [float(c) for c in re.findall(r'confidence="([01]\.\d)"', block_match.group(0))]
if not confidences:
continue
blocks_evaluated += 1
if is_ambiguous and len(confidences) > 1:
# entropy normalized
entropy = -sum(p * math.log(p + 1e-8) for p in confidences)
max_entropy = math.log(len(confidences))
total_score += entropy / max_entropy if max_entropy > 0 else 0
else:
total_score += 1.0
return total_score / blocks_evaluated if blocks_evaluated > 0 else 0.0
def structural_penalty(self, output: str) -> float:
"""
Compute structural penalty (soft, subtractive).
"""
penalty = 0.0
try:
root = ET.fromstring(output.strip())
except ET.ParseError:
return 1.0
for tag in root.findall(".//domain") + root.findall(".//intent"):
if "ambiguous" not in tag.attrib:
penalty += 0.1
for cand in root.findall(".//candidate"):
if "confidence" not in cand.attrib:
penalty += 0.05
return min(1.0, penalty)
def answering_penalty(self, output: str) -> float:
"""
Compute answering penalty.
"""
stripped = output.strip()
if stripped.startswith('<query_analysis>') and stripped.endswith('</query_analysis>'):
return 0.0
return 1.0
def evaluate(self, output: str) -> Tuple[float, Dict[str, float]]:
"""
Return a composite evaluation, using intermediary checks.
"""
ap = self.answering_penalty(output)
if ap > 0:
return 0.0, {"answering_penalty": ap}
components = {
"structure": self.evaluate_structure(output),
"xml_valid": self.evaluate_xml_validity(output),
"order": self.evaluate_order(output),
"confidence": self.evaluate_confidence(output),
"distribution": self.evaluate_distribution(output)
}
sp = self.structural_penalty(output)
components["structural_penalty"] = sp
# Subtractive penalty instead of multiplicative
reward = sum(self.composite_weights[k] * v for k, v in components.items() if k != "structural_penalty")
final_score = max(0.0, reward - 0.5 * sp)
components["final_score"] = final_score
return final_score, components
def get_detailed_analysis(self, output: str) -> Dict[str, any]:
"""
Provides detailed analysis of the XML output including all component scores.
"""
final_score, component_scores = self.evaluate(output)
return {
'final_score': final_score,
'component_scores': component_scores,
'weights_used': self.composite_weights,
'recommendations': self._get_recommendations(component_scores)
}
def _get_recommendations(self, component_scores: Dict[str, float]) -> list:
"""
Generate improvement recommendations based on component scores.
"""
recommendations = []
if component_scores.get('structure', 1.0) < 0.85:
recommendations.append("Improve XML structure - ensure all required tags are present")
if component_scores.get('xml_valid', 1.0) < 1.0:
recommendations.append("Fix XML syntax errors - ensure proper tag closing and nesting")
if component_scores.get('order', 1.0) < 0.9:
recommendations.append("Reorder XML elements to match expected structure")
if component_scores.get('confidence', 1.0) < 0.8:
recommendations.append("Fix confidence values - ensure they sum to 1.0 for ambiguous cases")
if component_scores.get('distribution', 1.0) < 0.8:
recommendations.append("Improve confidence distribution balance for ambiguous classifications")
if component_scores.get('structural_penalty', 0.0) > 0.2:
recommendations.append("Address structural issues - missing attributes or malformed tags")
return recommendations
|