codesensei-env / env /server /test_runner.py
vineetshukla.work@gmail.com
feat: CodeSensei - GRPO-trained LLM code debugger on OpenEnv
c47c81c
"""
CodeSensei — Test Runner.
Runs individual test cases against a function and returns structured
TestResult objects. Handles test isolation and error capture.
"""
from __future__ import annotations
import traceback
from typing import List, Tuple
from env.models import TestResult
from env.server.sandbox import run_function_with_tests, check_syntax
def run_tests(
function_code: str,
test_cases: List[dict],
timeout: int = 5,
) -> Tuple[List[TestResult], int, int, str]:
"""Run all test cases against a function and return results.
Each test case dict has:
- "name": str — test description
- "code": str — Python assert statement(s) calling the function
Args:
function_code: The Python function source code.
test_cases: List of test case dicts.
timeout: Max execution time per test batch.
Returns:
Tuple of (test_results, passed_count, total_count, raw_error_output).
"""
# First check syntax
is_valid, syntax_error = check_syntax(function_code)
if not is_valid:
return (
[
TestResult(test_name=tc["name"], passed=False, error_message=syntax_error)
for tc in test_cases
],
0,
len(test_cases),
syntax_error,
)
results: List[TestResult] = []
total_passed = 0
total = len(test_cases)
raw_errors = []
# Run all tests together first for speed
combined_test_code = "\n".join(
f"# Test: {tc['name']}\n{tc['code']}" for tc in test_cases
)
stdout, stderr, all_success = run_function_with_tests(
function_code, combined_test_code, timeout
)
if all_success and "ALL_TESTS_PASSED" in stdout:
# All tests passed in batch — fast path
for tc in test_cases:
results.append(TestResult(test_name=tc["name"], passed=True))
return results, total, total, ""
# If batch failed, run tests individually to identify which ones fail
for tc in test_cases:
stdout_i, stderr_i, success_i = run_function_with_tests(
function_code, tc["code"], timeout
)
if success_i and "ALL_TESTS_PASSED" in stdout_i:
results.append(TestResult(test_name=tc["name"], passed=True))
total_passed += 1
else:
# Extract the meaningful error
error_msg = _extract_error(stderr_i)
results.append(
TestResult(test_name=tc["name"], passed=False, error_message=error_msg)
)
raw_errors.append(f"[{tc['name']}] {error_msg}")
return results, total_passed, total, "\n".join(raw_errors)
def _extract_error(stderr: str) -> str:
"""Extract the most meaningful error line from stderr.
Args:
stderr: Raw stderr output from subprocess.
Returns:
Cleaned error message string (single line or short).
"""
if not stderr:
return "Unknown error (no output)"
lines = stderr.strip().split("\n")
# Look for the last line that starts with a known error type
for line in reversed(lines):
stripped = line.strip()
if any(
stripped.startswith(err)
for err in [
"AssertionError",
"AssertionError",
"TypeError",
"ValueError",
"NameError",
"IndexError",
"KeyError",
"AttributeError",
"ZeroDivisionError",
"RecursionError",
"RuntimeError",
"StopIteration",
"SyntaxError",
"IndentationError",
"AssertionError",
]
):
return stripped
# Fallback: last non-empty line, truncated
for line in reversed(lines):
if line.strip():
return line.strip()[:200]
return "Unknown error"