eeshwar143 commited on
Commit
edf1c38
·
1 Parent(s): 9be17b9

Use OpenEnv websocket client sessions for inference

Browse files
Files changed (2) hide show
  1. inference.py +18 -3
  2. support_queue_env/client.py +82 -149
inference.py CHANGED
@@ -20,6 +20,7 @@ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
20
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
 
23
 
24
  BENCHMARK = "support_queue_env"
25
  SUCCESS_SCORE_THRESHOLD = 0.80
@@ -46,11 +47,14 @@ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> No
46
 
47
 
48
  def create_openai_client() -> Any:
 
 
 
49
  if OpenAI is not None:
50
- return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "placeholder")
51
 
52
  openai_module.api_base = API_BASE_URL
53
- openai_module.api_key = HF_TOKEN or "placeholder"
54
  return openai_module
55
 
56
 
@@ -61,6 +65,9 @@ def get_model_message(
61
  last_reward: float,
62
  history: List[str],
63
  ) -> str:
 
 
 
64
  prompt = (
65
  "Return a short support-triage recommendation as JSON with fields priority, queue, disposition, summary, response. "
66
  f"Step: {step}. Last reward: {last_reward:.4f}. History: {history[-4:]}. Observation: {observation.model_dump_json()}"
@@ -215,6 +222,14 @@ def heuristic_action(observation: SupportQueueObservation) -> SupportQueueAction
215
  )
216
 
217
 
 
 
 
 
 
 
 
 
218
  async def run_task(client: Any, env: SupportQueueEnv, task: TaskCard) -> dict[str, Any]:
219
  history: List[str] = []
220
  rewards: List[float] = []
@@ -285,7 +300,7 @@ async def main() -> None:
285
  env: SupportQueueEnv | None = None
286
 
287
  try:
288
- env = await SupportQueueEnv.from_docker_image(LOCAL_IMAGE_NAME)
289
  for task in tasks:
290
  results.append(await run_task(client, env, task))
291
  except Exception as exc:
 
20
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
23
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL")
24
 
25
  BENCHMARK = "support_queue_env"
26
  SUCCESS_SCORE_THRESHOLD = 0.80
 
47
 
48
 
49
  def create_openai_client() -> Any:
50
+ if not HF_TOKEN:
51
+ return None
52
+
53
  if OpenAI is not None:
54
+ return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
55
 
56
  openai_module.api_base = API_BASE_URL
57
+ openai_module.api_key = HF_TOKEN
58
  return openai_module
59
 
60
 
 
65
  last_reward: float,
66
  history: List[str],
67
  ) -> str:
68
+ if client is None:
69
+ return "hello"
70
+
71
  prompt = (
72
  "Return a short support-triage recommendation as JSON with fields priority, queue, disposition, summary, response. "
73
  f"Step: {step}. Last reward: {last_reward:.4f}. History: {history[-4:]}. Observation: {observation.model_dump_json()}"
 
222
  )
223
 
224
 
225
+ async def build_env() -> SupportQueueEnv:
226
+ if ENV_BASE_URL:
227
+ env = SupportQueueEnv(base_url=ENV_BASE_URL)
228
+ await env.connect()
229
+ return env
230
+ return await SupportQueueEnv.from_docker_image(LOCAL_IMAGE_NAME or "support-queue-openenv")
231
+
232
+
233
  async def run_task(client: Any, env: SupportQueueEnv, task: TaskCard) -> dict[str, Any]:
234
  history: List[str] = []
235
  rewards: List[float] = []
 
300
  env: SupportQueueEnv | None = None
301
 
302
  try:
303
+ env = await build_env()
304
  for task in tasks:
305
  results.append(await run_task(client, env, task))
306
  except Exception as exc:
support_queue_env/client.py CHANGED
@@ -1,157 +1,90 @@
1
- """HTTP client for interacting with the support queue environment."""
2
 
3
  from __future__ import annotations
4
 
5
- import asyncio
6
- import os
7
- import socket
8
- import subprocess
9
- import time
10
- from typing import Any
11
 
12
  import requests
13
 
14
  from support_queue_env.models import TaskCard, SupportQueueAction, SupportQueueObservation, SupportQueueState
15
 
