HelpDesk / inference.py
Freakdivi's picture
openenv space
2bd71de
import json
import os
import sys
import textwrap
from pathlib import Path
from typing import List, Optional
from openai import OpenAI
ROOT = Path(__file__).resolve().parent
PACKAGE_PARENT = ROOT.parent
if str(PACKAGE_PARENT) not in sys.path:
sys.path.insert(0, str(PACKAGE_PARENT))
from helpdesk_env.environment import HelpdeskEnv
from helpdesk_env.models import Action
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "helpdesk-openenv")
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
API_KEY = os.getenv("GROQ_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
TASK_NAME = os.getenv("TASK_NAME", "easy")
BENCHMARK = os.getenv("BENCHMARK", "helpdesk_env")
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "180"))
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.50"))
MAX_STEPS_BY_TASK = {
"easy": 1,
"medium": 3,
"hard": 8,
}
SYSTEM_PROMPT_BASE = (
"You are a banking customer support agent for a UPI payments app. "
"Never ask for PIN, OTP, CVV, or full card details. "
"You must return exactly one JSON object with keys from: "
"action_type, category, faq_id, message. "
"Valid action_type values are exactly: classify, lookup_faq, ask_clarification, "
"reply, escalate, resolve_ticket."
)
def system_prompt_for_task(task_id: str) -> str:
if task_id == "easy":
return (
SYSTEM_PROMPT_BASE
+ " For easy tasks, classify the issue into exactly one category from "
"observation.available_categories."
)
if task_id == "medium":
return (
SYSTEM_PROMPT_BASE
+ " For medium tasks, choose lookup_faq with the best faq_id from "
"observation.knowledge_base, or use escalate when fraud or overdue review requires manual handling."
)
return (
SYSTEM_PROMPT_BASE
+ " For hard tasks, ask for clarification first, then retrieve the right FAQ, "
"then reply with safe guidance, and only resolve after the customer confirms the issue is fixed."
)
def build_user_prompt(task_id: str, observation_json: str, history: List[str]) -> str:
history_block = "\n".join(history[-4:]) if history else "None"
return textwrap.dedent(
f"""
Task: {task_id}
Observation JSON:
{observation_json}
Recent action history:
{history_block}
Return the next action as one JSON object only.
"""
).strip()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={str(done).lower()} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, rewards: List[float]) -> None:
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}",
flush=True,
)
def _extract_json_object(text: str) -> str:
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
if len(lines) >= 2 and lines[0].startswith("```"):
lines = lines[1:]
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
text = "\n".join(lines).strip()
return text
_VALID_ACTIONS = frozenset(
{
"classify",
"lookup_faq",
"ask_clarification",
"reply",
"escalate",
"resolve_ticket",
}
)
def _normalize_action_type(raw: object) -> str:
if raw is None:
return ""
value = str(raw).strip().lower().replace("-", "_")
return value if value in _VALID_ACTIONS else ""
def _fallback_action(task_id: str, turn_number: int) -> Action:
if task_id == "easy":
return Action(action_type="classify", category="payment_failure")
if task_id == "medium":
return Action(action_type="escalate", message="Escalating for manual review.")
if turn_number == 0:
return Action(
action_type="ask_clarification",
message="Please share the UTR, amount, and exact issue.",
)
if turn_number == 1:
return Action(action_type="lookup_faq", faq_id="faq_001")
if turn_number in (2, 3):
return Action(
action_type="reply",
message="Please follow the safe steps in the app and confirm the result.",
)
return Action(action_type="resolve_ticket")
def parse_action(response_text: str, task_id: str, turn_number: int) -> Action:
text = _extract_json_object(response_text)
try:
payload = json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
try:
payload = json.loads(text[start : end + 1])
except json.JSONDecodeError:
payload = {}
else:
payload = {}
action_type = _normalize_action_type(payload.get("action_type"))
if not action_type:
return _fallback_action(task_id, turn_number)
try:
return Action(
action_type=action_type,
category=payload.get("category"),
faq_id=payload.get("faq_id"),
message=payload.get("message"),
)
except Exception:
return _fallback_action(task_id, turn_number)
def get_model_action(
client: OpenAI,
task_id: str,
observation_json: str,
history: List[str],
turn_number: int,
) -> Action:
user_prompt = build_user_prompt(task_id, observation_json, history)
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt_for_task(task_id)},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
response_format={"type": "json_object"},
)
text = completion.choices[0].message.content or ""
return parse_action(text, task_id, turn_number)
def main() -> None:
if not API_KEY:
raise RuntimeError(
"Set GROQ_API_KEY, HF_TOKEN, or API_KEY before running inference.py"
)
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = HelpdeskEnv()
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
success = False
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
try:
observation = env.reset(TASK_NAME)
done = False
for step in range(1, MAX_STEPS_BY_TASK.get(TASK_NAME, 3) + 1):
if done:
break
error: Optional[str] = None
try:
action = get_model_action(
client=client,
task_id=TASK_NAME,
observation_json=observation.model_dump_json(),
history=history,
turn_number=observation.turn_number,
)
observation, reward, done, _info = env.step(action)
reward_value = reward.value
except Exception as exc:
action = _fallback_action(TASK_NAME, observation.turn_number)
reward_value = 0.0
done = True
error = str(exc)
action_str = json.dumps(action.model_dump(exclude_none=True), separators=(",", ":"))
log_step(
step=step,
action=action_str,
reward=reward_value,
done=done,
error=error,
)
rewards.append(reward_value)
steps_taken = step
history.append(f"step={step} action={action_str} reward={reward_value:.2f}")
final_score = rewards[-1] if rewards else 0.0
success = final_score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(success=success, steps=steps_taken, rewards=rewards)
if __name__ == "__main__":
main()