File size: 15,291 Bytes
3c7fd6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a44c20f
3c7fd6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
# This is based heavily on the huggingface APPS metric
import re

# to run the solution files we're using a timing based approach
import signal
import sys

# for capturing the stdout
from io import StringIO
from typing import List, Tuple

# used for testing the code that reads from input
from unittest.mock import patch, mock_open

import numpy as np
from pyext import RuntimeModule
from wrapt_timeout_decorator import timeout as wrapt_timeout
import threading

from ..datasets.schema import assert_test_format_codeforces

from aiflows import logging

log = logging.get_logger(__name__)
lock = threading.Lock()


def evaluate_solution_for_problem(
    candidate_solution,
    hidden_tests_io=None,
    public_tests_io=None,
    timeout=10,
    debug=False,
    add_extra_imports=False,
    allow_truncated_io=False,
):
    """See the readme for the output format of this function."""
    if hidden_tests_io is None:
        hidden_tests_io = []
    if public_tests_io is None:
        public_tests_io = []

    if candidate_solution is None:
        results_dict = {
            "compilation_status": False,
            "compilation_error_message": "No code was provided.",
            "timeout_error": False,
            "hidden_tests_results": [
                {
                    "status": False,
                    "error_message": "No code was provided.",
                    "generated_output": None,
                    "input": test[0],
                    "expected_output": test[1],
                }
                for test in hidden_tests_io
            ],
            "public_tests_results": [
                {
                    "status": False,
                    "error_message": "No code was provided.",
                    "generated_output": None,
                    "input": test[0],
                    "expected_output": test[1],
                }
                for test in public_tests_io
            ],
        }
        return results_dict

    @wrapt_timeout(timeout, use_signals=False)
    def run_tests():
        hidden_tests_results = check_correctness(
            candidate_solution, hidden_tests_io, timeout, debug, add_extra_imports, allow_truncated_io
        )
        public_tests_results = check_correctness(
            candidate_solution, public_tests_io, timeout, debug, add_extra_imports, allow_truncated_io
        )

        return hidden_tests_results, public_tests_results

    try:
        lock.acquire()
        hidden_tests_results, public_tests_results = run_tests()
        timeout_error_occurred = False
        lock.release()
    except BaseException as e:
        lock.release()
        log.info(e)
        hidden_tests_results = {}
        public_tests_results = {}

        hidden_tests_results["compilation_status"] = True
        public_tests_results["compilation_status"] = True
        timeout_error_occurred = True
        hidden_tests_results["error_message"] = "Timeout error."

        hidden_tests_results["results"] = [
            {
                "status": False,
                "error_message": hidden_tests_results["error_message"],
                "generated_output": None,
                "input": test[0],
                "expected_output": test[1],
            }
            for test in hidden_tests_io
        ]
        public_tests_results["results"] = [
            {
                "status": False,
                "error_message": hidden_tests_results["error_message"],
                "generated_output": None,
                "input": test[0],
                "expected_output": test[1],
            }
            for test in public_tests_io
        ]

    # the compilation status shouldn't depend on the tests
    assert hidden_tests_results["compilation_status"] == public_tests_results["compilation_status"]

    results_dict = {
        "compilation_status": hidden_tests_results["compilation_status"],
        "compilation_error_message": hidden_tests_results["error_message"],
        "timeout_error": timeout_error_occurred,
        "hidden_tests_results": hidden_tests_results["results"],
        "public_tests_results": public_tests_results["results"],
    }

    return results_dict


