Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from collections import Counter, defaultdict | |
import multiprocessing | |
from datetime import datetime | |
from concurrent.futures import ProcessPoolExecutor, as_completed | |
from typing import Dict, List, Tuple | |
import gc | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import RedirectResponse | |
from api.code_execution import untrusted_check | |
Result = Tuple[str, List[bool]] | |
def create_app() -> FastAPI: | |
level = os.environ.get("LOG_LEVEL", default=logging.INFO) | |
logging.basicConfig(level=level) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
def root(): | |
return RedirectResponse("/docs") | |
def health(): | |
return | |
async def evaluate( | |
samples: List[dict], | |
calibrate: bool = True, | |
parallel: int = -1, | |
min_time_limit: float = 1, | |
max_as_limit: int = 30 * 1024, | |
max_data_limit: int = 30 * 1024, | |
max_stack_limit: int = 10, | |
no_gt: bool = True, | |
) -> dict: | |
""" | |
Evaluate the correctness of the solutions in the given samples data. | |
""" | |
if parallel < 1: | |
n_workers = min(2, multiprocessing.cpu_count() // 2) | |
if n_workers < 1: | |
n_workers = 1 | |
else: | |
n_workers = parallel | |
if not no_gt: | |
expected_time = get_groundtruth() | |
else: | |
expected_time = {} | |
results = { | |
"date": datetime.now().strftime("%Y-%m-%d %H:%M"), | |
"eval": {}, | |
} | |
with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
futures = [] | |
completion_id = Counter() | |
n_samples = 0 | |
eval_results = defaultdict(list) # task_id -> | |
remainings = set() | |
for i, sample in enumerate(samples): | |
# TODO: investigate why HTTPException detail is not passed to client. | |
for key in ["task_id", "res_id", "test", "solution", "entry_point"]: | |
if key not in sample: | |
raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!") | |
if not isinstance(sample["solution"], str): | |
raise HTTPException(status_code=400, detail="Solution must be a string!") | |
sample["_identifier"] = ( | |
sample["task_id"] + f" (line {i+1} )" | |
) | |
task_id = sample["task_id"] | |
solution = sample["solution"] | |
if calibrate: | |
solution = sample["code_prompt"] + "\n pass\n" + solution | |
remainings.add(sample["_identifier"]) | |
args = ( | |
completion_id[task_id], | |
sample["res_id"], | |
task_id, | |
solution, | |
sample["test"], | |
sample["entry_point"], | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
sample["_identifier"], | |
min_time_limit, | |
expected_time.get(task_id) if expected_time.get(task_id) else 20 | |
) | |
futures.append(executor.submit(check_correctness, *args)) | |
completion_id[task_id] += 1 | |
n_samples += 1 | |
assert n_samples == len(remainings), "Missing problems in unfinished" | |
#assert len(completion_id) == len(problems), "Missing problems in samples" | |
for future in as_completed(futures): | |
result = future.result() | |
remainings.remove(result["_identifier"]) | |
eval_results[result["task_id"]].append(result) | |
del future, result | |
gc.collect() | |
# sort the results for each problem by completion_id | |
for task_id, task_results in eval_results.items(): | |
task_results.sort(key=lambda x: x["completion_id"]) | |
results["eval"][task_id] = [] | |
for res in task_results: | |
stat, details = res["base"] | |
results["eval"][task_id].append( | |
{ | |
"res_id": res["res_id"], | |
"task_id": task_id, | |
"solution": res["solution"], | |
"status": stat, | |
"details": details, | |
} | |
) | |
return results | |
return app | |
def check_correctness( | |
completion_id: int, | |
res_id: int, | |
task_id: str, | |
solution: str, | |
test: str, | |
entry_point: str, | |
max_as_limit: float, | |
max_data_limit: float, | |
max_stack_limit: float, | |
identifier=None, | |
min_time_limit: float = 0.1, | |
gt_time_limit: float = 2.0, | |
) -> Dict[str, Result]: | |
ret = { | |
"completion_id": completion_id, | |
"res_id": res_id, | |
"task_id": task_id, | |
"_identifier": identifier, | |
"solution": solution, | |
} | |
ret["base"] = untrusted_check( | |
solution, | |
test, | |
entry_point, | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
min_time_limit, | |
gt_time_limit, | |
) | |
return ret | |
def get_groundtruth(): | |
raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!") | |