Spaces:
Sleeping
Sleeping
File size: 5,457 Bytes
25db7e9 31f0a2c 9184913 25db7e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import logging
import os
from collections import Counter, defaultdict
import multiprocessing
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import gc
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from api.code_execution import untrusted_check
Result = Tuple[str, List[bool]]
def create_app() -> FastAPI:
level = os.environ.get("LOG_LEVEL", default=logging.INFO)
logging.basicConfig(level=level)
logger = logging.getLogger(__name__)
app = FastAPI()
@app.get("/")
def root():
return RedirectResponse("/docs")
@app.get("/health", status_code=204)
def health():
return
@app.post("/evaluate/")
async def evaluate(
samples: List[dict],
calibrate: bool = True,
parallel: int = -1,
min_time_limit: float = 1,
max_as_limit: int = 30 * 1024,
max_data_limit: int = 30 * 1024,
max_stack_limit: int = 10,
no_gt: bool = True,
) -> dict:
"""
Evaluate the correctness of the solutions in the given samples data.
"""
if parallel < 1:
n_workers = min(2, multiprocessing.cpu_count() // 2)
if n_workers < 1:
n_workers = 1
else:
n_workers = parallel
if not no_gt:
expected_time = get_groundtruth()
else:
expected_time = {}
results = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
"eval": {},
}
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
completion_id = Counter()
n_samples = 0
eval_results = defaultdict(list) # task_id ->
remainings = set()
for i, sample in enumerate(samples):
# TODO: investigate why HTTPException detail is not passed to client.
for key in ["task_id", "res_id", "test", "solution", "entry_point"]:
if key not in sample:
raise HTTPException(status_code=400, detail=f"'{key}' not in sample {i}!")
if not isinstance(sample["solution"], str):
raise HTTPException(status_code=400, detail="Solution must be a string!")
sample["_identifier"] = (
sample["task_id"] + f" (line {i+1} )"
)
task_id = sample["task_id"]
solution = sample["solution"]
if calibrate:
solution = sample["code_prompt"] + "\n pass\n" + solution
remainings.add(sample["_identifier"])
args = (
completion_id[task_id],
sample["res_id"],
task_id,
solution,
sample["test"],
sample["entry_point"],
max_as_limit,
max_data_limit,
max_stack_limit,
sample["_identifier"],
min_time_limit,
expected_time.get(task_id) if expected_time.get(task_id) else 20
)
futures.append(executor.submit(check_correctness, *args))
completion_id[task_id] += 1
n_samples += 1
assert n_samples == len(remainings), "Missing problems in unfinished"
#assert len(completion_id) == len(problems), "Missing problems in samples"
for future in as_completed(futures):
result = future.result()
remainings.remove(result["_identifier"])
eval_results[result["task_id"]].append(result)
del future, result
gc.collect()
# sort the results for each problem by completion_id
for task_id, task_results in eval_results.items():
task_results.sort(key=lambda x: x["completion_id"])
results["eval"][task_id] = []
for res in task_results:
stat, details = res["base"]
results["eval"][task_id].append(
{
"res_id": res["res_id"],
"task_id": task_id,
"solution": res["solution"],
"status": stat,
"details": details,
}
)
return results
return app
def check_correctness(
completion_id: int,
res_id: int,
task_id: str,
solution: str,
test: str,
entry_point: str,
max_as_limit: float,
max_data_limit: float,
max_stack_limit: float,
identifier=None,
min_time_limit: float = 0.1,
gt_time_limit: float = 2.0,
) -> Dict[str, Result]:
ret = {
"completion_id": completion_id,
"res_id": res_id,
"task_id": task_id,
"_identifier": identifier,
"solution": solution,
}
ret["base"] = untrusted_check(
solution,
test,
entry_point,
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
gt_time_limit,
)
return ret
def get_groundtruth():
raise HTTPException(status_code=405, detail="Groundtruth execution is not implemented yet!")
|