File size: 8,136 Bytes
decae83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import re
import xml.etree.ElementTree as ET
import math
from typing import Dict, Tuple


class SmoothXMLRewardEvaluator:
    """

    Smooth, differentiable XML reward function.

    Returns continuous scores between 0.0 and 1.0 for all components.

    """
    
    def __init__(self):
        self.composite_weights = {
            'structure': 0.30,
            'xml_valid': 0.25,
            'order': 0.25,
            'confidence': 0.18,
            'distribution': 0.02
        }

    def evaluate_structure(self, output: str) -> float:
        """

        Check structure (tags present).

        """
        required_tags = [
            r"<query_analysis>.*</query_analysis>",
            r"<domain ambiguous=\"(true|false)\">.*</domain>",
            r"<intent ambiguous=\"(true|false)\">.*</intent>",
            r"<candidate confidence=\"(?:0\.\d|1\.0)\">.*?</candidate>",
            r"<insufficient_context>(true|false)</insufficient_context>",
            r"<rephrased>(true|false)</rephrased>",
            r"<rephrased_query>.*?</rephrased_query>"
        ]
        hits = sum(bool(re.search(tag, output, re.DOTALL)) for tag in required_tags)
        return hits / len(required_tags)

    def evaluate_xml_validity(self, output: str) -> float:
        """

        Check XML validity.

        """
        try:
            ET.fromstring(output.strip())
            return 1.0
        except ET.ParseError:
            return 0.0

    def evaluate_order(self, output: str) -> float:
        """

        Check order of elements (fraction of correct sequence for continuity).

        """
        sequence = ["<query_analysis>", "<domain", "<intent", "<insufficient_context>", "<rephrased>", "<rephrased_query>"]
        last_index = -1
        correct_count = 0
        for tag in sequence:
            match = re.search(tag, output)
            if match:
                idx = match.start()
                if idx > last_index:
                    correct_count += 1
                    last_index = idx
        return correct_count / len(sequence)

    def evaluate_confidence(self, output: str) -> float:
        """

        Check confidence correctness.

        """
        score = 0.0
        blocks = ['domain', 'intent']
        for block_name in blocks:
            block_match = re.search(f"<{block_name} ambiguous=\"(true|false)\">.*?</{block_name}>", output, re.DOTALL)
            if not block_match:
                continue
            try:
                is_ambiguous = block_match.group(1) == "true"
                confidences = [float(c) for c in re.findall(r"<candidate confidence=\"(0\.\d|1\.0)\">", block_match.group(0))]
                if not confidences:
                    continue
                if is_ambiguous:
                    target_sum = 1.0
                    actual_sum = sum(confidences)
                    # Continuous score: closer to 1.0 sum → higher reward
                    score += max(0, 1 - abs(actual_sum - target_sum))
                else:
                    if len(confidences) == 1:
                        score += 1.0 - abs(confidences[0] - 1.0)
            except (ValueError, AttributeError):
                continue
        return score / len(blocks) if blocks else 0.0

    def evaluate_distribution(self, output: str) -> float:
        """

        Check confidence distribution (entropy-based, normalized).

        """
        total_score = 0.0
        blocks_evaluated = 0
        for block_name in ['domain', 'intent']:
            block_match = re.search(f'<{block_name} ambiguous="(true|false)".*?</{block_name}>', output, re.DOTALL)
            if not block_match:
                continue
            is_ambiguous = block_match.group(1) == "true"
            confidences = [float(c) for c in re.findall(r'confidence="([01]\.\d)"', block_match.group(0))]
            if not confidences:
                continue
            blocks_evaluated += 1
            if is_ambiguous and len(confidences) > 1:
                # entropy normalized
                entropy = -sum(p * math.log(p + 1e-8) for p in confidences)
                max_entropy = math.log(len(confidences))
                total_score += entropy / max_entropy if max_entropy > 0 else 0
            else:
                total_score += 1.0
        return total_score / blocks_evaluated if blocks_evaluated > 0 else 0.0

    def structural_penalty(self, output: str) -> float:
        """

        Compute structural penalty (soft, subtractive).

        """
        penalty = 0.0
        try:
            root = ET.fromstring(output.strip())
        except ET.ParseError:
            return 1.0
        for tag in root.findall(".//domain") + root.findall(".//intent"):
            if "ambiguous" not in tag.attrib:
                penalty += 0.1
        for cand in root.findall(".//candidate"):
            if "confidence" not in cand.attrib:
                penalty += 0.05
        return min(1.0, penalty)

    def answering_penalty(self, output: str) -> float:
        """

        Compute answering penalty.

        """
        stripped = output.strip()
        if stripped.startswith('<query_analysis>') and stripped.endswith('</query_analysis>'):
            return 0.0
        return 1.0

    def evaluate(self, output: str) -> Tuple[float, Dict[str, float]]:
        """

        Return a composite evaluation, using intermediary checks.

        """
        ap = self.answering_penalty(output)
        if ap > 0:
            return 0.0, {"answering_penalty": ap}

        components = {
            "structure": self.evaluate_structure(output),
            "xml_valid": self.evaluate_xml_validity(output),
            "order": self.evaluate_order(output),
            "confidence": self.evaluate_confidence(output),
            "distribution": self.evaluate_distribution(output)
        }

        sp = self.structural_penalty(output)
        components["structural_penalty"] = sp

        # Subtractive penalty instead of multiplicative
        reward = sum(self.composite_weights[k] * v for k, v in components.items() if k != "structural_penalty")
        final_score = max(0.0, reward - 0.5 * sp)

        components["final_score"] = final_score
        return final_score, components

    def get_detailed_analysis(self, output: str) -> Dict[str, any]:
        """

        Provides detailed analysis of the XML output including all component scores.

        """
        final_score, component_scores = self.evaluate(output)
        
        return {
            'final_score': final_score,
            'component_scores': component_scores,
            'weights_used': self.composite_weights,
            'recommendations': self._get_recommendations(component_scores)
        }
    
    def _get_recommendations(self, component_scores: Dict[str, float]) -> list:
        """

        Generate improvement recommendations based on component scores.

        """
        recommendations = []

        if component_scores.get('structure', 1.0) < 0.85:
            recommendations.append("Improve XML structure - ensure all required tags are present")

        if component_scores.get('xml_valid', 1.0) < 1.0:
            recommendations.append("Fix XML syntax errors - ensure proper tag closing and nesting")

        if component_scores.get('order', 1.0) < 0.9:
            recommendations.append("Reorder XML elements to match expected structure")

        if component_scores.get('confidence', 1.0) < 0.8:
            recommendations.append("Fix confidence values - ensure they sum to 1.0 for ambiguous cases")

        if component_scores.get('distribution', 1.0) < 0.8:
            recommendations.append("Improve confidence distribution balance for ambiguous classifications")

        if component_scores.get('structural_penalty', 0.0) > 0.2:
            recommendations.append("Address structural issues - missing attributes or malformed tags")

        return recommendations