def check_correctness(
    candidate_solution: str,
    tests: List[Tuple[List[str], str]],
    timeout: int = 6000,
    debug=True,
    add_extra_imports=False,
    allow_truncated_io=True,
):
    """
    wrapping the testing code in a global timeout, based on huggingface code
    """

    assert_test_format_codeforces(tests)
    inputs, outputs = [], []
    if len(tests) > 0:
        inputs, outputs = zip(*tests)

    compilation_error, results = run_test(
        candidate_solution, inputs, outputs, timeout, debug, add_extra_imports, allow_truncated_io
    )

    assert len(results) == len(inputs)

    for result in results:
        assert isinstance(result["generated_output"], str) or result["generated_output"] is None
        assert isinstance(result["status"], bool)
        assert isinstance(result["error_message"], str) or result["error_message"] is None
        assert isinstance(result["input"], list)
        assert isinstance(result["expected_output"], str)

    compilation_status = compilation_error == ""
    if compilation_status:
        compilation_error = None

    return {"compilation_status": compilation_status, "error_message": compilation_error, "results": results}


class TimeoutException(Exception):
    pass


def timeout_handler(signum, frame):
    log.info("Alarm went off")
    # return
    raise TimeoutException


signal.signal(signal.SIGALRM, timeout_handler)


# used to capture stdout as a list
# from https://stackoverflow.com/a/16571630/6416660
# alternative use redirect_stdout() from contextlib
class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        # Make closing the StringIO a no-op
        self._stringio.close = lambda x: 1
        return self

    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio  # free up some memory
        sys.stdout = self._stdout


