File size: 10,190 Bytes
b43abc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import sys
import re
import os
from typing import List
# from .sonnet_eval import sonnet_errors
from .execute_code import execute_code_with_timeout


def clean_output_for_arithmetic(output: str) -> str:
    """
    Clean the output for arithmetic problems.

    Args:
        output (str): The output to clean.

    Returns:
        str: The cleaned output.
    """
    if "=" in output:
        output = output.split("=")[1].strip()
    if " is" in output:
        output = output.split(" is")[1].strip()
    if " equals" in output:
        output = output.split(" equals")[1].strip()
    if " evaluates to" in output:
        output = output.split(" evaluates to")[1].strip()
    if " is equal to" in output:
        output = output.split(" is equal to")[1].strip()
    return output


def clean_output_for_GameOf24(output: str) -> str:
    """
    Clean the output for GameOf24 problems.
    """
    if "=" in output:
        output = output.split("=")[0].strip()
    if "is" in output:
        output = output.split("is")[1].strip()
    if "equals" in output:
        output = output.split("equals")[0].strip()
    if "evaluates to" in output:
        output = output.split("evaluates to")[0].strip()
    return output


def eval_for_GameOf24(input: str, output: str) -> bool:
    """
    Given an input and output, check if the output is correct and follows the rules of the game.
    """
    clean_output = output

    clean_output = clean_output_for_GameOf24(output)
    clean_output = clean_output.replace("x", "*").strip()
    clean_output = clean_output.replace("×", "*").strip()
    clean_output = clean_output.replace("÷", "/").strip()
    
    try:
        # Get the value of the expression using eval
        value = eval(clean_output)
        if not (abs(value - 24) < 1e-3):
            return False
        # Split the input and output digits by space
        input_digits = input.split(" ")
        # Replace the following symbols with space
        replacements = ["+", "-", "*", "/", "÷", "(", ")"]
        for symbol in replacements:
            clean_output = clean_output.replace(symbol, " ")
        # Replace multiple spaces with single space
        clean_output = re.sub(" +", " ", clean_output)
        clean_output = clean_output.strip()
        output_digits = clean_output.split(" ")
        # Sort the digits
        input_digits.sort()
        output_digits.sort()
        # Check if the digits are the same
        if input_digits != output_digits:
            return False
        return True
    except Exception as e:
        return False


def remove_punctuation(output: str) -> str:
    """
    Remove punctuation from the output.
    """
    markers = [",", ";", ":", ".", '"']
    for marker in markers:
        output = output.replace(marker, "")
    return output


def convert_newline_to_space(output: str) -> str:
    """
    Convert newline to space.
    """
    output = output.replace("\n", " ")
    return output


def eval_for_exact_matching_with_no_punctuation(
    output: str, target: str
) -> bool:
    """
    Evaluate if the output is exactly the same as the target.
    """
    output = remove_punctuation(output)
    output = convert_newline_to_space(output)
    if target == output:
        return True
    return False


def eval_for_softmatch(input: str, output: str, target: str) -> bool:
    """
    Evaluate if the output is a soft match of the target.
    """
    output = remove_punctuation(output)
    if target in output:
        return True
    return False


def eval_for_CheckmateInOne(input: str, output: str, target: str) -> bool:
    """
    Evaluate if the output is a checkmate in one.
    """
    output = output.strip()
    if output[-1] == "#":
        output = output.split(" ")[-1].strip()
    # Based on the input, determine the number of the last move
    last_move = input.split(".")[-1].strip()
    move_idx = input.split(".")[-2].split(" ")[-1].strip()
    # If the last move is an empty string, then the last move is white; otherwise, it is black
    if last_move == "":
        last_move = "White"
    else:
        last_move = "Black"
    next_move_idx = str(int(move_idx) + 1)
    if not (next_move_idx in output):
        if target in output or (target[1] == 'x' and (target[0] + target[2:]) in output):
            return True
    else:
        output = output.split(next_move_idx)[0].strip()
        if target in output or (target[1] == 'x' and (target[0] + target[2:]) in output):
            return True
    return False


def eval_equation_balancer(input: str, output: str, target: str) -> bool:
    """
    Evaluate if the output is a valid equation balancer.
    """
    output = output.split("=")[0].strip()
    target_val = target.split("=")[1].strip()
    target = target.split("=")[0].strip()
    # First make sure that the output has the same format as the target (when operators (e.g., +, -, *, /) are removed)
    output_nums = output.replace("+", "").replace("-", "").replace("*", "").replace("/", "").replace(" ", "").strip()
    target_nums = target.replace("+", "").replace("-", "").replace("*", "").replace("/", "").replace(" ", "").strip()
    if output_nums != target_nums:
        return False
    # Now, evaluate the output and target
    try:
        output_value = eval(output)
        if abs(output_value - eval(target_val)) < 1e-6:
            return True
    except Exception as e:
        return False
    return False


