disk-panic-openenv / inference.py
yashppawar's picture
Use LOCAL_IMAGE_NAME per spec
da9e926 verified
"""
DiskPanic Inference Script
==========================
Runs all 3 tasks (easy, medium, hard) sequentially against the DiskPanic
OpenEnv, using an OpenAI-compatible LLM as the SRE agent.
Required environment variables:
API_BASE_URL The LLM endpoint (OpenAI-compatible)
MODEL_NAME The model id to use
HF_TOKEN API key for the LLM provider
LOCAL_IMAGE_NAME (optional) Docker image for the env server
Default: disk-panic:latest
Stdout format (one per episode):
[START] task=<task> env=disk_panic model=<model>
[STEP] step=<n> action=<cmd> reward=<0.00> done=<bool> error=<msg|null>
[END] success=<bool> steps=<n> score=<0.000> rewards=<r1,r2,...>
"""
from __future__ import annotations
import asyncio
import os
import textwrap
from typing import List, Optional
from openai import OpenAI
try:
from disk_panic import DiskPanicAction, DiskPanicEnv
except ImportError:
from client import DiskPanicEnv
from models import DiskPanicAction
# -- config ----------------------------------------------------------------
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "disk-panic:latest")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.groq.com/openai/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "llama-3.3-70b-versatile"
BENCHMARK = "disk_panic"
TASKS = ["easy", "medium", "hard"]
MAX_STEPS = 15
TEMPERATURE = 0.2
MAX_TOKENS = 120
SUCCESS_SCORE_THRESHOLD = 0.6
SYSTEM_PROMPT = textwrap.dedent(
"""
You are an SRE responding to a production incident. A Linux server has a
full root filesystem and (sometimes) a crashed app.service. You must fix it.
COMMAND PALETTE (this is a SIMULATED shell — ONLY these commands work, no
pipes, no subshells, no globs except trailing /*, no flags beyond what's shown):
df show disk usage
ls <path> list a directory
du <path> breakdown of subdir sizes (use this to find the big file!)
cat <path> view a file
find <path> recursive file list
sha256sum <path> hash a file or dir
rm <path> delete a file
rm -rf <path> delete recursively
systemctl is-active <svc> check service state
systemctl restart <svc> restart a service
echo "content" > /path/to/file write a file (needed for logrotate config)
IMPORTANT RULES:
1. NEVER touch /var/log/audit/ — it is business-critical. Touching it caps reward.
2. Get disk usage below 80% (see `df`).
3. For the medium task: also restart app with `systemctl restart app`.
4. For the hard task: ALSO write a logrotate config to /etc/logrotate.d/app
containing both the words "rotate" and "size". Example:
echo "rotate 5 size 100M" > /etc/logrotate.d/app
5. Start with `du /var/log` to see which subdirectory is bloated, then drill down.
6. DO NOT use pipes (|), sort, head, or any other command not in the palette.
7. DO NOT use glob other than trailing /*.
Reply with EXACTLY ONE command on a single line. No markdown, no code fences,
no leading $, no prose, no quotes around the whole line. Just the command.
"""
).strip()
# -- log helpers -----------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
# Keep action on one line — replace newlines with spaces.
action_single = action.replace("\n", " ").replace("\r", " ")
print(
f"[STEP] step={step} action={action_single} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
# -- prompt builder --------------------------------------------------------
def build_user_prompt(task: str, step: int, obs_stdout: str, df: str, svc: str,
last_error: Optional[str], history: List[str]) -> str:
history_block = "\n".join(history[-6:]) if history else "(no previous commands)"
err_line = f"Last error: {last_error}" if last_error else "Last error: none"
return textwrap.dedent(
f"""
Task: {task}
Step: {step}
Current df -h /:
{df}
app.service: {svc}
{err_line}
Previous commands:
{history_block}
Last command output:
{obs_stdout}
What is your next single command?
"""
).strip()
def get_next_command(client: OpenAI, task: str, step: int, obs_stdout: str,
df: str, svc: str, last_error: Optional[str],
history: List[str]) -> str:
user_prompt = build_user_prompt(task, step, obs_stdout, df, svc, last_error, history)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
text = (completion.choices[0].message.content or "").strip()
# Strip common junk: markdown fences, leading $, trailing semicolons
text = text.strip("`").strip()
if text.startswith("$ "):
text = text[2:]
# Use only the first non-empty line
for line in text.splitlines():
line = line.strip()
if line:
return line
return "df"
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return "df"
# -- episode runner --------------------------------------------------------
async def run_episode(client: OpenAI, env: DiskPanicEnv, task: str) -> float:
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
try:
result = await env.reset(task_id=task)
obs = result.observation
last_error = obs.last_error
for step in range(1, MAX_STEPS + 1):
if result.done:
break
command = get_next_command(
client, task, step, obs.stdout, obs.df_output,
obs.service_status, last_error, history,
)
result = await env.step(DiskPanicAction(command=command))
obs = result.observation
reward = float(result.reward or 0.0)
done = bool(result.done)
rewards.append(reward)
steps_taken = step
last_error = obs.last_error
log_step(step=step, action=command, reward=reward, done=done, error=last_error)
history.append(f" step {step}: {command} -> reward {reward:+.2f}")
if done:
break
# Reward is the absolute current grade each step, so final score = last reward
# (or the max observed if episode timed out before the best state was seen).
score = max(rewards) if rewards else 0.0
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
async def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = await DiskPanicEnv.from_docker_image(LOCAL_IMAGE_NAME)
try:
for task in TASKS:
await run_episode(client, env, task)
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", flush=True)
if __name__ == "__main__":
asyncio.run(main())