|
|
|
import re |
|
|
|
|
|
import signal |
|
import sys |
|
|
|
|
|
from io import StringIO |
|
from typing import List, Tuple |
|
|
|
|
|
from unittest.mock import patch, mock_open |
|
|
|
import numpy as np |
|
from pyext import RuntimeModule |
|
from wrapt_timeout_decorator import timeout as wrapt_timeout |
|
import threading |
|
|
|
from src.datasets.schema import assert_test_format_codeforces |
|
|
|
import logging |
|
|
|
log = logging.getLogger() |
|
lock = threading.Lock() |
|
|
|
|
|
def evaluate_solution_for_problem( |
|
candidate_solution, |
|
hidden_tests_io=None, |
|
public_tests_io=None, |
|
timeout=10, |
|
debug=False, |
|
add_extra_imports=False, |
|
allow_truncated_io=False, |
|
): |
|
with lock: |
|
"""See the readme for the output format of this function.""" |
|
if hidden_tests_io is None: |
|
hidden_tests_io = [] |
|
if public_tests_io is None: |
|
public_tests_io = [] |
|
|
|
if candidate_solution is None: |
|
results_dict = { |
|
"compilation_status": False, |
|
"compilation_error_message": "No code was provided.", |
|
"timeout_error": False, |
|
"hidden_tests_results": [ |
|
{ |
|
"status": False, |
|
"error_message": "No code was provided.", |
|
"generated_output": None, |
|
"input": test[0], |
|
"expected_output": test[1], |
|
} |
|
for test in hidden_tests_io |
|
], |
|
"public_tests_results": [ |
|
{ |
|
"status": False, |
|
"error_message": "No code was provided.", |
|
"generated_output": None, |
|
"input": test[0], |
|
"expected_output": test[1], |
|
} |
|
for test in public_tests_io |
|
], |
|
} |
|
return results_dict |
|
|
|
@wrapt_timeout(timeout, use_signals=False) |
|
def run_tests(): |
|
hidden_tests_results = check_correctness( |
|
candidate_solution, hidden_tests_io, timeout, debug, add_extra_imports, allow_truncated_io |
|
) |
|
public_tests_results = check_correctness( |
|
candidate_solution, public_tests_io, timeout, debug, add_extra_imports, allow_truncated_io |
|
) |
|
|
|
return hidden_tests_results, public_tests_results |
|
|
|
try: |
|
hidden_tests_results, public_tests_results = run_tests() |
|
timeout_error_occurred = False |
|
except BaseException as e: |
|
log.info(e) |
|
hidden_tests_results = {} |
|
public_tests_results = {} |
|
|
|
hidden_tests_results["compilation_status"] = True |
|
public_tests_results["compilation_status"] = True |
|
timeout_error_occurred = True |
|
hidden_tests_results["error_message"] = "Timeout error." |
|
|
|
hidden_tests_results["results"] = [ |
|
{ |
|
"status": False, |
|
"error_message": hidden_tests_results["error_message"], |
|
"generated_output": None, |
|
"input": test[0], |
|
"expected_output": test[1], |
|
} |
|
for test in hidden_tests_io |
|
] |
|
public_tests_results["results"] = [ |
|
{ |
|
"status": False, |
|
"error_message": hidden_tests_results["error_message"], |
|
"generated_output": None, |
|
"input": test[0], |
|
"expected_output": test[1], |
|
} |
|
for test in public_tests_io |
|
] |
|
|
|
|
|
assert hidden_tests_results["compilation_status"] == public_tests_results["compilation_status"] |
|
|
|
results_dict = { |
|
"compilation_status": hidden_tests_results["compilation_status"], |
|
"compilation_error_message": hidden_tests_results["error_message"], |
|
"timeout_error": timeout_error_occurred, |
|
"hidden_tests_results": hidden_tests_results["results"], |
|
"public_tests_results": public_tests_results["results"], |
|
} |
|
|
|
return results_dict |
|
|
|
|
|
def check_correctness( |
|
candidate_solution: str, |
|
tests: List[Tuple[List[str], str]], |
|
timeout: int = 6000, |
|
debug=True, |
|
add_extra_imports=False, |
|
allow_truncated_io=True, |
|
): |
|
""" |
|
wrapping the testing code in a global timeout, based on huggingface code |
|
""" |
|
|
|
assert_test_format_codeforces(tests) |
|
inputs, outputs = [], [] |
|
if len(tests) > 0: |
|
inputs, outputs = zip(*tests) |
|
|
|
compilation_error, results = run_test( |
|
candidate_solution, inputs, outputs, timeout, debug, add_extra_imports, allow_truncated_io |
|
) |
|
|
|
assert len(results) == len(inputs) |
|
|
|
for result in results: |
|
assert isinstance(result["generated_output"], str) or result["generated_output"] is None |
|
assert isinstance(result["status"], bool) |
|
assert isinstance(result["error_message"], str) or result["error_message"] is None |
|
assert isinstance(result["input"], list) |
|
assert isinstance(result["expected_output"], str) |
|
|
|
compilation_status = compilation_error == "" |
|
if compilation_status: |
|
compilation_error = None |
|
|
|
return {"compilation_status": compilation_status, "error_message": compilation_error, "results": results} |
|
|
|
|
|
class TimeoutException(Exception): |
|
pass |
|
|
|
|
|
def timeout_handler(signum, frame): |
|
log.info("alarm went off") |
|
|
|
raise TimeoutException |
|
|
|
|
|
signal.signal(signal.SIGALRM, timeout_handler) |
|
|
|
|
|
|
|
|
|
|
|
class Capturing(list): |
|
def __enter__(self): |
|
self._stdout = sys.stdout |
|
sys.stdout = self._stringio = StringIO() |
|
|
|
self._stringio.close = lambda x: 1 |
|
return self |
|
|
|
def __exit__(self, *args): |
|
self.extend(self._stringio.getvalue().splitlines()) |
|
del self._stringio |
|
sys.stdout = self._stdout |
|
|
|
|
|
def run_test(code, inputs, outputs, timeout: int = 6000, debug=True, add_extra_imports=False, allow_truncated_io=True): |
|
""" |
|
runs the code and tries to match inputs and outputs |
|
the scraped testcases may be incomplete |
|
if allow_truncated_io==True, then we ignore an EOF exception at the end of the generated output |
|
""" |
|
|
|
|
|
results = [] |
|
|
|
if isinstance(code, list): |
|
tmp_test = code |
|
elif isinstance(code, str): |
|
tmp_test = code.split("\n") |
|
else: |
|
raise AssertionError("code must be provided as list of lines or string with \\n linebreaks.") |
|
|
|
|
|
import_lines = [] |
|
future_import_lines = [] |
|
code_lines = [] |
|
for x in tmp_test: |
|
if (not x.startswith("from ")) and (not x.startswith("import ")): |
|
code_lines.append("\t" + x + "\n") |
|
else: |
|
if "__future__" in x: |
|
future_import_lines.append(x + "\n") |
|
else: |
|
import_lines.append(x + "\n") |
|
|
|
|
|
new_test = "stdin = sys.stdin\nstdout = sys.stdout\n" |
|
new_test += '__name__="__main__"\n' |
|
new_test += "def code():\n" |
|
for line in code_lines: |
|
new_test += line |
|
|
|
sol = "\n".join(future_import_lines) |
|
sol += "import sys\n" |
|
if add_extra_imports: |
|
sol += "import time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" |
|
sol += "\n".join(import_lines) + "\n" + new_test |
|
|
|
if debug: |
|
log.info(f"sol = {sol}") |
|
method_name = "code" |
|
signal.alarm(timeout) |
|
|
|
|
|
sol_module = None |
|
try: |
|
sol_module = RuntimeModule.from_string("tmp_sol", "", sol) |
|
signal.alarm(0) |
|
except Exception as e: |
|
signal.alarm(0) |
|
if debug: |
|
log.info(f"type 1 compilation error = {e}") |
|
for inp, out in zip(inputs, outputs): |
|
|
|
results.append( |
|
{ |
|
"status": False, |
|
"input": inp, |
|
"expected_output": out, |
|
"generated_output": None, |
|
"error_message": repr(e), |
|
} |
|
) |
|
return repr(e), results |
|
|
|
assert sol_module is not None |
|
signal.alarm(0) |
|
|
|
try: |
|
method = getattr(sol_module, method_name) |
|
except: |
|
signal.alarm(0) |
|
e = sys.exc_info() |
|
log.info(f"unable to get function error = {e}") |
|
|
|
for inp, out in zip(inputs, outputs): |
|
|
|
results.append( |
|
{ |
|
"status": False, |
|
"input": inp, |
|
"expected_output": out, |
|
"generated_output": None, |
|
"error_message": repr(e), |
|
} |
|
) |
|
return repr(e), results |
|
|
|
|
|
|
|
for index, (test_input, reference_output) in enumerate(zip(inputs, outputs)): |
|
|
|
result_object = { |
|
"input": test_input, |
|
"expected_output": reference_output, |
|
} |
|
|
|
|
|
input_truncated = False |
|
if "".join(test_input).strip().endswith("...") and allow_truncated_io: |
|
test_input = test_input[:-1] |
|
input_truncated = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
error_code = None |
|
with Capturing() as generated_output: |
|
try: |
|
call_method(method, test_input) |
|
|
|
signal.alarm(0) |
|
except Exception as e: |
|
|
|
signal.alarm(0) |
|
error_code = e |
|
if debug: |
|
log.info(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") |
|
signal.alarm(0) |
|
|
|
|
|
|
|
if ( |
|
(input_truncated or reference_output.strip().endswith("...")) |
|
and allow_truncated_io |
|
and (error_code is None or isinstance(error_code, EOFError) or isinstance(error_code, ValueError)) |
|
): |
|
|
|
generated_output = generated_output[:-1] |
|
reference_output = reference_output.rstrip("...") |
|
if len(generated_output) == 0: |
|
|
|
result_object.update( |
|
**{ |
|
"status": True, |
|
"generated_output": "\n".join(generated_output), |
|
"error_message": None, |
|
} |
|
) |
|
results.append(result_object) |
|
else: |
|
result_object.update( |
|
**{ |
|
"status": string_compare(generated_output, reference_output, True), |
|
"generated_output": "\n".join(generated_output), |
|
"error_message": None, |
|
} |
|
) |
|
results.append(result_object) |
|
|
|
|
|
elif error_code is not None: |
|
result_object.update(**{"status": False, "generated_output": None, "error_message": repr(error_code)}) |
|
results.append(result_object) |
|
|
|
else: |
|
|
|
result_object.update( |
|
**{ |
|
"status": string_compare(generated_output, reference_output, False), |
|
"generated_output": "\n".join(generated_output), |
|
"error_message": None, |
|
} |
|
) |
|
results.append(result_object) |
|
|
|
return "", results |
|
|
|
|
|
def string_compare(candidate, correct, truncate_output=False, floating_point_accuracy=0.01): |
|
candidate = [o.strip().lower() for o in candidate] |
|
correct = correct.strip().lower() |
|
|
|
|
|
candidate = "\n".join(candidate) |
|
candidate = re.sub("\s+", " ", candidate).strip() |
|
correct = re.sub("\s+", " ", correct).strip() |
|
|
|
|
|
candidate = candidate.split(" ") |
|
correct = correct.split(" ") |
|
|
|
|
|
if not truncate_output: |
|
if not len(candidate) == len(correct): |
|
return False |
|
|
|
|
|
if truncate_output: |
|
correct = correct[:-1] |
|
|
|
|
|
for left, right in zip(candidate, correct): |
|
if left == right: |
|
continue |
|
|
|
try: |
|
int_left = int(left) |
|
int_right = int(right) |
|
if int_left == int_right: |
|
continue |
|
except ValueError: |
|
pass |
|
|
|
try: |
|
float_left = float(left) |
|
float_right = float(right) |
|
if np.abs(float_left - float_right) < floating_point_accuracy: |
|
continue |
|
except ValueError: |
|
pass |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
def call_method(method, inputs): |
|
if isinstance(inputs, list): |
|
inputs = "\n".join(inputs) |
|
|
|
inputs_line_iterator = iter(inputs.split("\n")) |
|
|
|
|
|
|
|
|
|
@patch("builtins.open", mock_open(read_data=inputs)) |
|
@patch("sys.stdin", StringIO(inputs)) |
|
@patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) |
|
@patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) |
|
@patch("sys.stdin.read", lambda *args: inputs) |
|
|
|
def _inner_call_method(_method): |
|
try: |
|
return _method() |
|
except SystemExit as e: |
|
pass |
|
finally: |
|
pass |
|
|
|
return _inner_call_method(method) |
|
|