| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| from openai import OpenAI |
|
|
| from support_triage_openenv import Action, SupportTriageEnv |
|
|
|
|
| SYSTEM_PROMPT = """You are an agent solving a customer-support triage environment. |
| Return exactly one JSON object for the next action with keys: |
| - action_type: read_ticket | classify_ticket | draft_reply | resolve_ticket |
| - ticket_id (required for read/classify/resolve) |
| - priority, category, needs_escalation (for classify) |
| - message (for draft_reply) |
| No markdown, no extra text.""" |
|
|
|
|
| @dataclass |
| class EpisodeResult: |
| task_id: str |
| steps: int |
| grader_score: float |
| reward: float |
| done_reason: str |
|
|
|
|
| RULE_POLICY: dict[str, list[dict[str, Any]]] = { |
| "easy_password_reset": [ |
| {"action_type": "read_ticket", "ticket_id": "T-1001"}, |
| { |
| "action_type": "classify_ticket", |
| "ticket_id": "T-1001", |
| "priority": "medium", |
| "category": "account", |
| "needs_escalation": False, |
| }, |
| { |
| "action_type": "draft_reply", |
| "message": ( |
| "We will send a reset link to your email. For security, confirm the request " |
| "from your registered email before using the reset link." |
| ), |
| }, |
| {"action_type": "resolve_ticket", "ticket_id": "T-1001"}, |
| ], |
| "medium_billing_dispute": [ |
| {"action_type": "read_ticket", "ticket_id": "T-2001"}, |
| {"action_type": "read_ticket", "ticket_id": "T-2002"}, |
| { |
| "action_type": "classify_ticket", |
| "ticket_id": "T-2001", |
| "priority": "high", |
| "category": "billing", |
| "needs_escalation": False, |
| }, |
| { |
| "action_type": "draft_reply", |
| "message": ( |
| "We confirmed a duplicate charge. We are issuing a refund and will share the invoice update. " |
| "Refund processing typically takes 3-5 business days." |
| ), |
| }, |
| {"action_type": "resolve_ticket", "ticket_id": "T-2001"}, |
| ], |
| "hard_outage_incident": [ |
| {"action_type": "read_ticket", "ticket_id": "T-3001"}, |
| {"action_type": "read_ticket", "ticket_id": "T-3002"}, |
| {"action_type": "read_ticket", "ticket_id": "T-3003"}, |
| { |
| "action_type": "classify_ticket", |
| "ticket_id": "T-3001", |
| "priority": "urgent", |
| "category": "technical", |
| "needs_escalation": True, |
| }, |
| { |
| "action_type": "draft_reply", |
| "message": ( |
| "We have escalated this incident and are investigating now. " |
| "The status page will carry updates while we continue incident response." |
| ), |
| }, |
| {"action_type": "resolve_ticket", "ticket_id": "T-3001"}, |
| ], |
| } |
|
|
|
|
| def _extract_json(text: str) -> str: |
| text = text.strip() |
| start = text.find("{") |
| end = text.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| raise ValueError("No JSON object found in model response") |
| return text[start : end + 1] |
|
|
|
|
| def llm_action(client: OpenAI, model: str, observation: dict[str, Any], state: dict[str, Any]) -> Action: |
| user_prompt = json.dumps( |
| { |
| "observation": observation, |
| "state": state, |
| "instruction": "Pick the best next single action to maximize final score.", |
| }, |
| ensure_ascii=True, |
| ) |
|
|
| response = client.responses.create( |
| model=model, |
| temperature=0, |
| top_p=1, |
| input=[ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": [{"type": "text", "text": user_prompt}]}, |
| ], |
| ) |
|
|
| raw = response.output_text or "" |
| payload = json.loads(_extract_json(raw)) |
| return Action.model_validate(payload) |
|
|
|
|
| def heuristic_action(task_id: str, step_idx: int) -> Action: |
| plan = RULE_POLICY[task_id] |
| idx = min(step_idx, len(plan) - 1) |
| return Action.model_validate(plan[idx]) |
|
|
|
|
| def run_episode(env: SupportTriageEnv, task_id: str, mode: str, model: str, client: OpenAI | None) -> EpisodeResult: |
| obs = env.reset(task_id) |
| done = False |
| info: dict[str, Any] = {} |
| reward_value = 0.0 |
|
|
| while not done: |
| step_idx = env.state()["step_count"] |
| if mode == "heuristic": |
| action = heuristic_action(task_id, step_idx) |
| else: |
| assert client is not None |
| try: |
| action = llm_action(client, model, obs.model_dump(), env.state()) |
| except Exception: |
| |
| action = heuristic_action(task_id, step_idx) |
|
|
| obs, reward, done, info = env.step(action) |
| reward_value = reward.value |
|
|
| return EpisodeResult( |
| task_id=task_id, |
| steps=env.state()["step_count"], |
| grader_score=float(info["grader_score"]), |
| reward=reward_value, |
| done_reason=str(info["done_reason"]), |
| ) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Run baseline on support-triage-openenv tasks.") |
| parser.add_argument("--mode", choices=["openai", "heuristic"], default="openai") |
| parser.add_argument("--model", default="gpt-4.1-mini") |
| parser.add_argument("--output", default="scores/baseline_scores.json") |
| args = parser.parse_args() |
|
|
| client = None |
| if args.mode == "openai": |
| if not os.getenv("OPENAI_API_KEY"): |
| raise RuntimeError("OPENAI_API_KEY is required for --mode openai") |
| client = OpenAI() |
|
|
| env = SupportTriageEnv() |
| results = [run_episode(env, t, args.mode, args.model, client) for t in env.task_ids] |
|
|
| summary = { |
| "mode": args.mode, |
| "model": args.model, |
| "avg_grader_score": round(sum(r.grader_score for r in results) / len(results), 4), |
| "avg_final_reward": round(sum(r.reward for r in results) / len(results), 4), |
| "episodes": [asdict(r) for r in results], |
| } |
|
|
| output_path = Path(args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
|
|
| print(json.dumps(summary, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|