moksh24's picture
Restructure to compliant multi-mode layout with server/ and uv.lock
e13f5d6
Raw
History Blame Contribute Delete
4.64 kB
import subprocess
import tempfile
import os
import json
from typing import Dict, Any, Tuple
from models import Observation, Action, Reward
from tasks import TASKS
class PythonDebuggerEnv:
def __init__(self, task_id: str = "task_easy"):
if task_id not in TASKS:
task_id = "task_easy"
self.task_id = task_id
self.task = TASKS[task_id]
self.step_count = 0
self.best_score = 0.0
self.current_code = self.task.initial_code
def reset(self) -> Observation:
self.step_count = 0
self.best_score = 0.0
self.current_code = self.task.initial_code
return Observation(
task_id=self.task_id,
current_code=self.current_code,
test_results={"success": False, "tests_passed": 0, "total_tests": 1, "error": "", "stdout": ""},
feedback="Initial broken code provided. Fix the code to pass the tests.",
step=self.step_count,
max_steps=self.task.max_steps,
score=self.best_score
)
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
self.step_count += 1
self.current_code = action.code
# Indent the user's code so we can inject it inside the test function template if needed
indented_code = "\n ".join(self.current_code.split("\n"))
executable_script = self.task.test_code_template.replace("{code}", indented_code)
success = False
error_msg = ""
stdout = ""
tests_passed = 0
total_tests = 1
timeout_seconds = 2.0
try:
with open("temp_test.py", "w") as f:
f.write(executable_script)
result = subprocess.run(
["python", "temp_test.py"],
capture_output=True,
text=True,
timeout=timeout_seconds
)
stdout = result.stdout
if result.returncode == 0:
# Execution succeeded, let's parse the JSON emitted by the test wrapper
try:
data = json.loads(stdout.strip().split("\n")[-1])
tests_passed = data.get("passed", 0)
total_tests = data.get("total", 1)
errors = data.get("errors", [])
success = (tests_passed == total_tests)
if not success:
error_msg = "\n".join(errors)
except Exception as parse_e:
error_msg = "Parse error in test output: " + str(parse_e) + "\nRaw output: " + stdout
else:
error_msg = result.stderr or result.stdout
except subprocess.TimeoutExpired:
error_msg = f"TimeoutError: Execution exceeded {timeout_seconds}s limit. Your algorithm is too slow or has an infinite loop."
except Exception as e:
error_msg = str(e)
finally:
if os.path.exists("temp_test.py"):
os.remove("temp_test.py")
# Meaningful reward logic based on partial progress
if total_tests > 0:
score = tests_passed / total_tests
else:
score = 0.0
reward_value = max(0.0, score - self.best_score)
self.best_score = max(self.best_score, score)
done = success or (self.step_count >= self.task.max_steps)
if success:
feedback = "All tests passed! Great job."
else:
feedback = f"Passed {tests_passed}/{total_tests} tests.\nErrors:\n{error_msg[-500:]}"
obs = Observation(
task_id=self.task_id,
current_code=self.current_code,
test_results={
"success": success,
"tests_passed": tests_passed,
"total_tests": total_tests,
"error": error_msg,
"stdout": stdout,
},
feedback=feedback,
step=self.step_count,
max_steps=self.task.max_steps,
score=self.best_score
)
reward = Reward(
value=reward_value,
correctness=score,
efficiency=1.0 if success else 0.0
)
return obs, reward, done, {"info": "executed"}
def state(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"step": self.step_count,
"best_score": self.best_score,
"current_code": self.current_code
}