HyperBrickCaseOps / server /supportdesk_environment.py
modelbuilderhq's picture
Upload folder using huggingface_hub
4f129c9 verified
"""SupportDesk environment implementation."""
from __future__ import annotations
import os
import threading
import uuid
from pathlib import Path
from typing import ClassVar
from graders import grade_case
from models import (
ActionHistoryEntry,
CustomerFollowUp,
SupportCaseProgress,
SupportDeskAction,
SupportDeskObservation,
SupportDeskState,
)
from openenv_compat import Environment, EnvironmentMetadata
from tasks import (
ALL_ISSUE_TYPES,
ALL_PRIORITIES,
ALL_QUEUES,
ALL_STATUSES,
SupportTaskSpec,
get_task,
list_task_ids,
)
class SupportDeskEnvironment(
Environment[SupportDeskAction, SupportDeskObservation, SupportDeskState]
):
"""A realistic customer support triage environment with dense rewards."""
_state_lock: ClassVar[threading.RLock] = threading.RLock()
_episode_store: ClassVar[dict[str, SupportDeskState]] = {}
_episode_task_ids: ClassVar[dict[str, str]] = {}
_latest_episode_id: ClassVar[str | None] = None
_shared_reset_counter: ClassVar[int] = 0
def __init__(self, task_id: str | None = None):
super().__init__()
env_task_id = os.getenv("SUPPORTDESK_TASK_ID")
self._explicit_task_id = task_id is not None or env_task_id is not None
requested_task = task_id or env_task_id or list_task_ids()[0]
self.task: SupportTaskSpec = get_task(requested_task)
self._max_steps = self.task.max_steps
self._step_count = 0
self._reward_total = 0.0
self._done = False
self._last_feedback = ""
self._history: list[ActionHistoryEntry] = []
self._case = SupportCaseProgress()
self._episode_id: str | None = None
self._current_sla_minutes_remaining = self.task.ticket.sla_minutes_remaining
initial_grade = grade_case(self.task, self._case)
self._score = initial_grade.total_score
self._completed_milestones = list(initial_grade.completed_milestones)
@classmethod
def _build_initial_state(cls, task: SupportTaskSpec, episode_id: str) -> SupportDeskState:
initial_case = SupportCaseProgress()
initial_grade = grade_case(task, initial_case)
return SupportDeskState(
episode_id=episode_id,
task_id=task.task_id,
difficulty=task.difficulty,
step_count=0,
reward=0.0,
done=False,
current_score=initial_grade.total_score,
max_steps=task.max_steps,
case=initial_case,
current_sla_minutes_remaining=task.ticket.sla_minutes_remaining,
workflow_stage="intake",
required_next_actions=["classify"],
risk_flags=[],
action_history=[],
completed_milestones=list(initial_grade.completed_milestones),
last_feedback="New case loaded. Review the ticket and policy snippets before acting.",
)
@classmethod
def _extract_episode_id(cls, episode_id: str | None = None, **kwargs) -> str | None:
if episode_id:
return episode_id
for key in ("episode_id", "request_id"):
value = kwargs.get(key)
if isinstance(value, str) and value:
return value
return None
def _load_episode(self, episode_id: str | None = None, **kwargs) -> None:
resolved_episode_id = self._extract_episode_id(episode_id, **kwargs) or self.__class__._latest_episode_id
if not resolved_episode_id:
return
episode_state = self.__class__._episode_store.get(resolved_episode_id)
if episode_state is None:
raise ValueError(
f"Unknown episode_id '{resolved_episode_id}'. Call reset() first or provide a valid episode_id."
)
task = get_task(self.__class__._episode_task_ids.get(resolved_episode_id, episode_state.task_id))
self.task = task
self._max_steps = episode_state.max_steps
self._step_count = episode_state.step_count
self._reward_total = episode_state.reward
self._done = episode_state.done
self._last_feedback = episode_state.last_feedback
self._history = [entry.model_copy(deep=True) for entry in episode_state.action_history]
self._case = episode_state.case.model_copy(deep=True)
self._episode_id = resolved_episode_id
self._score = episode_state.current_score
self._completed_milestones = list(episode_state.completed_milestones)
self._current_sla_minutes_remaining = episode_state.current_sla_minutes_remaining
def _persist_episode(self) -> None:
if self._episode_id is None:
return
self.__class__._episode_store[self._episode_id] = SupportDeskState(
episode_id=self._episode_id,
task_id=self.task.task_id,
difficulty=self.task.difficulty,
step_count=self._step_count,
reward=round(self._reward_total, 4),
done=self._done,
current_score=round(self._score, 4),
max_steps=self._max_steps,
case=self._case.model_copy(deep=True),
current_sla_minutes_remaining=self._current_sla_minutes_remaining,
workflow_stage=self._workflow_stage(),
required_next_actions=self._required_next_actions(),
risk_flags=self._risk_flags(),
action_history=[entry.model_copy(deep=True) for entry in self._history],
completed_milestones=list(self._completed_milestones),
last_feedback=self._last_feedback,
)
self.__class__._episode_task_ids[self._episode_id] = self.task.task_id
self.__class__._latest_episode_id = self._episode_id
@property
def state(self) -> SupportDeskState:
with self.__class__._state_lock:
self._load_episode()
return SupportDeskState(
episode_id=self._episode_id,
task_id=self.task.task_id,
difficulty=self.task.difficulty,
step_count=self._step_count,
reward=round(self._reward_total, 4),
done=self._done,
current_score=round(self._score, 4),
max_steps=self._max_steps,
case=self._case.model_copy(deep=True),
current_sla_minutes_remaining=self._current_sla_minutes_remaining,
workflow_stage=self._workflow_stage(),
required_next_actions=self._required_next_actions(),
risk_flags=self._risk_flags(),
action_history=[entry.model_copy(deep=True) for entry in self._history],
completed_milestones=list(self._completed_milestones),
last_feedback=self._last_feedback,
)
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
**kwargs,
) -> SupportDeskObservation:
with self.__class__._state_lock:
if not self._explicit_task_id:
task_ids = list_task_ids()
next_task_id = task_ids[self.__class__._shared_reset_counter % len(task_ids)]
self.__class__._shared_reset_counter += 1
self.task = get_task(next_task_id)
self._max_steps = self.task.max_steps
self._episode_id = episode_id or f"{self.task.task_id}-{uuid.uuid4().hex[:8]}"
initial_state = self.__class__._build_initial_state(self.task, self._episode_id)
self.__class__._episode_store[self._episode_id] = initial_state
self.__class__._episode_task_ids[self._episode_id] = self.task.task_id
self.__class__._latest_episode_id = self._episode_id
self._load_episode(self._episode_id)
return self._build_observation(reward=0.0, done=False)
def step(
self,
action: SupportDeskAction,
timeout_s: float | None = None,
episode_id: str | None = None,
**kwargs,
) -> SupportDeskObservation:
with self.__class__._state_lock:
self._load_episode(episode_id, **kwargs)
if self._done:
return self._build_observation(
reward=-0.05,
done=True,
feedback="Episode already finished. Call reset() before taking more actions.",
)
previous_grade = grade_case(self.task, self._case)
previous_stage = self._workflow_stage()
self._apply_action(action)
self._step_count += 1
self._advance_external_events(action)
self._degrade_sla()
current_grade = grade_case(self.task, self._case)
reward = current_grade.total_score - previous_grade.total_score
reward += self._process_bonus(action, previous_stage, current_grade.total_score)
reward += self._action_penalty(
action,
current_grade.total_score,
previous_grade.total_score,
)
reward = round(reward, 4)
self._score = current_grade.total_score
self._completed_milestones = list(current_grade.completed_milestones)
if action.operation == "submit":
self._done = True
self._last_feedback = (
"Case submitted. Final deterministic grade is "
f"{current_grade.total_score:.2f}."
)
elif self._step_count >= self._max_steps:
self._done = True
self._last_feedback = (
f"Reached max steps ({self._max_steps}). Final deterministic grade is "
f"{current_grade.total_score:.2f}."
)
else:
self._last_feedback = self._build_feedback(current_grade, reward)
self._reward_total = round(self._reward_total + reward, 4)
self._history.append(
ActionHistoryEntry(
step=self._step_count,
operation=action.operation,
summary=self._summarize_action(action),
reward_delta=reward,
)
)
self._persist_episode()
return self._build_observation(reward=reward, done=self._done)
@classmethod
def state_for_episode(cls, episode_id: str) -> SupportDeskState:
with cls._state_lock:
state = cls._episode_store.get(episode_id)
if state is None:
raise ValueError(f"Unknown episode_id '{episode_id}'. Call reset() first.")
return state.model_copy(deep=True)
def close(self) -> None:
"""No-op close hook for compatibility with local scripts."""
def get_metadata(self) -> EnvironmentMetadata:
"""Return richer metadata for docs, validators, and HF Space UI."""
readme_path = Path(__file__).resolve().parents[1] / "README.md"
readme_content = readme_path.read_text(encoding="utf-8") if readme_path.exists() else None
return EnvironmentMetadata(
name="supportdesk_env",
description=(
"A policy-heavy enterprise operations desk with deterministic grading, delayed "
"customer follow-ups, SLA pressure, escalation tradeoffs, and sharper cross-functional triage."
),
readme_content=readme_content,
version="0.1.0",
author="HyperBrick",
)
def _apply_action(self, action: SupportDeskAction) -> None:
if action.operation == "classify":
if action.queue is not None:
self._case.queue = action.queue
if action.priority is not None:
self._case.priority = action.priority
if action.issue_type is not None:
self._case.issue_type = action.issue_type
return
if action.operation == "request_info":
if action.requested_fields:
merged = {item for item in self._case.requested_fields}
merged.update(action.requested_fields)
self._case.requested_fields = sorted(merged)
if self.task.follow_up_outcome != "none" and self._case.customer_follow_up.status == "none":
self._case.customer_follow_up = CustomerFollowUp(status="pending")
return
if action.operation == "draft_reply":
if action.reply is not None:
self._case.reply = action.reply
return
if action.operation == "add_internal_note":
if action.internal_note is not None:
self._case.internal_note = action.internal_note
return
if action.operation == "submit":
if action.status is not None:
self._case.status = action.status
if action.resolution_code is not None:
self._case.resolution_code = action.resolution_code
def _advance_external_events(self, action: SupportDeskAction) -> None:
if self._case.customer_follow_up.status == "pending" and action.operation == "wait":
self._case.customer_follow_up = CustomerFollowUp(
status=self.task.follow_up_outcome,
message=self.task.follow_up_message or None,
provided_fields=list(self.task.follow_up_provided_fields),
wrong_fields=list(self.task.follow_up_wrong_fields),
)
def _degrade_sla(self) -> None:
if self._current_sla_minutes_remaining is None:
return
self._current_sla_minutes_remaining = max(
0,
self._current_sla_minutes_remaining - self.task.sla_step_cost,
)
def _action_penalty(
self,
action: SupportDeskAction,
current_score: float,
previous_score: float,
) -> float:
penalty = 0.0
if current_score <= previous_score:
penalty -= 0.03
penalty -= self._mixed_action_penalty(action)
penalty -= self._escalation_tradeoff_penalty()
if action.operation == "draft_reply" and not action.reply:
penalty -= 0.03
if action.operation == "request_info" and not action.requested_fields:
penalty -= 0.03
if action.operation == "add_internal_note" and not action.internal_note:
penalty -= 0.03
if action.operation == "classify" and not any(
[action.queue, action.priority, action.issue_type, action.status, action.resolution_code]
):
penalty -= 0.03
if action.operation == "wait" and self._case.customer_follow_up.status != "pending":
penalty -= 0.02
if action.operation == "submit" and self._required_next_actions():
penalty -= 0.08
if (
self.task.under_escalation_deadline_step is not None
and self._step_count >= self.task.under_escalation_deadline_step
and (self._case.queue != self.task.gold_queue or self._case.priority != self.task.gold_priority)
):
penalty -= 0.04
if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining <= 15:
penalty -= 0.02
return round(penalty, 4)
def _build_feedback(self, grade, reward: float) -> str:
return (
f"Reward delta {reward:+.2f}. Current score {grade.total_score:.2f}. "
f"SLA remaining: {self._current_sla_minutes_remaining if self._current_sla_minutes_remaining is not None else 'n/a'} minutes. "
f"Stage: {self._workflow_stage()}. "
f"Customer follow-up: {self._case.customer_follow_up.status}. "
f"Next actions: {', '.join(self._required_next_actions()) or 'none'}. "
f"Completed milestones: {', '.join(grade.completed_milestones) or 'none yet'}."
)
def _summarize_action(self, action: SupportDeskAction) -> str:
parts = [action.operation]
if action.queue:
parts.append(f"queue={action.queue}")
if action.priority:
parts.append(f"priority={action.priority}")
if action.issue_type:
parts.append(f"issue_type={action.issue_type}")
if action.status:
parts.append(f"status={action.status}")
if action.resolution_code:
parts.append(f"resolution={action.resolution_code}")
if action.requested_fields:
parts.append(f"requested={','.join(action.requested_fields)}")
if action.reply:
parts.append("reply=yes")
if action.internal_note:
parts.append("note=yes")
return " | ".join(parts)
def _build_observation(
self,
reward: float,
done: bool,
feedback: str | None = None,
) -> SupportDeskObservation:
return SupportDeskObservation(
task_id=self.task.task_id,
difficulty=self.task.difficulty,
objective=self.task.objective,
ticket=self.task.ticket,
knowledge_base=list(self.task.knowledge_base),
available_queues=list(ALL_QUEUES),
available_priorities=list(ALL_PRIORITIES),
available_statuses=list(ALL_STATUSES),
available_issue_types=list(ALL_ISSUE_TYPES),
case=self._case.model_copy(deep=True),
current_sla_minutes_remaining=self._current_sla_minutes_remaining,
workflow_stage=self._workflow_stage(),
required_next_actions=self._required_next_actions(),
risk_flags=self._risk_flags(),
action_history=[entry.model_copy(deep=True) for entry in self._history],
feedback=feedback or self._last_feedback,
remaining_steps=max(self._max_steps - self._step_count, 0),
reward=reward,
done=done,
)
def _workflow_stage(self) -> str:
if self._done:
return "closed"
if self._case.queue is None or self._case.priority is None or self._case.issue_type is None:
return "intake"
if self.task.required_requested_fields and sorted(self._case.requested_fields) != sorted(self.task.required_requested_fields):
return "verification"
if self._case.customer_follow_up.status == "pending":
return "awaiting_customer"
if self._case.customer_follow_up.status in {"partial", "incorrect"}:
return "follow_up_review"
if not self._case.reply:
return "customer_communication"
if not self._case.internal_note:
return "internal_handoff"
if self._case.status != self.task.gold_status or self._case.resolution_code != self.task.gold_resolution_code:
return "final_resolution"
return "ready_to_submit"
def _required_next_actions(self) -> list[str]:
if self._case.queue is None or self._case.priority is None or self._case.issue_type is None:
return ["classify"]
if self.task.required_requested_fields and sorted(self._case.requested_fields) != sorted(self.task.required_requested_fields):
return ["request_info"]
if self._case.customer_follow_up.status == "pending":
return ["wait"]
needed: list[str] = []
if not self._case.reply:
needed.append("draft_reply")
if not self._case.internal_note:
needed.append("add_internal_note")
if self._case.status != self.task.gold_status or self._case.resolution_code != self.task.gold_resolution_code:
needed.append("submit")
return needed
def _risk_flags(self) -> list[str]:
flags = list(self.task.risk_flags)
if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining <= 30:
flags.append("sla_breach_risk")
if self.task.ticket.affected_users and self.task.ticket.affected_users >= 1000:
flags.append("high_customer_impact")
if self.task.ticket.secondary_concerns:
flags.append("secondary_issue_present")
if self._case.customer_follow_up.status == "partial":
flags.append("customer_reply_incomplete")
if self._case.customer_follow_up.status == "incorrect":
flags.append("customer_reply_irrelevant")
return sorted(set(flags))
def _process_bonus(
self,
action: SupportDeskAction,
previous_stage: str,
current_score: float,
) -> float:
bonus = 0.0
stage_rank = {
"intake": 0,
"verification": 1,
"awaiting_customer": 2,
"follow_up_review": 3,
"customer_communication": 4,
"internal_handoff": 5,
"final_resolution": 6,
"ready_to_submit": 7,
"closed": 8,
}
current_stage = self._workflow_stage()
if stage_rank.get(current_stage, 0) > stage_rank.get(previous_stage, 0):
bonus += 0.02
if action.operation == "classify" and self._step_count == 1:
if self._case.queue == self.task.gold_queue and self._case.priority == self.task.gold_priority:
bonus += 0.03
if action.operation == "request_info" and current_score > 0 and self.task.required_requested_fields:
bonus += 0.02
if action.operation == "wait" and self._case.customer_follow_up.status in {"partial", "complete", "incorrect"}:
bonus += 0.02
if action.operation == "submit" and not self._required_next_actions():
bonus += 0.03
if self._current_sla_minutes_remaining is not None and self._current_sla_minutes_remaining > 0:
if self.task.gold_priority == "urgent" and self._step_count <= 2 and self._case.queue == self.task.gold_queue:
bonus += 0.02
return round(bonus, 4)
def _mixed_action_penalty(self, action: SupportDeskAction) -> float:
allowed_fields = {
"classify": {"queue", "priority", "issue_type"},
"request_info": {"requested_fields"},
"draft_reply": {"reply"},
"add_internal_note": {"internal_note"},
"submit": {"status", "resolution_code"},
"wait": set(),
}
populated_fields = {
"queue": action.queue,
"priority": action.priority,
"issue_type": action.issue_type,
"status": action.status,
"resolution_code": action.resolution_code,
"requested_fields": action.requested_fields,
"reply": action.reply,
"internal_note": action.internal_note,
}
extras = 0
for field_name, value in populated_fields.items():
if field_name in allowed_fields[action.operation]:
continue
if value is None:
continue
if isinstance(value, list) and not value:
continue
if isinstance(value, str) and not value:
continue
extras += 1
return min(0.06, extras * 0.02)
def _escalation_tradeoff_penalty(self) -> float:
penalty = 0.0
if self._case.queue in self.task.over_escalation_queues and self._case.queue != self.task.gold_queue:
penalty += 0.06
return round(penalty, 4)