| import json |
| import os |
| import sys |
| import textwrap |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| from openai import OpenAI |
|
|
|
|
| ROOT = Path(__file__).resolve().parent |
| PACKAGE_PARENT = ROOT.parent |
| if str(PACKAGE_PARENT) not in sys.path: |
| sys.path.insert(0, str(PACKAGE_PARENT)) |
|
|
| from helpdesk_env.environment import HelpdeskEnv |
| from helpdesk_env.models import Action |
|
|
|
|
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "helpdesk-openenv") |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile") |
| API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY") |
| TASK_NAME = os.getenv("TASK_NAME", "easy") |
| BENCHMARK = os.getenv("BENCHMARK", "helpdesk_env") |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0")) |
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "180")) |
| SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.50")) |
|
|
| MAX_STEPS_BY_TASK = { |
| "easy": 1, |
| "medium": 3, |
| "hard": 8, |
| } |
|
|
| SYSTEM_PROMPT_BASE = ( |
| "You are a banking customer support agent for a UPI payments app. " |
| "Never ask for PIN, OTP, CVV, or full card details. " |
| "You must return exactly one JSON object with keys from: " |
| "action_type, category, faq_id, message. " |
| "Valid action_type values are exactly: classify, lookup_faq, ask_clarification, " |
| "reply, escalate, resolve_ticket." |
| ) |
|
|
|
|
| def system_prompt_for_task(task_id: str) -> str: |
| if task_id == "easy": |
| return ( |
| SYSTEM_PROMPT_BASE |
| + " For easy tasks, classify the issue into exactly one category from " |
| "observation.available_categories." |
| ) |
| if task_id == "medium": |
| return ( |
| SYSTEM_PROMPT_BASE |
| + " For medium tasks, choose lookup_faq with the best faq_id from " |
| "observation.knowledge_base, or use escalate when fraud or overdue review requires manual handling." |
| ) |
| return ( |
| SYSTEM_PROMPT_BASE |
| + " For hard tasks, ask for clarification first, then retrieve the right FAQ, " |
| "then reply with safe guidance, and only resolve after the customer confirms the issue is fixed." |
| ) |
|
|
|
|
| def build_user_prompt(task_id: str, observation_json: str, history: List[str]) -> str: |
| history_block = "\n".join(history[-4:]) if history else "None" |
| return textwrap.dedent( |
| f""" |
| Task: {task_id} |
| Observation JSON: |
| {observation_json} |
| |
| Recent action history: |
| {history_block} |
| |
| Return the next action as one JSON object only. |
| """ |
| ).strip() |
|
|
|
|
| def log_start(task: str, env: str, model: str) -> None: |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: |
| error_val = error if error else "null" |
| print( |
| f"[STEP] step={step} action={action} reward={reward:.2f} " |
| f"done={str(done).lower()} error={error_val}", |
| flush=True, |
| ) |
|
|
|
|
| def log_end(success: bool, steps: int, rewards: List[float]) -> None: |
| rewards_str = ",".join(f"{reward:.2f}" for reward in rewards) |
| print( |
| f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", |
| flush=True, |
| ) |
|
|
|
|
| def _extract_json_object(text: str) -> str: |
| text = text.strip() |
| if text.startswith("```"): |
| lines = text.split("\n") |
| if len(lines) >= 2 and lines[0].startswith("```"): |
| lines = lines[1:] |
| if lines and lines[-1].strip() == "```": |
| lines = lines[:-1] |
| text = "\n".join(lines).strip() |
| return text |
|
|
|
|
| _VALID_ACTIONS = frozenset( |
| { |
| "classify", |
| "lookup_faq", |
| "ask_clarification", |
| "reply", |
| "escalate", |
| "resolve_ticket", |
| } |
| ) |
|
|
|
|
| def _normalize_action_type(raw: object) -> str: |
| if raw is None: |
| return "" |
| value = str(raw).strip().lower().replace("-", "_") |
| return value if value in _VALID_ACTIONS else "" |
|
|
|
|
| def _fallback_action(task_id: str, turn_number: int) -> Action: |
| if task_id == "easy": |
| return Action(action_type="classify", category="payment_failure") |
| if task_id == "medium": |
| return Action(action_type="escalate", message="Escalating for manual review.") |
| if turn_number == 0: |
| return Action( |
| action_type="ask_clarification", |
| message="Please share the UTR, amount, and exact issue.", |
| ) |
| if turn_number == 1: |
| return Action(action_type="lookup_faq", faq_id="faq_001") |
| if turn_number in (2, 3): |
| return Action( |
| action_type="reply", |
| message="Please follow the safe steps in the app and confirm the result.", |
| ) |
| return Action(action_type="resolve_ticket") |
|
|
|
|
| def parse_action(response_text: str, task_id: str, turn_number: int) -> Action: |
| text = _extract_json_object(response_text) |
| try: |
| payload = json.loads(text) |
| except json.JSONDecodeError: |
| start = text.find("{") |
| end = text.rfind("}") |
| if start != -1 and end != -1 and end > start: |
| try: |
| payload = json.loads(text[start : end + 1]) |
| except json.JSONDecodeError: |
| payload = {} |
| else: |
| payload = {} |
|
|
| action_type = _normalize_action_type(payload.get("action_type")) |
| if not action_type: |
| return _fallback_action(task_id, turn_number) |
|
|
| try: |
| return Action( |
| action_type=action_type, |
| category=payload.get("category"), |
| faq_id=payload.get("faq_id"), |
| message=payload.get("message"), |
| ) |
| except Exception: |
| return _fallback_action(task_id, turn_number) |
|
|
|
|
| def get_model_action( |
| client: OpenAI, |
| task_id: str, |
| observation_json: str, |
| history: List[str], |
| turn_number: int, |
| ) -> Action: |
| user_prompt = build_user_prompt(task_id, observation_json, history) |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": system_prompt_for_task(task_id)}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| temperature=TEMPERATURE, |
| max_tokens=MAX_TOKENS, |
| response_format={"type": "json_object"}, |
| ) |
| text = completion.choices[0].message.content or "" |
| return parse_action(text, task_id, turn_number) |
|
|
|
|
| def main() -> None: |
| if not API_KEY: |
| raise RuntimeError( |
| "Set GROQ_API_KEY, HF_TOKEN, or API_KEY before running inference.py" |
| ) |
|
|
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
| env = HelpdeskEnv() |
|
|
| history: List[str] = [] |
| rewards: List[float] = [] |
| steps_taken = 0 |
| success = False |
|
|
| log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME) |
|
|
| try: |
| observation = env.reset(TASK_NAME) |
| done = False |
|
|
| for step in range(1, MAX_STEPS_BY_TASK.get(TASK_NAME, 3) + 1): |
| if done: |
| break |
|
|
| error: Optional[str] = None |
| try: |
| action = get_model_action( |
| client=client, |
| task_id=TASK_NAME, |
| observation_json=observation.model_dump_json(), |
| history=history, |
| turn_number=observation.turn_number, |
| ) |
| observation, reward, done, _info = env.step(action) |
| reward_value = reward.value |
| except Exception as exc: |
| action = _fallback_action(TASK_NAME, observation.turn_number) |
| reward_value = 0.0 |
| done = True |
| error = str(exc) |
|
|
| action_str = json.dumps(action.model_dump(exclude_none=True), separators=(",", ":")) |
| log_step( |
| step=step, |
| action=action_str, |
| reward=reward_value, |
| done=done, |
| error=error, |
| ) |
|
|
| rewards.append(reward_value) |
| steps_taken = step |
| history.append(f"step={step} action={action_str} reward={reward_value:.2f}") |
|
|
| final_score = rewards[-1] if rewards else 0.0 |
| success = final_score >= SUCCESS_SCORE_THRESHOLD |
|
|
| finally: |
| log_end(success=success, steps=steps_taken, rewards=rewards) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|