| |
| import re |
|
|
| |
| import signal |
| import sys |
|
|
| |
| from io import StringIO |
| from typing import List, Tuple |
|
|
| |
| from unittest.mock import patch, mock_open |
|
|
| import numpy as np |
| from pyext import RuntimeModule |
| from wrapt_timeout_decorator import timeout as wrapt_timeout |
| import threading |
|
|
| from ..datasets.schema import assert_test_format_codeforces |
|
|
| from flows import logging |
|
|
| log = logging.get_logger(__name__) |
| lock = threading.Lock() |
|
|
|
|
| def evaluate_solution_for_problem( |
| candidate_solution, |
| hidden_tests_io=None, |
| public_tests_io=None, |
| timeout=10, |
| debug=False, |
| add_extra_imports=False, |
| allow_truncated_io=False, |
| ): |
| with lock: |
| """See the readme for the output format of this function.""" |
| if hidden_tests_io is None: |
| hidden_tests_io = [] |
| if public_tests_io is None: |
| public_tests_io = [] |
|
|
| if candidate_solution is None: |
| results_dict = { |
| "compilation_status": False, |
| "compilation_error_message": "No code was provided.", |
| "timeout_error": False, |
| "hidden_tests_results": [ |
| { |
| "status": False, |
| "error_message": "No code was provided.", |
| "generated_output": None, |
| "input": test[0], |
| "expected_output": test[1], |
| } |
| for test in hidden_tests_io |
| ], |
| "public_tests_results": [ |
| { |
| "status": False, |
| "error_message": "No code was provided.", |
| "generated_output": None, |
| "input": test[0], |
| "expected_output": test[1], |
| } |
| for test in public_tests_io |
| ], |
| } |
| return results_dict |
|
|
| @wrapt_timeout(timeout, use_signals=False) |
| def run_tests(): |
| hidden_tests_results = check_correctness( |
| candidate_solution, hidden_tests_io, timeout, debug, add_extra_imports, allow_truncated_io |
| ) |
| public_tests_results = check_correctness( |
| candidate_solution, public_tests_io, timeout, debug, add_extra_imports, allow_truncated_io |
| ) |
|
|
| return hidden_tests_results, public_tests_results |
|
|
| try: |
| hidden_tests_results, public_tests_results = run_tests() |
| timeout_error_occurred = False |
| except BaseException as e: |
| log.info(e) |
| hidden_tests_results = {} |
| public_tests_results = {} |
|
|
| hidden_tests_results["compilation_status"] = True |
| public_tests_results["compilation_status"] = True |
| timeout_error_occurred = True |
| hidden_tests_results["error_message"] = "Timeout error." |
|
|
| hidden_tests_results["results"] = [ |
| { |
| "status": False, |
| "error_message": hidden_tests_results["error_message"], |
| "generated_output": None, |
| "input": test[0], |
| "expected_output": test[1], |
| } |
| for test in hidden_tests_io |
| ] |
| public_tests_results["results"] = [ |
| { |
| "status": False, |
| "error_message": hidden_tests_results["error_message"], |
| "generated_output": None, |
| "input": test[0], |
| "expected_output": test[1], |
| } |
| for test in public_tests_io |
| ] |
|
|
| |
| assert hidden_tests_results["compilation_status"] == public_tests_results["compilation_status"] |
|
|
| results_dict = { |
| "compilation_status": hidden_tests_results["compilation_status"], |
| "compilation_error_message": hidden_tests_results["error_message"], |
| "timeout_error": timeout_error_occurred, |
| "hidden_tests_results": hidden_tests_results["results"], |
| "public_tests_results": public_tests_results["results"], |
| } |
|
|
| return results_dict |
|
|
|
|
| def check_correctness( |
| candidate_solution: str, |
| tests: List[Tuple[List[str], str]], |
| timeout: int = 6000, |
| debug=True, |
| add_extra_imports=False, |
| allow_truncated_io=True, |
| ): |
| """ |
| wrapping the testing code in a global timeout, based on huggingface code |
| """ |
|
|
| assert_test_format_codeforces(tests) |
| inputs, outputs = [], [] |
| if len(tests) > 0: |
| inputs, outputs = zip(*tests) |
|
|
| compilation_error, results = run_test( |
| candidate_solution, inputs, outputs, timeout, debug, add_extra_imports, allow_truncated_io |
| ) |
|
|
| assert len(results) == len(inputs) |
|
|
| for result in results: |
| assert isinstance(result["generated_output"], str) or result["generated_output"] is None |
| assert isinstance(result["status"], bool) |
| assert isinstance(result["error_message"], str) or result["error_message"] is None |
| assert isinstance(result["input"], list) |
| assert isinstance(result["expected_output"], str) |
|
|
| compilation_status = compilation_error == "" |
| if compilation_status: |
| compilation_error = None |
|
|
| return {"compilation_status": compilation_status, "error_message": compilation_error, "results": results} |
|
|
|
|
| class TimeoutException(Exception): |
| pass |
|
|
|
|
| def timeout_handler(signum, frame): |
| log.info("alarm went off") |
| |
| raise TimeoutException |
|
|
|
|
| signal.signal(signal.SIGALRM, timeout_handler) |
|
|
|
|
| |
| |
| |
| class Capturing(list): |
| def __enter__(self): |
| self._stdout = sys.stdout |
| sys.stdout = self._stringio = StringIO() |
| |
| self._stringio.close = lambda x: 1 |
| return self |
|
|
| def __exit__(self, *args): |
| self.extend(self._stringio.getvalue().splitlines()) |
| del self._stringio |
| sys.stdout = self._stdout |
|
|
|
|
| def run_test(code, inputs, outputs, timeout: int = 6000, debug=True, add_extra_imports=False, allow_truncated_io=True): |
| """ |
| runs the code and tries to match inputs and outputs |
| the scraped testcases may be incomplete |
| if allow_truncated_io==True, then we ignore an EOF exception at the end of the generated output |
| """ |
| |
|
|
| results = [] |
|
|
| if isinstance(code, list): |
| tmp_test = code |
| elif isinstance(code, str): |
| tmp_test = code.split("\n") |
| else: |
| raise AssertionError("code must be provided as list of lines or string with \\n linebreaks.") |
|
|
| |
| import_lines = [] |
| future_import_lines = [] |
| code_lines = [] |
| for x in tmp_test: |
| if (not x.startswith("from ")) and (not x.startswith("import ")): |
| code_lines.append("\t" + x + "\n") |
| else: |
| if "__future__" in x: |
| future_import_lines.append(x + "\n") |
| else: |
| import_lines.append(x + "\n") |
|
|
| |
| new_test = "stdin = sys.stdin\nstdout = sys.stdout\n" |
| new_test += '__name__="__main__"\n' |
| new_test += "def code():\n" |
| for line in code_lines: |
| new_test += line |
|
|
| sol = "\n".join(future_import_lines) |
| sol += "import sys\n" |
| if add_extra_imports: |
| sol += "import time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" |
| sol += "\n".join(import_lines) + "\n" + new_test |
|
|
| if debug: |
| log.info(f"sol = {sol}") |
| method_name = "code" |
| signal.alarm(timeout) |
|
|
| |
| sol_module = None |
| try: |
| sol_module = RuntimeModule.from_string("tmp_sol", "", sol) |
| signal.alarm(0) |
| except Exception as e: |
| signal.alarm(0) |
| if debug: |
| log.info(f"type 1 compilation error = {e}") |
| for inp, out in zip(inputs, outputs): |
| |
| results.append( |
| { |
| "status": False, |
| "input": inp, |
| "expected_output": out, |
| "generated_output": None, |
| "error_message": repr(e), |
| } |
| ) |
| return repr(e), results |
|
|
| assert sol_module is not None |
| signal.alarm(0) |
|
|
| try: |
| method = getattr(sol_module, method_name) |
| except: |
| signal.alarm(0) |
| e = sys.exc_info() |
| log.info(f"unable to get function error = {e}") |
|
|
| for inp, out in zip(inputs, outputs): |
| |
| results.append( |
| { |
| "status": False, |
| "input": inp, |
| "expected_output": out, |
| "generated_output": None, |
| "error_message": repr(e), |
| } |
| ) |
| return repr(e), results |
|
|
| |
| |
| for index, (test_input, reference_output) in enumerate(zip(inputs, outputs)): |
|
|
| result_object = { |
| "input": test_input, |
| "expected_output": reference_output, |
| } |
|
|
| |
| input_truncated = False |
| if "".join(test_input).strip().endswith("...") and allow_truncated_io: |
| test_input = test_input[:-1] |
| input_truncated = True |
|
|
| |
| |
| |
| |
|
|
| error_code = None |
| with Capturing() as generated_output: |
| try: |
| call_method(method, test_input) |
| |
| signal.alarm(0) |
| except Exception as e: |
| |
| signal.alarm(0) |
| error_code = e |
| if debug: |
| log.info(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}") |
| signal.alarm(0) |
|
|
| |
| |
| if ( |
| (input_truncated or reference_output.strip().endswith("...")) |
| and allow_truncated_io |
| and (error_code is None or isinstance(error_code, EOFError) or isinstance(error_code, ValueError)) |
| ): |
|
|
| generated_output = generated_output[:-1] |
| reference_output = reference_output.rstrip("...") |
| if len(generated_output) == 0: |
| |
| result_object.update( |
| **{ |
| "status": True, |
| "generated_output": "\n".join(generated_output), |
| "error_message": None, |
| } |
| ) |
| results.append(result_object) |
| else: |
| result_object.update( |
| **{ |
| "status": string_compare(generated_output, reference_output, True), |
| "generated_output": "\n".join(generated_output), |
| "error_message": None, |
| } |
| ) |
| results.append(result_object) |
|
|
| |
| elif error_code is not None: |
| result_object.update(**{"status": False, "generated_output": None, "error_message": repr(error_code)}) |
| results.append(result_object) |
| |
| else: |
| |
| result_object.update( |
| **{ |
| "status": string_compare(generated_output, reference_output, False), |
| "generated_output": "\n".join(generated_output), |
| "error_message": None, |
| } |
| ) |
| results.append(result_object) |
|
|
| return "", results |
|
|
|
|
| def string_compare(candidate, correct, truncate_output=False, floating_point_accuracy=0.01): |
| candidate = [o.strip().lower() for o in candidate] |
| correct = correct.strip().lower() |
|
|
| |
| candidate = "\n".join(candidate) |
| candidate = re.sub("\s+", " ", candidate).strip() |
| correct = re.sub("\s+", " ", correct).strip() |
|
|
| |
| candidate = candidate.split(" ") |
| correct = correct.split(" ") |
|
|
| |
| if not truncate_output: |
| if not len(candidate) == len(correct): |
| return False |
|
|
| |
| if truncate_output: |
| correct = correct[:-1] |
|
|
| |
| for left, right in zip(candidate, correct): |
| if left == right: |
| continue |
|
|
| try: |
| int_left = int(left) |
| int_right = int(right) |
| if int_left == int_right: |
| continue |
| except ValueError: |
| pass |
|
|
| try: |
| float_left = float(left) |
| float_right = float(right) |
| if np.abs(float_left - float_right) < floating_point_accuracy: |
| continue |
| except ValueError: |
| pass |
|
|
| return False |
|
|
| return True |
|
|
|
|
| def call_method(method, inputs): |
| if isinstance(inputs, list): |
| inputs = "\n".join(inputs) |
|
|
| inputs_line_iterator = iter(inputs.split("\n")) |
|
|
| |
|
|
| |
| @patch("builtins.open", mock_open(read_data=inputs)) |
| @patch("sys.stdin", StringIO(inputs)) |
| @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) |
| @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) |
| @patch("sys.stdin.read", lambda *args: inputs) |
| |
| def _inner_call_method(_method): |
| try: |
| return _method() |
| except SystemExit as e: |
| pass |
| finally: |
| pass |
|
|
| return _inner_call_method(method) |
|
|