Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Dict, List, Optional, Tuple | |
| from .graders import grade_task | |
| from .models import Action, Observation, RewardModel, StateModel, StepInfo, TaskSpec, TicketObservation | |
| from .reward import STEP_PENALTY, build_reward | |
| from .state import discovered_for_ticket, initial_tracking, update_mapping | |
| from .tasks import get_all_tasks, get_task | |
| class SupportOpsEnv: | |
| """OpenEnv-shaped benchmark for support operations workflows.""" | |
| def __init__(self, task_id: Optional[str] = None): | |
| self._tasks = {task.task_id: task for task in get_all_tasks()} | |
| self._task_order = sorted(self._tasks) | |
| self._task_id = task_id or self._task_order[0] | |
| self._task: TaskSpec = self._tasks[self._task_id] | |
| self._state: StateModel = initial_tracking(self._task) | |
| def reset(self, task_id: Optional[str] = None) -> Observation: | |
| if task_id is not None: | |
| self._task = get_task(task_id) | |
| self._task_id = task_id | |
| self._state = initial_tracking(self._task) | |
| return self._build_observation() | |
| def state(self) -> StateModel: | |
| return self._state.model_copy(deep=True) | |
| def step(self, action: Action) -> Tuple[Observation, RewardModel, bool, Dict[str, object]]: | |
| if self._state.done: | |
| reward = build_reward({"invalid_after_done": -0.1}, "Episode already finished.") | |
| info = StepInfo( | |
| task_id=self._task.task_id, | |
| step_count=self._state.step_count, | |
| task_score=self._state.latest_score.get("task_score", 0.0), | |
| done_reason="already_done", | |
| event="invalid_after_done", | |
| event_score=reward.components, | |
| ) | |
| return self._build_observation(), reward, True, info.model_dump() | |
| self._state.step_count += 1 | |
| event_scores: Dict[str, float] = {"step_penalty": STEP_PENALTY} | |
| event_name = action.action_type | |
| done_reason = None | |
| if action.action_type == "inspect_ticket": | |
| event_scores.update(self._handle_inspect(action)) | |
| elif action.action_type == "request_context": | |
| event_scores.update(self._handle_request_context(action)) | |
| elif action.action_type == "set_priority": | |
| event_scores.update(self._handle_priority(action)) | |
| elif action.action_type == "set_route": | |
| event_scores.update(self._handle_route(action)) | |
| elif action.action_type == "set_resolution": | |
| event_scores.update(self._handle_resolution(action)) | |
| elif action.action_type == "escalate": | |
| event_scores.update(self._handle_escalation(action)) | |
| elif action.action_type == "rank_queue": | |
| event_scores.update(self._handle_rank_queue(action)) | |
| elif action.action_type == "finalize": | |
| self._state.done = True | |
| done_reason = "agent_finalized" | |
| grade = grade_task(self._task, self._state) | |
| self._state.latest_score = {"task_score": grade.score, **grade.component_scores} | |
| event_scores["terminal_grade"] = grade.score | |
| reward = build_reward(event_scores, "Final task grade applied.") | |
| self._state.cumulative_reward = round(self._state.cumulative_reward + reward.value, 4) | |
| info = StepInfo( | |
| task_id=self._task.task_id, | |
| step_count=self._state.step_count, | |
| task_score=grade.score, | |
| done_reason=done_reason, | |
| grade=grade, | |
| event=event_name, | |
| event_score=reward.components, | |
| ) | |
| return self._build_observation(), reward, True, info.model_dump() | |
| else: | |
| event_scores["invalid_action"] = -0.1 | |
| event_name = "invalid_action" | |
| grade = grade_task(self._task, self._state) | |
| self._state.latest_score = {"task_score": grade.score, **grade.component_scores} | |
| if self._state.step_count >= self._task.max_steps and not self._state.done: | |
| self._state.done = True | |
| done_reason = "max_steps" | |
| event_scores["timeout_grade"] = grade.score | |
| reward = build_reward(event_scores, f"Processed {event_name}.") | |
| self._state.cumulative_reward = round(self._state.cumulative_reward + reward.value, 4) | |
| info = StepInfo( | |
| task_id=self._task.task_id, | |
| step_count=self._state.step_count, | |
| task_score=grade.score, | |
| done_reason=done_reason, | |
| grade=grade if self._state.done else None, | |
| event=event_name, | |
| event_score=reward.components, | |
| ) | |
| return self._build_observation(), reward, self._state.done, info.model_dump() | |
| def _build_observation(self) -> Observation: | |
| tickets: List[TicketObservation] = [] | |
| for ticket in self._task.tickets: | |
| keys = self._state.discovered_keys.get(ticket.ticket_id, []) | |
| discovered_context = {key: ticket.hidden_context[key] for key in keys} | |
| tickets.append( | |
| TicketObservation( | |
| ticket_id=ticket.ticket_id, | |
| summary=ticket.summary, | |
| visible_context=ticket.visible_context, | |
| discovered_context=discovered_context, | |
| selected_priority=self._state.priorities.get(ticket.ticket_id), | |
| selected_route=self._state.routes.get(ticket.ticket_id), | |
| selected_resolution=self._state.resolutions.get(ticket.ticket_id), | |
| escalation_team=self._state.escalations.get(ticket.ticket_id), | |
| ) | |
| ) | |
| return Observation( | |
| task_id=self._task.task_id, | |
| difficulty=self._task.difficulty, | |
| title=self._task.title, | |
| instruction=self._task.instruction, | |
| queue_mode=self._task.queue_mode, | |
| tickets=tickets, | |
| remaining_steps=max(self._task.max_steps - self._state.step_count, 0), | |
| available_actions=[ | |
| "inspect_ticket", | |
| "request_context", | |
| "set_priority", | |
| "set_route", | |
| "set_resolution", | |
| "escalate", | |
| "rank_queue", | |
| "finalize", | |
| ], | |
| current_queue_order=self._state.queue_order, | |
| score_hint=self._state.latest_score, | |
| ) | |
| def _find_ticket(self, ticket_id: str): | |
| for ticket in self._task.tickets: | |
| if ticket.ticket_id == ticket_id: | |
| return ticket | |
| return None | |
| def _handle_inspect(self, action: Action) -> Dict[str, float]: | |
| ticket = self._find_ticket(action.target) | |
| if ticket is None: | |
| return {"invalid_ticket": -0.1} | |
| key = f"inspected::{ticket.ticket_id}" | |
| notes = self._state.latest_score.setdefault("inspections", 0.0) | |
| if notes and key in self._state.latest_score: | |
| return {"redundant_inspect": -0.03} | |
| self._state.latest_score[key] = 1.0 | |
| return {"inspect": 0.03} | |
| def _handle_request_context(self, action: Action) -> Dict[str, float]: | |
| ticket = self._find_ticket(action.target) | |
| if ticket is None or not action.value: | |
| return {"invalid_context_request": -0.1} | |
| if action.value not in ticket.hidden_context: | |
| return {"unknown_context_key": -0.08} | |
| discovered = discovered_for_ticket(self._state.discovered_keys, ticket.ticket_id) | |
| if action.value in discovered: | |
| return {"redundant_context_request": -0.05} | |
| discovered.append(action.value) | |
| if action.value in ticket.required_context: | |
| return {"required_context_found": 0.12} | |
| return {"optional_context_found": 0.04} | |
| def _handle_priority(self, action: Action) -> Dict[str, float]: | |
| ticket = self._find_ticket(action.target) | |
| if ticket is None or not action.value: | |
| return {"invalid_priority": -0.1} | |
| current = self._state.priorities.get(ticket.ticket_id) | |
| update_mapping(self._state.priorities, ticket.ticket_id, action.value) | |
| if action.value == current: | |
| return {"redundant_priority": -0.03} | |
| return {"priority_match": 0.08 if action.value == ticket.gold_priority else -0.04} | |
| def _handle_route(self, action: Action) -> Dict[str, float]: | |
| ticket = self._find_ticket(action.target) | |
| if ticket is None or not action.value: | |
| return {"invalid_route": -0.1} | |
| current = self._state.routes.get(ticket.ticket_id) | |
| update_mapping(self._state.routes, ticket.ticket_id, action.value) | |
| if action.value == current: | |
| return {"redundant_route": -0.03} | |
| return {"route_match": 0.1 if action.value == ticket.gold_route else -0.06} | |
| def _handle_resolution(self, action: Action) -> Dict[str, float]: | |
| ticket = self._find_ticket(action.target) | |
| if ticket is None or not action.value: | |
| return {"invalid_resolution": -0.1} | |
| current = self._state.resolutions.get(ticket.ticket_id) | |
| update_mapping(self._state.resolutions, ticket.ticket_id, action.value) | |
| if action.value == current: | |
| return {"redundant_resolution": -0.03} | |
| return {"resolution_match": 0.12 if action.value == ticket.gold_resolution else -0.08} | |
| def _handle_escalation(self, action: Action) -> Dict[str, float]: | |
| ticket = self._find_ticket(action.target) | |
| if ticket is None: | |
| return {"invalid_escalation": -0.1} | |
| team = action.value | |
| current = self._state.escalations.get(ticket.ticket_id) | |
| update_mapping(self._state.escalations, ticket.ticket_id, team) | |
| if team == current: | |
| return {"redundant_escalation": -0.03} | |
| if team == ticket.gold_escalation_team: | |
| return {"correct_escalation": 0.1} | |
| if ticket.gold_escalation_team is None and team is None: | |
| return {"correct_no_escalation": 0.03} | |
| return {"incorrect_escalation": -0.1} | |
| def _handle_rank_queue(self, action: Action) -> Dict[str, float]: | |
| if not self._task.queue_mode or not action.value: | |
| return {"invalid_queue_ranking": -0.1} | |
| ranked = [item.strip() for item in action.value.split(",") if item.strip()] | |
| valid_ticket_ids = {ticket.ticket_id for ticket in self._task.tickets} | |
| if set(ranked) != valid_ticket_ids: | |
| return {"malformed_queue_ranking": -0.08} | |
| self._state.queue_order = ranked | |
| correct_positions = sum( | |
| 1 for observed, expected in zip(ranked, self._task.gold_queue_order) if observed == expected | |
| ) | |
| return {"queue_progress": round((correct_positions / len(ranked)) * 0.12, 4)} | |