debugZero / server /graders.py
The-Fool-09's picture
Upload folder using huggingface_hub
8412998 verified
from __future__ import annotations
import ast
import statistics
from collections import deque
from thefuzz import fuzz
# Global solve rate history buffer: {seed_id: deque(maxlen=20)}
solve_rate_history: dict[str, deque[float]] = {}
def reset_reward_history() -> None:
solve_rate_history.clear()
def get_solve_rate(seed_id: str) -> float:
if seed_id not in solve_rate_history or not solve_rate_history[seed_id]:
return 0.5
return statistics.mean(solve_rate_history[seed_id])
def record_solve_result(seed_id: str, solved: bool) -> None:
if seed_id not in solve_rate_history:
solve_rate_history[seed_id] = deque(maxlen=20)
solve_rate_history[seed_id].append(1.0 if solved else 0.0)
def is_effectively_unchanged(original_code: str, candidate_code: str) -> bool:
try:
return ast.dump(ast.parse(original_code)) == ast.dump(ast.parse(candidate_code))
except SyntaxError:
return original_code.strip() == candidate_code.strip()
def compute_ast_distance(original_code: str, mutated_code: str) -> float:
"""
Computes the string similarity distance between the AST dumps of the original
and mutated code using thefuzz (Levenshtein based).
Zero edits = 0 score.
Targeted (small) edit = high plausibility (closer to 1.0).
Random / wide corruption = low score.
"""
try:
orig_ast = ast.dump(ast.parse(original_code))
mut_ast = ast.dump(ast.parse(mutated_code))
except SyntaxError:
return 0.0
ratio = fuzz.ratio(orig_ast, mut_ast)
# The Fuzz defaults to percentage (0 to 100)
# Empirical calibration: simple AST mutations typically result in
# a fuzz ratio of 85-98%.
if 85 <= ratio:
return 1.0
if 50 <= ratio < 85:
return max(0.1, (ratio - 50) / 35.0)
return 0.0
def compute_proposer_reward(meta: dict) -> float:
if meta.get("syntax_error", False):
return -0.5
if meta.get("unsafe_code", False):
return -0.5
if meta.get("unchanged_code", False):
return 0.0
if meta.get("tests_passed", True):
return 0.0
if meta.get("changed_but_passing", False):
return -0.1
plausibility_bonus = meta.get("plausibility_score", 0.0)
learnability_bonus = 0.0
solve_rate = get_solve_rate(meta["seed_id"])
if 0.2 <= solve_rate <= 0.8:
learnability_bonus = 1.0
return 1.0 + plausibility_bonus + learnability_bonus
def compute_solver_reward(meta: dict) -> float:
solved = meta.get("tests_passed", False)
syntax_error = meta.get("syntax_error", True)
unsafe_code = meta.get("unsafe_code", False)
record_solve_result(meta["seed_id"], solved and not syntax_error and not unsafe_code)
if syntax_error or unsafe_code:
return -0.5
if solved:
return 1.0
return 0.0