def eval_for_multiple_choice(input_text: str, final_answer: str, target: str) -> bool:
    """
    Evaluates if the final answer matches the target using pattern matching.
    
    Args:
        input_text (str): The original question text including options
        final_answer (str): The model's answer
        target (str): The correct answer
    
    Returns:
        bool: True if answer is correct, False otherwise
    """
    # Handle empty or None inputs
    if not final_answer or not target:
        return False
    
    def clean_text(text: str) -> str:
        if not text:
            return ""
        return text.lower().strip().replace('`', '').replace('(', '').replace(')', '')
    
    def extract_option_text(input_text: str, option_letter: str) -> str:
        try:
            # Try different formats of options sections
            options_section = ""
            if 'options:' in input_text.lower():
                options_section = input_text.lower().split('options:')[1].strip()
            elif 'choices:' in input_text.lower():
                options_section = input_text.lower().split('choices:')[1].strip()
            
            if not options_section:
                # Try to find options in the format (A) text, (B) text
                lines = input_text.lower().split('\n')
                for i, line in enumerate(lines):
                    if line.strip().startswith(f'({option_letter})') or line.strip().startswith(f'{option_letter})'):
                        return line.split(')', 1)[1].strip()
                
            # Process the options section if found
            for line in options_section.split('\n'):
                line = line.strip()
                if line.startswith(f'({option_letter})') or line.startswith(f'{option_letter})'):
                    return line.split(')', 1)[1].strip()
                # Handle options like "A. text" format
                if line.startswith(f'{option_letter}.'):
                    return line.split('.', 1)[1].strip()
        except:
            return ''
        return ''

    # Full option match (A), (B), etc. (e.g., (A) == (A))
    if final_answer == target:
        return True

    # Clean and normalize inputs
    clean_answer = clean_text(final_answer)
    clean_target = clean_text(target)
    
    # Handle target formats: (A), A), A, etc.
    target_letter = ""
    if len(clean_target) == 1:
        target_letter = clean_target
    elif clean_target.endswith(')'):
        target_letter = clean_target[-2]
    else:
        # Extract the last character if it's a letter a-d or A-D
        last_char = clean_target[-1]
        if last_char in 'abcd':
            target_letter = last_char
    
    # Direct letter match (a, b, c, d)
    if len(clean_answer) == 1 and clean_answer in 'abcd' and clean_answer == target_letter:
        return True
    
    # Handle answer formats like "A" or "A."
    if clean_answer.startswith(target_letter) and (len(clean_answer) == 1 or 
                                                  (len(clean_answer) == 2 and clean_answer[1] == '.')):
        return True
    
    # Handle answer formats like "Option A" or "Answer is A"
    if clean_answer.endswith(target_letter) and (clean_answer[-2:] == f" {target_letter}" or 
                                               clean_answer[-3:] == f" {target_letter}."):
        return True
    
    # Text content match - check if the target option text is in the answer
    target_text = extract_option_text(input_text, target_letter)
    
    if target_text and target_text in clean_answer:
        return True
    
    # Handle numerical answers (if target is a number and answer contains that number)
    if target_letter.isdigit() and target_letter in clean_answer:
        return True
        
    return False


def eval_for_pyton_programming_puzzles(input: str, output: str) -> bool:
    """
    Evaluate if the output is a valid Python programming puzzle solution.
    """
    if "```python" in output:
        output = output.split("```python")[-1].strip()
        output = output.split("```")[0].strip()

    if "def sat" in output:
        if "from typing" not in output:
            output = f"from typing import *\n{output}"
        code = f"{output}\nanswer = solution()\nprint(sat(answer))"
    else:
        code = f"from typing import *\n{input}\n{output}\nanswer = solution()\nprint(sat(answer))"
    
    code = code.replace("List[", "list[")
    eval_bool = execute_code_with_timeout(code, timeout=3)

    if "NameError: name 'answer' is not defined" in eval_bool:
        print(f"Eval bool: {eval_bool}")
        print(f"Code:\n{code}")
        print("*" * 100)
    if "True" in eval_bool:
        return True
    return False