jjyang77
slow but
31f0a2c
import logging
import os
from collections import Counter, defaultdict
import multiprocessing
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import gc
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from api.code_execution import untrusted_check
Result = Tuple[str, List[bool]]
def create_app() -> FastAPI:
level = os.environ.get("LOG_LEVEL", default=logging.INFO)
logging.basicConfig(level=level)
logger = logging.getLogger(__name__)
app = FastAPI()
@app.get("/")
def root():
return RedirectResponse("/docs")
@app.get("/health", status_code=204)
def health():
return
@app.post("/evaluate/")
async def evaluate(
samples: List[dict],
calibrate: bool = True,
parallel: int = -1,
min_time_limit: float = 1,
max_as_limit: int = 30 * 1024,
max_data_limit: int = 30 * 1024,
max_stack_limit: int = 10,
no_gt: bool = True,
) -> dict:
"""
Evaluate the correctness of the solutions in the given samples data.
"""
if parallel < 1:
n_workers = min(2, multiprocessing.cpu_count() // 2)
if n_workers < 1:
n_workers = 1
else:
n_workers = parallel
if not no_gt:
expected_time = get_groundtruth()
else:
expected_time = {}
results = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
"eval": {},
}
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
completion_id = Counter()
n_samples = 0
eval_results = defaultdict(list) # task_id ->
remainings = set()
for i, sample in enumerate(samples):
# TODO: investigate why HTTPException detail is not passed to client.
for key in ["task_id", "res_id", "test", "solution", "entry_point"]:
if key not in sample:
raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!")
if not isinstance(sample["solution"], str):
raise HTTPException(status_code=400, detail="Solution must be a string!")
sample["_identifier"] = (
sample["task_id"] + f" (line {i+1} )"
)
task_id = sample["task_id"]
solution = sample["solution"]
if calibrate:
solution = sample["code_prompt"] + "\n pass\n" + solution
remainings.add(sample["_identifier"])
args = (
completion_id[task_id],
sample["res_id"],
task_id,
solution,
sample["test"],
sample["entry_point"],
max_as_limit,
max_data_limit,
max_stack_limit,
sample["_identifier"],
min_time_limit,
expected_time.get(task_id) if expected_time.get(task_id) else 20
)
futures.append(executor.submit(check_correctness, *args))
completion_id[task_id] += 1
n_samples += 1
assert n_samples == len(remainings), "Missing problems in unfinished"
#assert len(completion_id) == len(problems), "Missing problems in samples"
for future in as_completed(futures):
result = future.result()
remainings.remove(result["_identifier"])
eval_results[result["task_id"]].append(result)
del future, result
gc.collect()
# sort the results for each problem by completion_id
for task_id, task_results in eval_results.items():
task_results.sort(key=lambda x: x["completion_id"])
results["eval"][task_id] = []
for res in task_results:
stat, details = res["base"]
results["eval"][task_id].append(
{
"res_id": res["res_id"],
"task_id": task_id,
"solution": res["solution"],
"status": stat,
"details": details,
}
)
return results
return app
def check_correctness(
completion_id: int,
res_id: int,
task_id: str,
solution: str,
test: str,
entry_point: str,
max_as_limit: float,
max_data_limit: float,
max_stack_limit: float,
identifier=None,
min_time_limit: float = 0.1,
gt_time_limit: float = 2.0,
) -> Dict[str, Result]:
ret = {
"completion_id": completion_id,
"res_id": res_id,
"task_id": task_id,
"_identifier": identifier,
"solution": solution,
}
ret["base"] = untrusted_check(
solution,
test,
entry_point,
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
gt_time_limit,
)
return ret
def get_groundtruth():
raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!")