HyperBrickCaseOps / examples /rl /train_q_agent.py
modelbuilderhq's picture
Upload folder using huggingface_hub
181758b verified
"""Train a simple tabular Q-learning agent against the local SupportDesk env.
This is an extra playground script for local experimentation. It is not part of
the hackathon submission baseline and intentionally uses a compact, hand-built
discrete action library so that plain Python Q-learning can train quickly.
"""
from __future__ import annotations
import argparse
import random
import sys
from dataclasses import dataclass
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from supportdesk_env import (
SupportDeskAction,
get_task,
grade_case,
list_task_ids,
)
from supportdesk_env.policies import default_note, default_reply
from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
@dataclass
class EvalResult:
"""Compact report for a greedy evaluation episode."""
task_id: str
score: float
reward: float
steps: int
actions: list[str]
def build_action_library(task_id: str) -> list[SupportDeskAction]:
"""Return a small discrete action set for a task."""
task = get_task(task_id)
wrong_queue = next(queue for queue in ("general_support", "billing_ops", "trust_and_safety", "platform_engineering") if queue != task.gold_queue)
wrong_priority = next(priority for priority in ("low", "normal", "high", "urgent") if priority != task.gold_priority)
wrong_issue = next(issue for issue in ("general_question", "duplicate_charge", "account_compromise", "production_incident") if issue != task.gold_issue_type)
partial_fields = list(task.required_requested_fields[:1])
if not partial_fields:
partial_fields = ["billing_email"]
if task.required_requested_fields:
good_request = SupportDeskAction(
operation="request_info",
requested_fields=list(task.required_requested_fields),
status=task.gold_status,
reply=default_reply(task_id),
)
else:
good_request = SupportDeskAction(
operation="request_info",
requested_fields=["billing_email"],
status="waiting_on_customer",
reply="Please confirm the billing email on the account so we can continue.",
)
partial_request = SupportDeskAction(
operation="request_info",
requested_fields=partial_fields,
status="waiting_on_customer",
reply="Please share more details so we can investigate.",
)
return [
SupportDeskAction(
operation="classify",
queue=task.gold_queue,
priority=task.gold_priority,
issue_type=task.gold_issue_type,
),
SupportDeskAction(
operation="classify",
queue=wrong_queue,
priority=wrong_priority,
issue_type=wrong_issue,
),
good_request,
partial_request,
SupportDeskAction(operation="draft_reply", reply=default_reply(task_id)),
SupportDeskAction(operation="draft_reply", reply="Thanks for reaching out. We are checking this now."),
SupportDeskAction(operation="add_internal_note", internal_note=default_note(task_id)),
SupportDeskAction(operation="add_internal_note", internal_note="Customer contacted support with a problem."),
SupportDeskAction(
operation="submit",
status=task.gold_status,
resolution_code=task.gold_resolution_code,
),
SupportDeskAction(
operation="submit",
status="resolved",
resolution_code="closed_generic",
),
]
def state_key(task_id: str, observation) -> tuple:
"""Compress the observation into a tabular Q-learning state."""
case = observation.case
return (
task_id,
case.queue or "_",
case.priority or "_",
case.issue_type or "_",
case.status,
case.resolution_code or "_",
tuple(case.requested_fields),
bool(case.reply),
bool(case.internal_note),
observation.remaining_steps,
)
def action_label(action: SupportDeskAction) -> str:
"""Human-readable action label for debug output."""
parts = [action.operation]
if action.queue:
parts.append(action.queue)
if action.priority:
parts.append(action.priority)
if action.issue_type:
parts.append(action.issue_type)
if action.status:
parts.append(action.status)
if action.resolution_code:
parts.append(action.resolution_code)
if action.requested_fields:
parts.append(",".join(action.requested_fields))
if action.reply:
parts.append("reply")
if action.internal_note:
parts.append("note")
return " | ".join(parts)
def choose_action(q_values: dict[tuple, list[float]], state: tuple, num_actions: int, epsilon: float) -> int:
"""Epsilon-greedy action selection."""
if state not in q_values:
q_values[state] = [0.0] * num_actions
if random.random() < epsilon:
return random.randrange(num_actions)
best_value = max(q_values[state])
best_indices = [index for index, value in enumerate(q_values[state]) if value == best_value]
return random.choice(best_indices)
def train_q_agent(
episodes_per_task: int,
alpha: float,
gamma: float,
epsilon: float,
epsilon_decay: float,
min_epsilon: float,
seed: int,
) -> tuple[dict[tuple, list[float]], dict[str, list[SupportDeskAction]]]:
"""Train a small tabular Q-learning agent over all tasks."""
random.seed(seed)
q_values: dict[tuple, list[float]] = {}
action_libraries = {task_id: build_action_library(task_id) for task_id in list_task_ids()}
for _ in range(episodes_per_task):
for task_id in list_task_ids():
env = SupportDeskEnvironment(task_id=task_id)
observation = env.reset()
actions = action_libraries[task_id]
try:
while not observation.done:
state = state_key(task_id, observation)
action_index = choose_action(q_values, state, len(actions), epsilon)
next_observation = env.step(actions[action_index])
next_state = state_key(task_id, next_observation)
if next_state not in q_values:
q_values[next_state] = [0.0] * len(actions)
td_target = next_observation.reward + gamma * (0.0 if next_observation.done else max(q_values[next_state]))
td_error = td_target - q_values[state][action_index]
q_values[state][action_index] += alpha * td_error
observation = next_observation
finally:
env.close()
epsilon = max(min_epsilon, epsilon * epsilon_decay)
return q_values, action_libraries
def evaluate_policy(
q_values: dict[tuple, list[float]],
action_libraries: dict[str, list[SupportDeskAction]],
) -> list[EvalResult]:
"""Run a greedy evaluation episode for each task."""
results: list[EvalResult] = []
for task_id in list_task_ids():
env = SupportDeskEnvironment(task_id=task_id)
observation = env.reset()
actions = action_libraries[task_id]
chosen_actions: list[str] = []
try:
while not observation.done:
state = state_key(task_id, observation)
q_values.setdefault(state, [0.0] * len(actions))
action_index = max(range(len(actions)), key=lambda idx: q_values[state][idx])
action = actions[action_index]
chosen_actions.append(action_label(action))
observation = env.step(action)
results.append(
EvalResult(
task_id=task_id,
score=grade_case(get_task(task_id), env.state.case).total_score,
reward=env.state.reward,
steps=env.state.step_count,
actions=chosen_actions,
)
)
finally:
env.close()
return results
def main() -> None:
parser = argparse.ArgumentParser(description="Train a simple tabular Q-learning agent on SupportDesk.")
parser.add_argument("--episodes-per-task", type=int, default=250)
parser.add_argument("--alpha", type=float, default=0.45)
parser.add_argument("--gamma", type=float, default=0.92)
parser.add_argument("--epsilon", type=float, default=0.35)
parser.add_argument("--epsilon-decay", type=float, default=0.99)
parser.add_argument("--min-epsilon", type=float, default=0.03)
parser.add_argument("--seed", type=int, default=7)
args = parser.parse_args()
q_values, action_libraries = train_q_agent(
episodes_per_task=args.episodes_per_task,
alpha=args.alpha,
gamma=args.gamma,
epsilon=args.epsilon,
epsilon_decay=args.epsilon_decay,
min_epsilon=args.min_epsilon,
seed=args.seed,
)
results = evaluate_policy(q_values, action_libraries)
average_score = sum(result.score for result in results) / len(results)
print("Tabular Q-learning evaluation")
print("============================")
for result in results:
print(
f"{result.task_id}: score={result.score:.2f} reward={result.reward:.2f} "
f"steps={result.steps}"
)
print(f" actions: {' -> '.join(result.actions)}")
print(f"average_score={average_score:.3f}")
if __name__ == "__main__":
main()