ToolUseEnv / inference.py
Clove25's picture
Upload 53 files
18feac5 verified
import asyncio
import json
import os
import textwrap
from typing import Any, List, Optional
from openai import OpenAI
from tool_use_env.client import ToolUseEnv
from tool_use_env.models import ToolUseAction
from tool_use_env.tasks import TASK_SEQUENCE
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://127.0.0.1:8000")
BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "support_ops_env")
MAX_STEPS = 6
TEMPERATURE = 0.0
MAX_TOKENS = 220
SYSTEM_PROMPT = textwrap.dedent(
"""
You are operating a customer-support workflow environment.
Your job is to gather the minimum necessary evidence, draft a short customer reply,
and submit the correct final resolution code.
Reply with JSON only using this schema:
{
"action_type": "review_ticket|inspect_artifact|search_policy|draft_reply|submit_resolution",
"artifact_id": "optional string",
"query": "optional string",
"message": "optional string",
"resolution_code": "optional string"
}
Use concise messages. Prefer exact artifact ids and exact resolution codes shown in the observation.
"""
).strip()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={str(done).lower()} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
def _serialize_action(action: ToolUseAction) -> str:
payload = {"action_type": action.action_type}
if action.artifact_id:
payload["artifact_id"] = action.artifact_id
if action.query:
payload["query"] = action.query
if action.message:
payload["message"] = action.message.replace("\n", " ").strip()
if action.resolution_code:
payload["resolution_code"] = action.resolution_code
return json.dumps(payload, ensure_ascii=True, separators=(",", ":"))
def _fallback_action(observation: Any) -> ToolUseAction:
evidence = set(observation.collected_evidence)
task_id = observation.task_id
if "ticket" not in evidence:
return ToolUseAction(action_type="review_ticket")
task_plans = {
"damaged-mug-replacement": [
ToolUseAction(action_type="inspect_artifact", artifact_id="order"),
ToolUseAction(action_type="search_policy", query="damaged_items"),
ToolUseAction(
action_type="draft_reply",
message=(
"We are sending a replacement within 48 hours. "
"There is no need to return the broken mug."
),
),
ToolUseAction(action_type="submit_resolution", resolution_code="send_replacement"),
],
"duplicate-charge-refund": [
ToolUseAction(action_type="inspect_artifact", artifact_id="order"),
ToolUseAction(action_type="inspect_artifact", artifact_id="payment"),
ToolUseAction(action_type="search_policy", query="duplicate_charge"),
ToolUseAction(
action_type="draft_reply",
message=(
"We confirmed the duplicate charge and issued a refund. "
"You should see the refund in 3-5 business days."
),
),
ToolUseAction(
action_type="submit_resolution",
resolution_code="refund_duplicate_charge",
),
],
"account-takeover-fraud": [
ToolUseAction(action_type="inspect_artifact", artifact_id="account"),
ToolUseAction(action_type="inspect_artifact", artifact_id="risk_log"),
ToolUseAction(action_type="search_policy", query="account_takeover"),
ToolUseAction(
action_type="draft_reply",
message=(
"We locked your account immediately and escalated this to our fraud team. "
"You will receive an update within 24 hours."
),
),
ToolUseAction(
action_type="submit_resolution",
resolution_code="lock_account_and_escalate_fraud",
),
],
}
plan = task_plans[task_id]
for candidate in plan:
if candidate.action_type == "inspect_artifact":
if f"artifact:{candidate.artifact_id}" not in evidence:
return candidate
elif candidate.action_type == "search_policy":
if f"policy:{candidate.query}" not in evidence:
return candidate
elif candidate.action_type == "draft_reply" and not observation.last_tool_result.startswith("Draft saved"):
return candidate
elif candidate.action_type == "submit_resolution":
return candidate
return ToolUseAction(action_type="submit_resolution", resolution_code=observation.available_resolution_codes[0])
def _prompt_for_observation(step: int, observation: Any) -> str:
return textwrap.dedent(
f"""
Step: {step}
Task ID: {observation.task_id}
Difficulty: {observation.difficulty}
Objective: {observation.objective}
Customer message: {observation.customer_message}
Workspace summary: {observation.workspace_summary}
Collected evidence: {observation.collected_evidence}
Available resolution codes: {observation.available_resolution_codes}
Last tool result: {observation.last_tool_result}
Last action error: {observation.last_action_error}
Remaining steps: {observation.remaining_steps}
Return the single best next action as JSON.
"""
).strip()
def _model_action(client: OpenAI, step: int, observation: Any) -> ToolUseAction:
fallback = _fallback_action(observation)
if not API_KEY:
return fallback
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": _prompt_for_observation(step, observation)},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
response_format={"type": "json_object"},
)
raw = (completion.choices[0].message.content or "").strip()
data = json.loads(raw)
return ToolUseAction(
action_type=data.get("action_type", fallback.action_type),
artifact_id=data.get("artifact_id"),
query=data.get("query"),
message=data.get("message"),
resolution_code=data.get("resolution_code"),
)
except Exception:
return fallback
async def _connect_env() -> ToolUseEnv:
if LOCAL_IMAGE_NAME:
return await ToolUseEnv.from_docker_image(LOCAL_IMAGE_NAME)
env = ToolUseEnv(base_url=ENV_BASE_URL)
await env.connect()
return env
async def run_task(client: OpenAI, env: ToolUseEnv, task_id: str) -> float:
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
result = await env.reset(task_id=task_id, seed=7)
observation = result.observation
for step in range(1, MAX_STEPS + 1):
if result.done:
break
action = _model_action(client, step, observation)
action_str = _serialize_action(action)
result = await env.step(action)
observation = result.observation
reward = float(result.reward or 0.0)
done = bool(result.done)
error = observation.last_action_error
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=error)
if done:
break
state = await env.state()
score = float(state.final_score)
success = score >= 0.8
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
async def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "missing")
env = await _connect_env()
try:
scores = []
for task_id in TASK_SEQUENCE:
score = await run_task(client, env, task_id)
scores.append(score)
finally:
await env.close()
if __name__ == "__main__":
asyncio.run(main())