import logging import os from collections import Counter, defaultdict import multiprocessing from datetime import datetime from concurrent.futures import ProcessPoolExecutor, as_completed from typing import Dict, List, Tuple import gc from fastapi import FastAPI, HTTPException from fastapi.responses import RedirectResponse from api.code_execution import untrusted_check Result = Tuple[str, List[bool]] def create_app() -> FastAPI: level = os.environ.get("LOG_LEVEL", default=logging.INFO) logging.basicConfig(level=level) logger = logging.getLogger(__name__) app = FastAPI() @app.get("/") def root(): return RedirectResponse("/docs") @app.get("/health", status_code=204) def health(): return @app.post("/evaluate/") async def evaluate( samples: List[dict], calibrate: bool = True, parallel: int = -1, min_time_limit: float = 1, max_as_limit: int = 30 * 1024, max_data_limit: int = 30 * 1024, max_stack_limit: int = 10, no_gt: bool = True, ) -> dict: """ Evaluate the correctness of the solutions in the given samples data. """ if parallel < 1: n_workers = min(2, multiprocessing.cpu_count() // 2) if n_workers < 1: n_workers = 1 else: n_workers = parallel if not no_gt: expected_time = get_groundtruth() else: expected_time = {} results = { "date": datetime.now().strftime("%Y-%m-%d %H:%M"), "eval": {}, } with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = [] completion_id = Counter() n_samples = 0 eval_results = defaultdict(list) # task_id -> remainings = set() for i, sample in enumerate(samples): # TODO: investigate why HTTPException detail is not passed to client. for key in ["task_id", "res_id", "test", "solution", "entry_point"]: if key not in sample: raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!") if not isinstance(sample["solution"], str): raise HTTPException(status_code=400, detail="Solution must be a string!") sample["_identifier"] = ( sample["task_id"] + f" (line {i+1} )" ) task_id = sample["task_id"] solution = sample["solution"] if calibrate: solution = sample["code_prompt"] + "\n pass\n" + solution remainings.add(sample["_identifier"]) args = ( completion_id[task_id], sample["res_id"], task_id, solution, sample["test"], sample["entry_point"], max_as_limit, max_data_limit, max_stack_limit, sample["_identifier"], min_time_limit, expected_time.get(task_id) if expected_time.get(task_id) else 20 ) futures.append(executor.submit(check_correctness, *args)) completion_id[task_id] += 1 n_samples += 1 assert n_samples == len(remainings), "Missing problems in unfinished" #assert len(completion_id) == len(problems), "Missing problems in samples" for future in as_completed(futures): result = future.result() remainings.remove(result["_identifier"]) eval_results[result["task_id"]].append(result) del future, result gc.collect() # sort the results for each problem by completion_id for task_id, task_results in eval_results.items(): task_results.sort(key=lambda x: x["completion_id"]) results["eval"][task_id] = [] for res in task_results: stat, details = res["base"] results["eval"][task_id].append( { "res_id": res["res_id"], "task_id": task_id, "solution": res["solution"], "status": stat, "details": details, } ) return results return app def check_correctness( completion_id: int, res_id: int, task_id: str, solution: str, test: str, entry_point: str, max_as_limit: float, max_data_limit: float, max_stack_limit: float, identifier=None, min_time_limit: float = 0.1, gt_time_limit: float = 2.0, ) -> Dict[str, Result]: ret = { "completion_id": completion_id, "res_id": res_id, "task_id": task_id, "_identifier": identifier, "solution": solution, } ret["base"] = untrusted_check( solution, test, entry_point, max_as_limit, max_data_limit, max_stack_limit, min_time_limit, gt_time_limit, ) return ret def get_groundtruth(): raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!")