CC_flows / src /evaluation /testing_utils_leetcode.py
martinjosifoski's picture
add LC_Code (#3)
17ab974
raw history blame
No virus
8.59 kB
# This is based heavily on the huggingface APPS metric
# to run the solution files we're using a timing based approach
# for capturing the stdout
# used for testing the code that reads from input
import logging
import re
from subprocess import Popen, PIPE, TimeoutExpired
from typing import List, Tuple
import threading
log = logging.getLogger(__name__)
lock = threading.Lock()
def evaluate_solution_for_problem(
candidate_solution,
python_stub,
hidden_tests_io=None,
public_tests_io=None,
timeout=10,
debug=False,
add_extra_imports=False,
):
with lock:
"""See the readme for the output format of this function."""
if hidden_tests_io is None:
hidden_tests_io = []
if public_tests_io is None:
public_tests_io = []
if candidate_solution is None:
results_dict = {
"compilation_status": False,
"compilation_error_message": "No code was provided.",
"timeout_error": False,
"hidden_tests_results": [
{
"status": False,
"error_message": "No code was provided.",
"generated_output": None,
"input": test[0],
"expected_output": test[1],
}
for test in hidden_tests_io
],
"public_tests_results": [
{
"status": False,
"error_message": "No code was provided.",
"generated_output": None,
"input": test[0],
"expected_output": test[1],
}
for test in public_tests_io
],
}
return results_dict
hidden_tests_results = check_correctness(
candidate_solution, python_stub, hidden_tests_io, timeout, debug, add_extra_imports
)
public_tests_results = check_correctness(
candidate_solution, python_stub, public_tests_io, timeout, debug, add_extra_imports
)
# the compilation status shouldn't depend on the tests
if len(hidden_tests_io) > 0 and len(public_tests_io) > 0:
assert hidden_tests_results["compilation_status"] == public_tests_results["compilation_status"]
compilation_status = True
error_message = None
timeout_error = False
if len(hidden_tests_io) > 0:
compilation_status = compilation_status and hidden_tests_results["compilation_status"]
error_message = hidden_tests_results["error_message"]
timeout_error = timeout_error or hidden_tests_results["timeout_error"]
if len(public_tests_io) > 0:
compilation_status = compilation_status and public_tests_results["compilation_status"]
error_message = public_tests_results["error_message"]
timeout_error = timeout_error or public_tests_results["timeout_error"]
results_dict = {
"compilation_status": compilation_status,
"compilation_error_message": error_message,
"timeout_error": timeout_error,
"hidden_tests_results": hidden_tests_results["results"],
"public_tests_results": public_tests_results["results"],
}
return results_dict
def check_correctness(
candidate_solution: str,
python_stub: str,
tests: List[Tuple[List[str], str]],
timeout: int = 6000,
debug=True,
add_extra_imports=False,
):
compilation_status = True
compilation_error = None
results = []
timeout_occurred = False
for idx, test in enumerate(tests):
inp, out, expl = test
result = one_test(
candidate_solution, python_stub, inp, out, timeout=timeout, debug=debug, add_extra_imports=add_extra_imports
)
error_message = result["error_message"]
if error_message is not None:
if "syntaxerror" in error_message.lower():
compilation_status = False
compilation_error = error_message
if "timeout" in error_message.lower():
timeout_occurred = True
results.append(result)
if timeout_occurred:
break
if timeout_occurred:
return {
"compilation_status": True,
"timeout_error": True,
"error_message": "Timeout error.",
"results": results,
}
return {
"compilation_status": compilation_status,
"timeout_error": False,
"error_message": compilation_error,
"results": results,
}
def one_test(candidate_solution, python_stub, inp, out, timeout=10, debug=False, add_extra_imports=False):
python_stub = python_stub.strip()
candidate_solution = candidate_solution.strip()
out = out.replace("null", "None").replace("true", "True").replace("false", "False")
# reformat the solution and parse class and method name
class_def, signature = python_stub.split(" def ")
class_name = class_def.split("class ")[1].strip().rstrip(":")
func_name, _ = signature.split("(")
# reformatting the input
first_param = r"^\w+\s\=\s"
later_params = r",\s\w+\s\=\s"
inp = re.sub(first_param, "", inp)
inp = re.sub(later_params, ", ", inp)
# we add custom code to invoke the solution
before_output = "AFTER THIS COMES OUR OWN GENERATED OUTPUT !@#!@!"
after_output = "AFTER THIS COMES OUR VERDICT !@#!@!"
if add_extra_imports:
sol = f"""
from collections import *
from math import *
import math
from functools import *
from heapq import *
import heapq
import itertools
from itertools import *
import bisect
from bisect import *
"""
else:
sol = ""
sol += f"""
from typing import List, Tuple, Optional
{candidate_solution}
sfohsdfdsfjhsdkfjhsdkjfh = {class_name}()
res = sfohsdfdsfjhsdkfjhsdkjfh.{func_name}({inp})
def nested_list_convert(inp):
try:
try:
inp = list(inp)
except BaseException as e:
return inp
out = []
for i in inp:
out.append(nested_list_convert(i))
except BaseException as e:
return inp
return out
matching = False
matching = matching or res == {out}
matching = matching or nested_list_convert(res) == {out}
matching = matching or nested_list_convert(res) == nested_list_convert({out})
matching = matching or str({out})==str(res).replace("{{","[").replace("(","[").replace("}}","]").replace(")","]")
matching = matching or str({out})==str(res).replace("{{","[").replace("(","[").replace("}}","]").replace(")","]")
print("res: ", res)
print("out: ", {out})
print("{before_output}")
print(res)
print("{after_output}")
print(matching)
"""
cmd = "python3"
proc = Popen([cmd, "-c", sol], stdin=PIPE, stdout=PIPE, stderr=PIPE)
result_object = {"input": inp, "expected_output": out.strip('"')}
try:
stdout, stderr = proc.communicate("", timeout=timeout)
except TimeoutExpired as e:
if debug:
log.info(f"Timeout error, timeout={timeout}")
result_object.update({"status": False, "error_message": "Timeout error.", "generated_output": None})
return result_object
finally:
proc.kill()
stdout = stdout.decode()
stderr = stderr.decode().lower()
if stderr == "":
# No compilation or runtime error
stderr = None
else:
# Runtime or compilation error (distinction is made by the presence of "syntaxerror" in the error message)
result_object.update(**{"status": False, "error_message": stderr, "generated_output": None})
return result_object
try:
generated_output = stdout.split(before_output)[1]
generated_output, verdict = generated_output.split(after_output)
result_object.update(
**{
"status": verdict.strip() == "True",
"error_message": stderr,
"generated_output": generated_output.strip(),
}
)
return result_object
except IndexError as e:
raise Exception(f"An unexpected error has occurred while parsing the following generated output: {stdout}")
# Used in debugging
# log.info(e)
# result_object.update(
# **{"status": False, "error_message": "The output couldn't be parsed", "generated_output": None}
# )
# return result_object