Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +1 -0
- inference.py +7 -4
- pytest.ini +1 -0
- supportdesk_env/client.py +1 -1
- supportdesk_env/graders.py +10 -2
- supportdesk_env/server/app.py +37 -0
- supportdesk_env/server/supportdesk_environment.py +97 -70
- tests/test_supportdesk.py +36 -2
README.md
CHANGED
|
@@ -380,3 +380,4 @@ Current deterministic fallback baseline:
|
|
| 380 |
- average: `1.00`
|
| 381 |
|
| 382 |
These scores are intentionally reproducible. The fallback policy exists to show that the environment, reward shaping, and graders all work end to end. Model-backed runs can be lower, which is useful for evaluation.
|
|
|
|
|
|
| 380 |
- average: `1.00`
|
| 381 |
|
| 382 |
These scores are intentionally reproducible. The fallback policy exists to show that the environment, reward shaping, and graders all work end to end. Model-backed runs can be lower, which is useful for evaluation.
|
| 383 |
+
|
inference.py
CHANGED
|
@@ -24,7 +24,7 @@ from supportdesk_env.tasks import get_task, list_task_ids
|
|
| 24 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 25 |
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 26 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
-
API_KEY = HF_TOKEN
|
| 28 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 29 |
TASK_NAME = os.getenv("SUPPORTDESK_TASK_ID", "billing_refund_easy")
|
| 30 |
BENCHMARK = os.getenv("SUPPORTDESK_BENCHMARK", "supportdesk_env")
|
|
@@ -172,10 +172,11 @@ def _log_step(step: int, action_str: str, reward: float, done: bool, error: str
|
|
| 172 |
)
|
| 173 |
|
| 174 |
|
| 175 |
-
def _log_end(success: bool, steps: int, rewards: list[float]) -> None:
|
| 176 |
reward_text = ",".join(f"{reward:.2f}" for reward in rewards)
|
| 177 |
print(
|
| 178 |
-
f"[END] success={str(success).lower()} steps={steps}
|
|
|
|
| 179 |
flush=True,
|
| 180 |
)
|
| 181 |
|
|
@@ -268,6 +269,7 @@ async def _run_docker_episode(task_id: str, client: OpenAI | None) -> EpisodeRes
|
|
| 268 |
async def main() -> None:
|
| 269 |
client = _build_client()
|
| 270 |
success = False
|
|
|
|
| 271 |
steps_taken = 0
|
| 272 |
rewards: list[float] = []
|
| 273 |
|
|
@@ -278,11 +280,12 @@ async def main() -> None:
|
|
| 278 |
episode = await _run_docker_episode(TASK_NAME, client)
|
| 279 |
else:
|
| 280 |
episode = _run_local_episode(TASK_NAME, client)
|
|
|
|
| 281 |
success = episode.final_score >= SUCCESS_SCORE_THRESHOLD
|
| 282 |
steps_taken = episode.steps_taken
|
| 283 |
rewards = episode.rewards
|
| 284 |
finally:
|
| 285 |
-
_log_end(success=success, steps=steps_taken, rewards=rewards)
|
| 286 |
|
| 287 |
|
| 288 |
if __name__ == "__main__":
|
|
|
|
| 24 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 25 |
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 26 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
+
API_KEY = HF_TOKEN or os.getenv("API_KEY")
|
| 28 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 29 |
TASK_NAME = os.getenv("SUPPORTDESK_TASK_ID", "billing_refund_easy")
|
| 30 |
BENCHMARK = os.getenv("SUPPORTDESK_BENCHMARK", "supportdesk_env")
|
|
|
|
| 172 |
)
|
| 173 |
|
| 174 |
|
| 175 |
+
def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 176 |
reward_text = ",".join(f"{reward:.2f}" for reward in rewards)
|
| 177 |
print(
|
| 178 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 179 |
+
f"score={score:.3f} rewards={reward_text}",
|
| 180 |
flush=True,
|
| 181 |
)
|
| 182 |
|
|
|
|
| 269 |
async def main() -> None:
|
| 270 |
client = _build_client()
|
| 271 |
success = False
|
| 272 |
+
final_score = 0.0
|
| 273 |
steps_taken = 0
|
| 274 |
rewards: list[float] = []
|
| 275 |
|
|
|
|
| 280 |
episode = await _run_docker_episode(TASK_NAME, client)
|
| 281 |
else:
|
| 282 |
episode = _run_local_episode(TASK_NAME, client)
|
| 283 |
+
final_score = episode.final_score
|
| 284 |
success = episode.final_score >= SUCCESS_SCORE_THRESHOLD
|
| 285 |
steps_taken = episode.steps_taken
|
| 286 |
rewards = episode.rewards
|
| 287 |
finally:
|
| 288 |
+
_log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)
|
| 289 |
|
| 290 |
|
| 291 |
if __name__ == "__main__":
|
pytest.ini
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
[pytest]
|
| 2 |
addopts = -p no:cacheprovider
|
| 3 |
testpaths = tests
|
|
|
|
|
|
| 1 |
[pytest]
|
| 2 |
addopts = -p no:cacheprovider
|
| 3 |
testpaths = tests
|
| 4 |
+
pythonpath = .
|
supportdesk_env/client.py
CHANGED
|
@@ -30,9 +30,9 @@ class SupportDeskEnv(EnvClient[SupportDeskAction, SupportDeskObservation, Suppor
|
|
| 30 |
|
| 31 |
def _parse_result(self, payload) -> StepResult[SupportDeskObservation]:
|
| 32 |
observation = _validate(SupportDeskObservation, payload["observation"])
|
|
|
|
| 33 |
return StepResult(
|
| 34 |
observation=observation,
|
| 35 |
reward=payload["reward"],
|
| 36 |
done=payload["done"],
|
| 37 |
-
info=payload.get("info", {}),
|
| 38 |
)
|
|
|
|
| 30 |
|
| 31 |
def _parse_result(self, payload) -> StepResult[SupportDeskObservation]:
|
| 32 |
observation = _validate(SupportDeskObservation, payload["observation"])
|
| 33 |
+
# OpenEnv StepResult only accepts observation/reward/done in this runtime.
|
| 34 |
return StepResult(
|
| 35 |
observation=observation,
|
| 36 |
reward=payload["reward"],
|
| 37 |
done=payload["done"],
|
|
|
|
| 38 |
)
|
supportdesk_env/graders.py
CHANGED
|
@@ -8,6 +8,8 @@ from dataclasses import dataclass
|
|
| 8 |
from supportdesk_env.models import SupportCaseProgress
|
| 9 |
from supportdesk_env.tasks import SupportTaskSpec, get_task
|
| 10 |
|
|
|
|
|
|
|
| 11 |
|
| 12 |
@dataclass(frozen=True)
|
| 13 |
class GradeBreakdown:
|
|
@@ -70,8 +72,14 @@ def _reply_penalty(case: SupportCaseProgress, task: SupportTaskSpec) -> float:
|
|
| 70 |
return 0.0 if not any(_normalize(marker) in text for marker in task.forbidden_reply_markers) else 0.5
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def grade_case(task: SupportTaskSpec, case: SupportCaseProgress) -> GradeBreakdown:
|
| 74 |
-
"""Score a case
|
| 75 |
|
| 76 |
queue_score = 1.0 if case.queue == task.gold_queue else 0.0
|
| 77 |
priority_score = 1.0 if case.priority == task.gold_priority else 0.0
|
|
@@ -112,7 +120,7 @@ def grade_case(task: SupportTaskSpec, case: SupportCaseProgress) -> GradeBreakdo
|
|
| 112 |
milestones.append("resolution_code")
|
| 113 |
|
| 114 |
return GradeBreakdown(
|
| 115 |
-
total_score=round(weighted_total, 4),
|
| 116 |
queue_score=queue_score,
|
| 117 |
priority_score=priority_score,
|
| 118 |
issue_type_score=issue_type_score,
|
|
|
|
| 8 |
from supportdesk_env.models import SupportCaseProgress
|
| 9 |
from supportdesk_env.tasks import SupportTaskSpec, get_task
|
| 10 |
|
| 11 |
+
STRICT_SCORE_EPSILON = 0.001
|
| 12 |
+
|
| 13 |
|
| 14 |
@dataclass(frozen=True)
|
| 15 |
class GradeBreakdown:
|
|
|
|
| 72 |
return 0.0 if not any(_normalize(marker) in text for marker in task.forbidden_reply_markers) else 0.5
|
| 73 |
|
| 74 |
|
| 75 |
+
def _strict_open_unit_interval(score: float) -> float:
|
| 76 |
+
"""Keep final task scores strictly within (0, 1) for evaluator compatibility."""
|
| 77 |
+
|
| 78 |
+
return min(1.0 - STRICT_SCORE_EPSILON, max(STRICT_SCORE_EPSILON, score))
|
| 79 |
+
|
| 80 |
+
|
| 81 |
def grade_case(task: SupportTaskSpec, case: SupportCaseProgress) -> GradeBreakdown:
|
| 82 |
+
"""Score a case deterministically with total_score strictly inside (0, 1)."""
|
| 83 |
|
| 84 |
queue_score = 1.0 if case.queue == task.gold_queue else 0.0
|
| 85 |
priority_score = 1.0 if case.priority == task.gold_priority else 0.0
|
|
|
|
| 120 |
milestones.append("resolution_code")
|
| 121 |
|
| 122 |
return GradeBreakdown(
|
| 123 |
+
total_score=round(_strict_open_unit_interval(weighted_total), 4),
|
| 124 |
queue_score=queue_score,
|
| 125 |
priority_score=priority_score,
|
| 126 |
issue_type_score=issue_type_score,
|
supportdesk_env/server/app.py
CHANGED
|
@@ -35,6 +35,7 @@ import os
|
|
| 35 |
from typing import Any
|
| 36 |
|
| 37 |
import uvicorn
|
|
|
|
| 38 |
|
| 39 |
try:
|
| 40 |
from openenv.core.env_server import http_server as openenv_http_server
|
|
@@ -103,6 +104,42 @@ def list_tasks() -> dict[str, Any]:
|
|
| 103 |
}
|
| 104 |
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
|
| 107 |
"""
|
| 108 |
Entry point for direct execution via uv run or python -m.
|
|
|
|
| 35 |
from typing import Any
|
| 36 |
|
| 37 |
import uvicorn
|
| 38 |
+
from fastapi import Body, HTTPException
|
| 39 |
|
| 40 |
try:
|
| 41 |
from openenv.core.env_server import http_server as openenv_http_server
|
|
|
|
| 104 |
}
|
| 105 |
|
| 106 |
|
| 107 |
+
@app.get("/episodes/{episode_id}/state", response_model=SupportDeskState)
|
| 108 |
+
def get_episode_state(episode_id: str) -> SupportDeskState:
|
| 109 |
+
"""Optional explicit state helper for robust episode-addressable inspection."""
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
return SupportDeskEnvironment.state_for_episode(episode_id)
|
| 113 |
+
except ValueError as exc:
|
| 114 |
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@app.post("/episodes/{episode_id}/step")
|
| 118 |
+
def step_episode(
|
| 119 |
+
episode_id: str,
|
| 120 |
+
payload: dict[str, Any] = Body(...),
|
| 121 |
+
) -> dict[str, Any]:
|
| 122 |
+
"""Optional explicit step helper that does not require sticky request context."""
|
| 123 |
+
|
| 124 |
+
action_payload = payload.get("action")
|
| 125 |
+
if not isinstance(action_payload, dict):
|
| 126 |
+
raise HTTPException(status_code=422, detail="Request body must include an 'action' object.")
|
| 127 |
+
|
| 128 |
+
timeout_s = payload.get("timeout_s")
|
| 129 |
+
try:
|
| 130 |
+
action = SupportDeskAction.model_validate(action_payload)
|
| 131 |
+
env = SupportDeskEnvironment()
|
| 132 |
+
observation = env.step(action, timeout_s=timeout_s, episode_id=episode_id)
|
| 133 |
+
except ValueError as exc:
|
| 134 |
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"observation": observation.model_dump(),
|
| 138 |
+
"reward": observation.reward,
|
| 139 |
+
"done": observation.done,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
|
| 144 |
"""
|
| 145 |
Entry point for direct execution via uv run or python -m.
|
supportdesk_env/server/supportdesk_environment.py
CHANGED
|
@@ -6,6 +6,7 @@ import os
|
|
| 6 |
import threading
|
| 7 |
import uuid
|
| 8 |
from pathlib import Path
|
|
|
|
| 9 |
|
| 10 |
from supportdesk_env.graders import grade_case
|
| 11 |
from supportdesk_env.models import (
|
|
@@ -33,19 +34,11 @@ class SupportDeskEnvironment(
|
|
| 33 |
):
|
| 34 |
"""A realistic customer support triage environment with dense rewards."""
|
| 35 |
|
| 36 |
-
_state_lock = threading.RLock()
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
_shared_last_feedback = ""
|
| 42 |
-
_shared_history: list[ActionHistoryEntry] = []
|
| 43 |
-
_shared_case = SupportCaseProgress()
|
| 44 |
-
_shared_episode_id: str | None = None
|
| 45 |
-
_shared_score = 0.0
|
| 46 |
-
_shared_completed_milestones: list[str] = []
|
| 47 |
-
_shared_current_sla_minutes_remaining: int | None = None
|
| 48 |
-
_shared_reset_counter = 0
|
| 49 |
|
| 50 |
def __init__(self, task_id: str | None = None):
|
| 51 |
super().__init__()
|
|
@@ -65,69 +58,93 @@ class SupportDeskEnvironment(
|
|
| 65 |
initial_grade = grade_case(self.task, self._case)
|
| 66 |
self._score = initial_grade.total_score
|
| 67 |
self._completed_milestones = list(initial_grade.completed_milestones)
|
| 68 |
-
self._ensure_shared_state(self.task)
|
| 69 |
|
| 70 |
@classmethod
|
| 71 |
-
def
|
| 72 |
-
cls,
|
| 73 |
-
task: SupportTaskSpec,
|
| 74 |
-
*,
|
| 75 |
-
episode_id: str | None = None,
|
| 76 |
-
) -> None:
|
| 77 |
initial_case = SupportCaseProgress()
|
| 78 |
initial_grade = grade_case(task, initial_case)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
)
|
| 86 |
-
cls._shared_history = []
|
| 87 |
-
cls._shared_case = initial_case
|
| 88 |
-
cls._shared_episode_id = episode_id
|
| 89 |
-
cls._shared_score = initial_grade.total_score
|
| 90 |
-
cls._shared_completed_milestones = list(initial_grade.completed_milestones)
|
| 91 |
-
cls._shared_current_sla_minutes_remaining = task.ticket.sla_minutes_remaining
|
| 92 |
|
| 93 |
@classmethod
|
| 94 |
-
def
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
-
task = get_task(self.__class__._shared_task_id or self.task.task_id)
|
| 101 |
self.task = task
|
| 102 |
-
self._max_steps =
|
| 103 |
-
self._step_count =
|
| 104 |
-
self._reward_total =
|
| 105 |
-
self._done =
|
| 106 |
-
self._last_feedback =
|
| 107 |
-
self._history = [entry.model_copy(deep=True) for entry in
|
| 108 |
-
self._case =
|
| 109 |
-
self._episode_id =
|
| 110 |
-
self._score =
|
| 111 |
-
self._completed_milestones = list(
|
| 112 |
-
self._current_sla_minutes_remaining =
|
| 113 |
-
|
| 114 |
-
def
|
| 115 |
-
self.
|
| 116 |
-
|
| 117 |
-
self.__class__.
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
@property
|
| 128 |
def state(self) -> SupportDeskState:
|
| 129 |
with self.__class__._state_lock:
|
| 130 |
-
self.
|
| 131 |
return SupportDeskState(
|
| 132 |
episode_id=self._episode_id,
|
| 133 |
task_id=self.task.task_id,
|
|
@@ -160,21 +177,23 @@ class SupportDeskEnvironment(
|
|
| 160 |
self.__class__._shared_reset_counter += 1
|
| 161 |
self.task = get_task(next_task_id)
|
| 162 |
self._max_steps = self.task.max_steps
|
| 163 |
-
self.
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
self.
|
|
|
|
| 168 |
return self._build_observation(reward=0.0, done=False)
|
| 169 |
|
| 170 |
def step(
|
| 171 |
self,
|
| 172 |
action: SupportDeskAction,
|
| 173 |
timeout_s: float | None = None,
|
|
|
|
| 174 |
**kwargs,
|
| 175 |
) -> SupportDeskObservation:
|
| 176 |
with self.__class__._state_lock:
|
| 177 |
-
self.
|
| 178 |
|
| 179 |
if self._done:
|
| 180 |
return self._build_observation(
|
|
@@ -227,10 +246,18 @@ class SupportDeskEnvironment(
|
|
| 227 |
reward_delta=reward,
|
| 228 |
)
|
| 229 |
)
|
| 230 |
-
self.
|
| 231 |
|
| 232 |
return self._build_observation(reward=reward, done=self._done)
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
def close(self) -> None:
|
| 235 |
"""No-op close hook for compatibility with local scripts."""
|
| 236 |
|
|
|
|
| 6 |
import threading
|
| 7 |
import uuid
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import ClassVar
|
| 10 |
|
| 11 |
from supportdesk_env.graders import grade_case
|
| 12 |
from supportdesk_env.models import (
|
|
|
|
| 34 |
):
|
| 35 |
"""A realistic customer support triage environment with dense rewards."""
|
| 36 |
|
| 37 |
+
_state_lock: ClassVar[threading.RLock] = threading.RLock()
|
| 38 |
+
_episode_store: ClassVar[dict[str, SupportDeskState]] = {}
|
| 39 |
+
_episode_task_ids: ClassVar[dict[str, str]] = {}
|
| 40 |
+
_latest_episode_id: ClassVar[str | None] = None
|
| 41 |
+
_shared_reset_counter: ClassVar[int] = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def __init__(self, task_id: str | None = None):
|
| 44 |
super().__init__()
|
|
|
|
| 58 |
initial_grade = grade_case(self.task, self._case)
|
| 59 |
self._score = initial_grade.total_score
|
| 60 |
self._completed_milestones = list(initial_grade.completed_milestones)
|
|
|
|
| 61 |
|
| 62 |
@classmethod
|
| 63 |
+
def _build_initial_state(cls, task: SupportTaskSpec, episode_id: str) -> SupportDeskState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
initial_case = SupportCaseProgress()
|
| 65 |
initial_grade = grade_case(task, initial_case)
|
| 66 |
+
return SupportDeskState(
|
| 67 |
+
episode_id=episode_id,
|
| 68 |
+
task_id=task.task_id,
|
| 69 |
+
difficulty=task.difficulty,
|
| 70 |
+
step_count=0,
|
| 71 |
+
reward=0.0,
|
| 72 |
+
done=False,
|
| 73 |
+
current_score=initial_grade.total_score,
|
| 74 |
+
max_steps=task.max_steps,
|
| 75 |
+
case=initial_case,
|
| 76 |
+
current_sla_minutes_remaining=task.ticket.sla_minutes_remaining,
|
| 77 |
+
workflow_stage="intake",
|
| 78 |
+
required_next_actions=["classify"],
|
| 79 |
+
risk_flags=[],
|
| 80 |
+
action_history=[],
|
| 81 |
+
completed_milestones=list(initial_grade.completed_milestones),
|
| 82 |
+
last_feedback="New case loaded. Review the ticket and policy snippets before acting.",
|
| 83 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
@classmethod
|
| 86 |
+
def _extract_episode_id(cls, episode_id: str | None = None, **kwargs) -> str | None:
|
| 87 |
+
if episode_id:
|
| 88 |
+
return episode_id
|
| 89 |
+
for key in ("episode_id", "request_id"):
|
| 90 |
+
value = kwargs.get(key)
|
| 91 |
+
if isinstance(value, str) and value:
|
| 92 |
+
return value
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
def _load_episode(self, episode_id: str | None = None, **kwargs) -> None:
|
| 96 |
+
resolved_episode_id = self._extract_episode_id(episode_id, **kwargs) or self.__class__._latest_episode_id
|
| 97 |
+
if not resolved_episode_id:
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
episode_state = self.__class__._episode_store.get(resolved_episode_id)
|
| 101 |
+
if episode_state is None:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
f"Unknown episode_id '{resolved_episode_id}'. Call reset() first or provide a valid episode_id."
|
| 104 |
+
)
|
| 105 |
|
| 106 |
+
task = get_task(self.__class__._episode_task_ids.get(resolved_episode_id, episode_state.task_id))
|
|
|
|
| 107 |
self.task = task
|
| 108 |
+
self._max_steps = episode_state.max_steps
|
| 109 |
+
self._step_count = episode_state.step_count
|
| 110 |
+
self._reward_total = episode_state.reward
|
| 111 |
+
self._done = episode_state.done
|
| 112 |
+
self._last_feedback = episode_state.last_feedback
|
| 113 |
+
self._history = [entry.model_copy(deep=True) for entry in episode_state.action_history]
|
| 114 |
+
self._case = episode_state.case.model_copy(deep=True)
|
| 115 |
+
self._episode_id = resolved_episode_id
|
| 116 |
+
self._score = episode_state.current_score
|
| 117 |
+
self._completed_milestones = list(episode_state.completed_milestones)
|
| 118 |
+
self._current_sla_minutes_remaining = episode_state.current_sla_minutes_remaining
|
| 119 |
+
|
| 120 |
+
def _persist_episode(self) -> None:
|
| 121 |
+
if self._episode_id is None:
|
| 122 |
+
return
|
| 123 |
+
self.__class__._episode_store[self._episode_id] = SupportDeskState(
|
| 124 |
+
episode_id=self._episode_id,
|
| 125 |
+
task_id=self.task.task_id,
|
| 126 |
+
difficulty=self.task.difficulty,
|
| 127 |
+
step_count=self._step_count,
|
| 128 |
+
reward=round(self._reward_total, 4),
|
| 129 |
+
done=self._done,
|
| 130 |
+
current_score=round(self._score, 4),
|
| 131 |
+
max_steps=self._max_steps,
|
| 132 |
+
case=self._case.model_copy(deep=True),
|
| 133 |
+
current_sla_minutes_remaining=self._current_sla_minutes_remaining,
|
| 134 |
+
workflow_stage=self._workflow_stage(),
|
| 135 |
+
required_next_actions=self._required_next_actions(),
|
| 136 |
+
risk_flags=self._risk_flags(),
|
| 137 |
+
action_history=[entry.model_copy(deep=True) for entry in self._history],
|
| 138 |
+
completed_milestones=list(self._completed_milestones),
|
| 139 |
+
last_feedback=self._last_feedback,
|
| 140 |
+
)
|
| 141 |
+
self.__class__._episode_task_ids[self._episode_id] = self.task.task_id
|
| 142 |
+
self.__class__._latest_episode_id = self._episode_id
|
| 143 |
|
| 144 |
@property
|
| 145 |
def state(self) -> SupportDeskState:
|
| 146 |
with self.__class__._state_lock:
|
| 147 |
+
self._load_episode()
|
| 148 |
return SupportDeskState(
|
| 149 |
episode_id=self._episode_id,
|
| 150 |
task_id=self.task.task_id,
|
|
|
|
| 177 |
self.__class__._shared_reset_counter += 1
|
| 178 |
self.task = get_task(next_task_id)
|
| 179 |
self._max_steps = self.task.max_steps
|
| 180 |
+
self._episode_id = episode_id or f"{self.task.task_id}-{uuid.uuid4().hex[:8]}"
|
| 181 |
+
initial_state = self.__class__._build_initial_state(self.task, self._episode_id)
|
| 182 |
+
self.__class__._episode_store[self._episode_id] = initial_state
|
| 183 |
+
self.__class__._episode_task_ids[self._episode_id] = self.task.task_id
|
| 184 |
+
self.__class__._latest_episode_id = self._episode_id
|
| 185 |
+
self._load_episode(self._episode_id)
|
| 186 |
return self._build_observation(reward=0.0, done=False)
|
| 187 |
|
| 188 |
def step(
|
| 189 |
self,
|
| 190 |
action: SupportDeskAction,
|
| 191 |
timeout_s: float | None = None,
|
| 192 |
+
episode_id: str | None = None,
|
| 193 |
**kwargs,
|
| 194 |
) -> SupportDeskObservation:
|
| 195 |
with self.__class__._state_lock:
|
| 196 |
+
self._load_episode(episode_id, **kwargs)
|
| 197 |
|
| 198 |
if self._done:
|
| 199 |
return self._build_observation(
|
|
|
|
| 246 |
reward_delta=reward,
|
| 247 |
)
|
| 248 |
)
|
| 249 |
+
self._persist_episode()
|
| 250 |
|
| 251 |
return self._build_observation(reward=reward, done=self._done)
|
| 252 |
|
| 253 |
+
@classmethod
|
| 254 |
+
def state_for_episode(cls, episode_id: str) -> SupportDeskState:
|
| 255 |
+
with cls._state_lock:
|
| 256 |
+
state = cls._episode_store.get(episode_id)
|
| 257 |
+
if state is None:
|
| 258 |
+
raise ValueError(f"Unknown episode_id '{episode_id}'. Call reset() first.")
|
| 259 |
+
return state.model_copy(deep=True)
|
| 260 |
+
|
| 261 |
def close(self) -> None:
|
| 262 |
"""No-op close hook for compatibility with local scripts."""
|
| 263 |
|
tests/test_supportdesk.py
CHANGED
|
@@ -66,7 +66,7 @@ def test_perfect_solution_grades_full_score():
|
|
| 66 |
)
|
| 67 |
|
| 68 |
breakdown = grade_case(task, env.state.case)
|
| 69 |
-
assert breakdown.total_score ==
|
| 70 |
|
| 71 |
|
| 72 |
def test_max_steps_ends_episode():
|
|
@@ -83,7 +83,7 @@ def test_grade_is_bounded_between_zero_and_one():
|
|
| 83 |
env = SupportDeskEnvironment(task_id=task.task_id)
|
| 84 |
env.reset()
|
| 85 |
breakdown = grade_case(task, env.state.case)
|
| 86 |
-
assert 0.0 <
|
| 87 |
|
| 88 |
|
| 89 |
def test_state_includes_episode_id_after_reset():
|
|
@@ -167,3 +167,37 @@ def test_http_reset_step_state_are_session_consistent():
|
|
| 167 |
assert state_payload["case"]["queue"] == "billing_ops"
|
| 168 |
assert state_payload["case"]["priority"] == "high"
|
| 169 |
assert state_payload["case"]["issue_type"] == "duplicate_charge"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
breakdown = grade_case(task, env.state.case)
|
| 69 |
+
assert breakdown.total_score == 0.999
|
| 70 |
|
| 71 |
|
| 72 |
def test_max_steps_ends_episode():
|
|
|
|
| 83 |
env = SupportDeskEnvironment(task_id=task.task_id)
|
| 84 |
env.reset()
|
| 85 |
breakdown = grade_case(task, env.state.case)
|
| 86 |
+
assert 0.0 < breakdown.total_score < 1.0
|
| 87 |
|
| 88 |
|
| 89 |
def test_state_includes_episode_id_after_reset():
|
|
|
|
| 167 |
assert state_payload["case"]["queue"] == "billing_ops"
|
| 168 |
assert state_payload["case"]["priority"] == "high"
|
| 169 |
assert state_payload["case"]["issue_type"] == "duplicate_charge"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@pytest.mark.skipif(TestClient is None, reason="httpx is not installed for FastAPI TestClient")
|
| 173 |
+
def test_http_explicit_episode_helpers_work():
|
| 174 |
+
from supportdesk_env.server.app import app
|
| 175 |
+
|
| 176 |
+
client = TestClient(app)
|
| 177 |
+
|
| 178 |
+
episode_id = "explicit-http-episode"
|
| 179 |
+
reset_response = client.post("/reset", json={"episode_id": episode_id})
|
| 180 |
+
assert reset_response.status_code == 200
|
| 181 |
+
|
| 182 |
+
step_response = client.post(
|
| 183 |
+
f"/episodes/{episode_id}/step",
|
| 184 |
+
json={
|
| 185 |
+
"action": {
|
| 186 |
+
"operation": "classify",
|
| 187 |
+
"queue": "billing_ops",
|
| 188 |
+
"priority": "high",
|
| 189 |
+
"issue_type": "duplicate_charge",
|
| 190 |
+
}
|
| 191 |
+
},
|
| 192 |
+
)
|
| 193 |
+
assert step_response.status_code == 200
|
| 194 |
+
|
| 195 |
+
state_response = client.get(f"/episodes/{episode_id}/state")
|
| 196 |
+
assert state_response.status_code == 200
|
| 197 |
+
state_payload = state_response.json()
|
| 198 |
+
|
| 199 |
+
assert state_payload["episode_id"] == episode_id
|
| 200 |
+
assert state_payload["step_count"] == 1
|
| 201 |
+
assert state_payload["case"]["queue"] == "billing_ops"
|
| 202 |
+
assert state_payload["case"]["priority"] == "high"
|
| 203 |
+
assert state_payload["case"]["issue_type"] == "duplicate_charge"
|