Spaces:
Sleeping
Sleeping
| from .models import ExecutionResult, TaskInfo | |
| from .llm_judge import llm_judge | |
| def force_valid_reward(value) -> float: | |
| """Hard guarantee: reward is strictly in (0, 1) — never 0 or 1, no exceptions.""" | |
| try: | |
| r = float(value) | |
| except Exception: | |
| return 0.5 | |
| if r <= 0.001: | |
| return 0.001 | |
| if r >= 0.999: | |
| return 0.999 | |
| return r | |
| def safe_reward(reward) -> float: | |
| """Clamp reward to open interval (0, 1) via force_valid_reward.""" | |
| if reward is None: | |
| reward = 0.5 | |
| return force_valid_reward(reward) | |
| def normalize_reward(passed: int, total: int) -> float: | |
| if total == 0: | |
| return 0.5 | |
| raw = passed / total | |
| return force_valid_reward(raw) | |
| def get_llm_quality_score(proposed_fix: str) -> dict: | |
| return llm_judge("", proposed_fix, "unknown") | |
| def calculate_codearena_reward( | |
| *, | |
| compile_ok: bool, | |
| passed: int, | |
| total: int, | |
| execution_time_seconds: float, | |
| optimal_time_seconds: float, | |
| buggy_code: str, | |
| proposed_fix: str, | |
| task_category: str, | |
| step_count: int, | |
| is_repeated_fix: bool, | |
| ) -> tuple[float, dict]: | |
| compile_score = 1.0 if compile_ok else 0.0 | |
| test_pass_ratio = passed / total if total else 0.0 | |
| efficiency_score = 0.0 | |
| if test_pass_ratio == 1.0: | |
| if execution_time_seconds <= optimal_time_seconds: | |
| efficiency_score = 1.0 | |
| else: | |
| ratio = execution_time_seconds / max(0.001, optimal_time_seconds) | |
| efficiency_score = max(0.0, 1.0 - (ratio - 1.0) / 2.0) | |
| llm_scores = llm_judge(buggy_code, proposed_fix, task_category) | |
| llm_correctness = float(llm_scores.get("correctness", 0.5)) | |
| llm_security = float(llm_scores.get("security", 0.5)) | |
| llm_code_quality = float(llm_scores.get("code_quality", 0.5)) | |
| llm_judge_score = (llm_correctness + llm_security + llm_code_quality) / 3 | |
| novelty_penalty = 1.0 if is_repeated_fix else 0.0 | |
| step_penalty = 0.02 * step_count | |
| final_reward = ( | |
| 0.20 * compile_score | |
| + 0.40 * test_pass_ratio | |
| + 0.10 * efficiency_score | |
| + 0.30 * llm_judge_score | |
| - step_penalty | |
| - 0.10 * novelty_penalty | |
| ) | |
| final_reward = force_valid_reward(round(final_reward, 4)) | |
| return final_reward, { | |
| "compile_score": compile_score, | |
| "test_pass_ratio": test_pass_ratio, | |
| "efficiency_score": efficiency_score, | |
| "llm_correctness": llm_correctness, | |
| "llm_security": llm_security, | |
| "llm_code_quality": llm_code_quality, | |
| "step_penalty": step_penalty, | |
| "novelty_penalty": novelty_penalty, | |
| } | |
| def calculate_reward_components(exec_result: ExecutionResult, task_info: TaskInfo, proposed_fix: str) -> dict: | |
| _, components = calculate_codearena_reward( | |
| compile_ok=exec_result.compile_success, | |
| passed=exec_result.test_passed, | |
| total=exec_result.test_total, | |
| execution_time_seconds=exec_result.execution_time_seconds, | |
| optimal_time_seconds=task_info.optimal_time_seconds, | |
| buggy_code=task_info.buggy_code, | |
| proposed_fix=proposed_fix, | |
| task_category=task_info.difficulty, | |
| step_count=0, | |
| is_repeated_fix=False, | |
| ) | |
| return components | |
| def calculate_reward( | |
| exec_result: ExecutionResult, | |
| task_info: TaskInfo, | |
| proposed_fix: str, | |
| step_count: int = 0, | |
| is_repeated_fix: bool = False, | |
| ) -> tuple[float, dict]: | |
| return calculate_codearena_reward( | |
| compile_ok=exec_result.compile_success, | |
| passed=exec_result.test_passed, | |
| total=exec_result.test_total, | |
| execution_time_seconds=exec_result.execution_time_seconds, | |
| optimal_time_seconds=task_info.optimal_time_seconds, | |
| buggy_code=task_info.buggy_code, | |
| proposed_fix=proposed_fix, | |
| task_category=task_info.difficulty, | |
| step_count=step_count, | |
| is_repeated_fix=is_repeated_fix, | |
| ) | |
| def grade(*args, **kwargs) -> float: | |
| try: | |
| if len(args) == 3: | |
| return calculate_reward(args[0], args[1], args[2])[0] | |
| return 0.5 | |
| except Exception: | |
| return 0.5 | |