16
- DEFAULT_ENV_BASE_URL = os.getenv("ENV_BASE_URL")
17
- DEFAULT_IMAGE_CANDIDATES = [
18
- "support-queue-openenv:latest",
19
- "support-queue-openenv",
20
- "support_queue_env:latest",
21
- "support_queue_env",
22
- ]
23
-
24
-
25
- class _Result:
26
- def __init__(self, payload: dict[str, Any]) -> None:
27
- self.observation = SupportQueueObservation.model_validate(payload["observation"])
28
- self.reward = float(payload.get("reward") or 0.0)
29
- self.done = bool(payload.get("done"))
30
-
31
-
32
- class SupportQueueEnv:
33
- def __init__(self, base_url: str, container_id: str | None = None) -> None:
34
- self.base_url = base_url.rstrip("/")
35
- self.container_id = container_id
36
-
37
- @classmethod
38
- def from_base_url(cls, base_url: str) -> "SupportQueueEnv":
39
- return cls(base_url=base_url)
40
-
41
- @classmethod
42
- async def from_docker_image(cls, image_name: str | None = None) -> "SupportQueueEnv":
43
- if DEFAULT_ENV_BASE_URL:
44
- return cls(base_url=DEFAULT_ENV_BASE_URL)
45
- return await asyncio.to_thread(cls._spawn_local_container, image_name)
46
-
47
- @classmethod
48
- def _spawn_local_container(cls, image_name: str | None) -> "SupportQueueEnv":
49
- chosen_image = cls._resolve_image_name(image_name)
50
- port = cls._pick_free_port()
51
- container_id = cls._run(["docker", "run", "-d", "-p", f"{port}:8000", chosen_image]).strip()
52
- base_url = f"http://127.0.0.1:{port}"
53
-
54
- try:
55
- cls._wait_until_ready(base_url)
56
- except Exception:
57
- cls._safe_remove_container(container_id)
58
- raise
59
-
60
- return cls(base_url=base_url, container_id=container_id)
61
-
62
- @classmethod
63
- def _resolve_image_name(cls, image_name: str | None) -> str:
64
- candidates: list[str] = []
65
- if image_name:
66
- candidates.append(image_name)
67
- candidates.extend(DEFAULT_IMAGE_CANDIDATES)
68
-
69
- for candidate in candidates:
70
- if cls._image_exists(candidate):
71
- return candidate
72
-
73
- build_tag = image_name or "support-queue-openenv:local"
74
- cls._run(["docker", "build", "-t", build_tag, "."])
75
- return build_tag
76
-
77
- @staticmethod
78
- def _image_exists(image_name: str) -> bool:
79
- try:
80
- SupportQueueEnv._run(["docker", "image", "inspect", image_name])
81
- return True
82
- except RuntimeError:
83
- return False
84
-
85
- @staticmethod
86
- def _pick_free_port() -> int:
87
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
88
- sock.bind(("127.0.0.1", 0))
89
- return int(sock.getsockname()[1])
90
-
91
- @staticmethod
92
- def _wait_until_ready(base_url: str, timeout_seconds: int = 45) -> None:
93
- deadline = time.time() + timeout_seconds
94
- last_error = ""
95
-
96
- while time.time() < deadline:
97
- try:
98
- response = requests.get(f"{base_url}/health", timeout=3)
99
- if response.ok:
100
- return
101
- except Exception as exc:
102
- last_error = str(exc)
103
- time.sleep(1)
104
-
105
- raise RuntimeError(f"Environment did not become ready at {base_url}: {last_error}")
106
-
107
- @staticmethod
108
- def _run(command: list[str]) -> str:
109
- result = subprocess.run(command, check=False, capture_output=True, text=True)
110
- if result.returncode != 0:
111
- raise RuntimeError((result.stderr or result.stdout).strip() or f"Command failed: {' '.join(command)}")
112
- return result.stdout
113
-
114
- @staticmethod
115
- def _safe_remove_container(container_id: str) -> None:
116
- subprocess.run(["docker", "rm", "-f", container_id], check=False, capture_output=True, text=True)
117
-
118
- def list_tasks(self) -> list[TaskCard]:
119
- response = requests.get(f"{self.base_url}/tasks", timeout=30)
120
- response.raise_for_status()
121
- payload = response.json()
122
- return [TaskCard.model_validate(item) for item in payload["tasks"]]
123
-
124
- async def alist_tasks(self) -> list[TaskCard]:
125
- return await asyncio.to_thread(self.list_tasks)
126
-
127
- def reset_sync(self, **kwargs: Any) -> _Result:
128
- response = requests.post(f"{self.base_url}/reset", json=kwargs or {}, timeout=30)
129
- response.raise_for_status()
130
- return _Result(response.json())
131
-
132
- async def reset(self, **kwargs: Any) -> _Result:
133
- return await asyncio.to_thread(self.reset_sync, **kwargs)
134
-
135
- def step_sync(self, action: SupportQueueAction) -> _Result:
136
- response = requests.post(
137
- f"{self.base_url}/step",
138
- json={"action": action.model_dump()},
139
- timeout=30,
140
- )
141
- response.raise_for_status()
142
- return _Result(response.json())
143
-
144
- async def step(self, action: SupportQueueAction) -> _Result:
145
- return await asyncio.to_thread(self.step_sync, action)
146
-
147
- def state_sync(self) -> SupportQueueState:
148
- response = requests.get(f"{self.base_url}/state", timeout=30)
149
- response.raise_for_status()
150
- return SupportQueueState.model_validate(response.json())
151
-
152
- async def state(self) -> SupportQueueState:
153
- return await asyncio.to_thread(self.state_sync)
154
-
155
- async def close(self) -> None:
156
- if self.container_id:
157
- await asyncio.to_thread(self._safe_remove_container, self.container_id)
 