def run_test(code, inputs, outputs, timeout: int = 10, debug=True, add_extra_imports=False, allow_truncated_io=True):
    """
    runs the code and tries to match inputs and outputs
    the scraped testcases may be incomplete
    if allow_truncated_io==True, then we ignore an EOF exception at the end of the generated output
    """
    # Disable functionalities that can make destructive changes to the test.

    results = []

    if isinstance(code, list):
        tmp_test = code
    elif isinstance(code, str):
        tmp_test = code.split("\n")
    else:
        raise AssertionError("code must be provided as list of lines or string with \\n linebreaks.")

    # parse the code into code and imports
    import_lines = []
    future_import_lines = []
    code_lines = []
    for x in tmp_test:
        if (not x.startswith("from ")) and (not x.startswith("import ")):
            code_lines.append("\t" + x + "\n")
        else:
            if "__future__" in x:
                future_import_lines.append(x + "\n")
            else:
                import_lines.append(x + "\n")

    # assemble a new solution snippet which wraps the generated solution in a function code()
    new_test = "stdin = sys.stdin\nstdout = sys.stdout\n"
    new_test += '__name__="__main__"\n'
    new_test += "def code():\n"
    new_test += "\tstdin = sys.stdin\n\tstdout = sys.stdout\n"

    for line in code_lines:
        new_test += line

    sol = "\n".join(future_import_lines)
    sol += "import sys\n"
    if add_extra_imports:
        sol += "import time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
    sol += "\n".join(import_lines) + "\n" + new_test

    if debug:
        log.info(f"sol = {sol}")
    method_name = "code"
    signal.alarm(timeout)

    # convert the solution snippet into a pyext runtime module
    sol_module = None
    try:
        sol_module = RuntimeModule.from_string("tmp_sol", "", sol)
        signal.alarm(0)
    except Exception as e:
        signal.alarm(0)
        if debug:
            log.info(f"type 1 compilation error = {e}")
        for inp, out in zip(inputs, outputs):
            # consider all inputs failed
            results.append(
                {
                    "status": False,
                    "input": inp,
                    "expected_output": out,
                    "generated_output": None,
                    "error_message": repr(e),
                }
            )
        return repr(e), results

    assert sol_module is not None
    signal.alarm(0)

    try:
        method = getattr(sol_module, method_name)  # get_attr second arg must be str
    except:
        signal.alarm(0)
        e = sys.exc_info()
        log.info(f"unable to get function error = {e}")

        for inp, out in zip(inputs, outputs):
            # consider all inputs failed
            results.append(
                {
                    "status": False,
                    "input": inp,
                    "expected_output": out,
                    "generated_output": None,
                    "error_message": repr(e),
                }
            )
        return repr(e), results

    # go through all tests, call our runtime module with the inputs
    # then compare with the reference output
    for index, (test_input, reference_output) in enumerate(zip(inputs, outputs)):

        result_object = {
            "input": test_input,
            "expected_output": reference_output,
        }

        # if the last token of the input is truncated and marked with "..." we delete it
        input_truncated = False
        if "".join(test_input).strip().endswith("...") and allow_truncated_io:
            test_input = test_input[:-1]
            input_truncated = True

        # sometimes the last input token is ""
        # if len(test_input)>0:
        #    if test_input[-1]=="":
        #        test_input = test_input[:-1]

        error_code = None
        with Capturing() as generated_output:
            try:
                call_method(method, test_input)
                # reset the alarm
                signal.alarm(0)
            except Exception as e:
                # runtime error or took too long
                signal.alarm(0)
                error_code = e
                if debug:
                    log.info(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
            signal.alarm(0)

        # in some cases we run into truncated tests
        # in such cases we expect the error code to be None, EOFError or ValueError
        if (
            (input_truncated or reference_output.strip().endswith("..."))
            and allow_truncated_io
            and (error_code is None or isinstance(error_code, EOFError) or isinstance(error_code, ValueError))
        ):

            generated_output = generated_output[:-1]
            reference_output = reference_output.rstrip("...")
            if len(generated_output) == 0:
                # no output left, we pass by default
                result_object.update(
                    **{
                        "status": True,
                        "generated_output": "\n".join(generated_output),
                        "error_message": None,
                    }
                )
                results.append(result_object)
            else:
                result_object.update(
                    **{
                        "status": string_compare(generated_output, reference_output, True),
                        "generated_output": "\n".join(generated_output),
                        "error_message": None,
                    }
                )
                results.append(result_object)

        # if the input and output are not truncated, we don't allow any errors
        elif error_code is not None:
            result_object.update(**{"status": False, "generated_output": None, "error_message": repr(error_code)})
            results.append(result_object)
        # finally, if there are no errors, we expect the output to match the reference output
        else:
            # the execution went well, let's compare the outputs
            result_object.update(
                **{
                    "status": string_compare(generated_output, reference_output, False),
                    "generated_output": "\n".join(generated_output),
                    "error_message": None,
                }
            )
            results.append(result_object)

    return "", results


def string_compare(candidate, correct, truncate_output=False, floating_point_accuracy=0.01):
    candidate = [o.strip().lower() for o in candidate]
    correct = correct.strip().lower()

    # normalize whitespace
    candidate = "\n".join(candidate)
    candidate = re.sub("\s+", " ", candidate).strip()
    correct = re.sub("\s+", " ", correct).strip()

    # split into individual tokens
    candidate = candidate.split(" ")
    correct = correct.split(" ")

    # some tests may be truncated, if we allow this we don't enforce equal length of inputs/outputs
    if not truncate_output:
        if not len(candidate) == len(correct):
            return False

    # if we allow truncated io, the last token of the output may have been corrupted
    if truncate_output:
        correct = correct[:-1]

    # when zip is used for lists of unequal length it will give as many pairs as there are items in the shorter list
    for left, right in zip(candidate, correct):
        if left == right:
            continue

        try:
            int_left = int(left)
            int_right = int(right)
            if int_left == int_right:
                continue
        except ValueError:
            pass

        try:
            float_left = float(left)
            float_right = float(right)
            if np.abs(float_left - float_right) < floating_point_accuracy:
                continue
        except ValueError:
            pass

        return False

    return True


def call_method(method, inputs):
    if isinstance(inputs, list):
        inputs = "\n".join(inputs)

    inputs_line_iterator = iter(inputs.split("\n"))

    # sys.setrecursionlimit(10000)

    # @patch('builtins.input', side_effect=inputs.split("\n"))
    @patch("builtins.open", mock_open(read_data=inputs))
    @patch("sys.stdin", StringIO(inputs))
    @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
    @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
    @patch("sys.stdin.read", lambda *args: inputs)
    # @patch('sys.stdout.write', print)
    def _inner_call_method(_method):
        try:
            return _method()
        except SystemExit as e:
            pass
        finally:
            pass

    return _inner_call_method(method)