1
+ """OpenEnv client for interacting with the support queue environment."""
2
 
3
  from __future__ import annotations
4
 
5
+ from typing import Any, Dict
 
 
 
 
 
6
 
7
  import requests
8
 
9
  from support_queue_env.models import TaskCard, SupportQueueAction, SupportQueueObservation, SupportQueueState
10
 
11
+ try:
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_client import EnvClient as OpenEnvClient
14
+ except Exception: # pragma: no cover - fallback for environments without openenv-core
15
+ OpenEnvClient = None
16
+ StepResult = None
17
+
18
+
19
+ if OpenEnvClient is not None:
20
+
21
+ class SupportQueueEnv(OpenEnvClient[SupportQueueAction, SupportQueueObservation, SupportQueueState]):
22
+ def __init__(self, base_url: str, **kwargs: Any) -> None:
23
+ super().__init__(base_url=base_url, **kwargs)
24
+ self.base_url = base_url.rstrip("/")
25
+
26
+ def _step_payload(self, action: SupportQueueAction) -> Dict[str, Any]:
27
+ return action.model_dump()
28
+
29
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[SupportQueueObservation]:
30
+ observation = SupportQueueObservation.model_validate(payload.get("observation", {}))
31
+ return StepResult(
32
+ observation=observation,
33
+ reward=payload.get("reward"),
34
+ done=payload.get("done", False),
35
+ )
36
+
37
+ def _parse_state(self, payload: Dict[str, Any]) -> SupportQueueState:
38
+ return SupportQueueState.model_validate(payload)
39
+
40
+ def list_tasks(self) -> list[TaskCard]:
41
+ response = requests.get(f"{self.base_url.rstrip('/')}/tasks", timeout=30)
42
+ response.raise_for_status()
43
+ payload = response.json()
44
+ return [TaskCard.model_validate(item) for item in payload["tasks"]]
45
+
46
+ else:
47
+
48
+ class _Result:
49
+ def __init__(self, payload: dict[str, Any]) -> None:
50
+ self.observation = SupportQueueObservation.model_validate(payload["observation"])
51
+ self.reward = float(payload.get("reward") or 0.0)
52
+ self.done = bool(payload.get("done"))
53
+
54
+
55
+ class SupportQueueEnv:
56
+ def __init__(self, base_url: str, **_: Any) -> None:
57
+ self.base_url = base_url.rstrip("/")
58
+
59
+ @classmethod
60
+ async def from_docker_image(cls, image_name: str | None = None) -> "SupportQueueEnv":
61
+ _ = image_name
62
+ return cls(base_url="http://127.0.0.1:8000")
63
+
64
+ def list_tasks(self) -> list[TaskCard]:
65
+ response = requests.get(f"{self.base_url}/tasks", timeout=30)
66
+ response.raise_for_status()
67
+ payload = response.json()
68
+ return [TaskCard.model_validate(item) for item in payload["tasks"]]
69
+
70
+ async def reset(self, **kwargs: Any) -> _Result:
71
+ response = requests.post(f"{self.base_url}/reset", json=kwargs or {}, timeout=30)
72
+ response.raise_for_status()
73
+ return _Result(response.json())
74
+
75
+ async def step(self, action: SupportQueueAction) -> _Result:
76
+ response = requests.post(
77
+ f"{self.base_url}/step",
78
+ json={"action": action.model_dump()},
79
+ timeout=30,
80
+ )
81
+ response.raise_for_status()
82
+ return _Result(response.json())
83
+
84
+ async def state(self) -> SupportQueueState:
85
+ response = requests.get(f"{self.base_url}/state", timeout=30)
86
+ response.raise_for_status()
87
+ return SupportQueueState.model_validate(response.json())
88
+
89
+ async def close(self) -> None:
90
+ return None