Sayed223 commited on
Commit
f1dacad
Β·
verified Β·
1 Parent(s): 2fb2c0e

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +392 -90
server.py CHANGED
@@ -1,129 +1,431 @@
1
  """
2
- CustomerSupportEnv β€” FastAPI server.
 
 
 
3
 
4
  Endpoints:
5
- POST /reset β†’ Observation
6
- POST /step β†’ StepResult
7
- GET /state β†’ Observation
8
- GET /tasks β†’ list of task specs
9
- POST /grade β†’ GraderResult
10
- GET /health β†’ 200 OK
11
- GET /openenv.yaml β†’ spec file
12
  """
13
  from __future__ import annotations
14
 
 
15
  import os
16
  import sys
17
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
18
-
19
- from typing import Optional
20
- from fastapi import FastAPI, HTTPException
21
- from fastapi.responses import FileResponse, JSONResponse
22
- from pydantic import BaseModel
23
-
24
- from env.environment import CustomerSupportEnv, TASKS
25
- from env.models import Action, Observation, StepResult, GraderResult
26
- from graders.graders import grade
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  app = FastAPI(
29
  title="CustomerSupportEnv",
30
- description="OpenEnv-compatible RL environment for customer support agent training.",
31
- version="1.0.0",
32
  )
33
 
34
- # One env instance per task (keyed by task_id)
35
- _envs: dict[str, CustomerSupportEnv] = {}
 
36
 
37
 
38
- def _get_env(task_id: str) -> CustomerSupportEnv:
39
- if task_id not in TASKS:
40
- raise HTTPException(status_code=404, detail=f"Unknown task_id: {task_id}")
41
- if task_id not in _envs:
42
- _envs[task_id] = CustomerSupportEnv(task_id=task_id)
43
- return _envs[task_id]
44
 
45
 
 
46
  class ResetRequest(BaseModel):
47
  task_id: str = "task_1"
 
48
 
49
 
50
  class StepRequest(BaseModel):
51
- task_id: str = "task_1"
52
  action_type: str
53
  payload: Optional[str] = None
54
 
55
 
56
  class GradeRequest(BaseModel):
57
- task_id: str
58
-
59
-
60
- @app.get("/health")
61
- def health():
62
- return {"status": "ok", "version": CustomerSupportEnv.VERSION}
63
-
64
-
65
- @app.post("/reset", response_model=Observation)
66
- def reset(req: ResetRequest):
67
- env = _get_env(req.task_id)
68
- obs = env.reset()
69
- return obs
70
-
71
-
72
- @app.post("/step", response_model=StepResult)
73
- def step(req: StepRequest):
74
- env = _get_env(req.task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  try:
76
- action = Action(action_type=req.action_type, payload=req.payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  result = env.step(action)
78
- return result
79
- except RuntimeError as e:
80
- raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  except Exception as e:
82
- raise HTTPException(status_code=422, detail=str(e))
83
-
84
-
85
- @app.get("/state", response_model=Observation)
86
- def state(task_id: str = "task_1"):
87
- env = _get_env(task_id)
88
- return env.state()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  @app.get("/tasks")
92
- def list_tasks():
93
- return {tid: spec.dict() for tid, spec in TASKS.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
- @app.post("/grade", response_model=GraderResult)
97
- def grade_endpoint(req: GradeRequest):
98
- env = _get_env(req.task_id)
99
- obs = env.state()
100
- result = grade(req.task_id, obs)
101
- return result
 
 
 
 
 
 
102
 
103
 
104
  @app.get("/openenv.yaml")
105
- def get_yaml():
106
- yaml_path = os.path.join(os.path.dirname(__file__), "openenv.yaml")
107
- if os.path.exists(yaml_path):
108
- return FileResponse(yaml_path, media_type="text/yaml")
109
- return JSONResponse({"error": "openenv.yaml not found"}, status_code=404)
110
-
111
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  if __name__ == "__main__":
113
- import uvicorn
114
- uvicorn.run("server:app", host="0.0.0.0", port=7860, reload=False)
115
-
116
- from fastapi.responses import HTMLResponse
117
-
118
- @app.get("/", response_class=HTMLResponse)
119
- def home():
120
- return """
121
- <h1>Customer Support Env πŸš€</h1>
122
- <p>API is running successfully.</p>
123
- <ul>
124
- <li>/reset</li>
125
- <li>/step</li>
126
- <li>/state</li>
127
- <li>/grade</li>
128
- </ul>
129
- """
 
1
  """
2
+ server.py β€” FastAPI/OpenEnv server wrapper for CustomerSupportEnv.
3
+
4
+ Exposes the environment as REST endpoints compatible with OpenEnv specification.
5
+ Handles session management and action validation.
6
 
7
  Endpoints:
8
+ POST /reset β†’ Initialize new episode, return initial observation
9
+ POST /step β†’ Apply action, return (obs, reward, done)
10
+ GET /state β†’ Get current environment state
11
+ GET /tasks β†’ List all tasks
12
+ POST /grade β†’ Grade current episode
13
+ GET /health β†’ Health check
14
+ GET /openenv.yaml β†’ Spec file
15
  """
16
  from __future__ import annotations
17
 
18
+ import json
19
  import os
20
  import sys
21
+ import traceback
22
+ from typing import Any, Dict, Optional
23
+ from pathlib import Path
24
+
25
+ # FastAPI imports
26
+ try:
27
+ from fastapi import FastAPI, HTTPException, Request
28
+ from fastapi.responses import FileResponse, JSONResponse
29
+ from pydantic import BaseModel
30
+ import uvicorn
31
+ except ImportError as e:
32
+ print(f"[ERROR] Missing FastAPI dependency: {e}", flush=True)
33
+ print("Run: pip install fastapi uvicorn pydantic", flush=True)
34
+ sys.exit(1)
35
 
36
+ # Local env imports
37
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
+ try:
39
+ from env.environment import CustomerSupportEnv, TASKS
40
+ from env.models import Action, ActionType, Observation, Reward
41
+ from graders.graders import grade
42
+ except ImportError as e:
43
+ print(f"[ERROR] Missing local env module: {e}", flush=True)
44
+ traceback.print_exc()
45
+ sys.exit(1)
46
+
47
+ # ── FastAPI App ──────────────────────────────────────────────────────────────
48
  app = FastAPI(
49
  title="CustomerSupportEnv",
50
+ description="OpenEnv-compatible customer support RL environment",
51
+ version="1.0.0"
52
  )
53
 
54
+ # ── Session Storage (in-memory for single deployment) ───────────────────────
55
+ _sessions: Dict[str, Dict[str, Any]] = {}
56
+ _session_counter = 0
57
 
58
 
59
+ def new_session_id() -> str:
60
+ """Generate a unique session ID."""
61
+ global _session_counter
62
+ _session_counter += 1
63
+ return f"session_{_session_counter:06d}"
 
64
 
65
 
66
+ # ── Pydantic Models ──────────────────────────────────────────────────────────
67
  class ResetRequest(BaseModel):
68
  task_id: str = "task_1"
69
+ seed: Optional[int] = None
70
 
71
 
72
  class StepRequest(BaseModel):
73
+ session_id: str
74
  action_type: str
75
  payload: Optional[str] = None
76
 
77
 
78
  class GradeRequest(BaseModel):
79
+ session_id: str
80
+
81
+
82
+ # ── Helper: Serialize observation/reward to JSON ──────────────────────────────
83
+ def serialize_obs(obs: Observation) -> Dict[str, Any]:
84
+ """Convert Observation dataclass to JSON-serializable dict."""
85
+ return {
86
+ "ticket_id": obs.ticket_id,
87
+ "task_id": obs.task_id,
88
+ "status": obs.status,
89
+ "sentiment": obs.sentiment,
90
+ "priority": obs.priority,
91
+ "category": obs.category,
92
+ "turn": obs.turn,
93
+ "max_turns": obs.max_turns,
94
+ "history": obs.history,
95
+ "kb_results": obs.kb_results,
96
+ "kb_searched": obs.kb_searched,
97
+ "empathized": obs.empathized,
98
+ "clarified": obs.clarified,
99
+ "solution_offered": obs.solution_offered,
100
+ "escalated": obs.escalated,
101
+ "cumulative_reward": obs.cumulative_reward,
102
+ "done": obs.done,
103
+ }
104
+
105
+
106
+ def serialize_reward(reward: Reward) -> Dict[str, Any]:
107
+ """Convert Reward dataclass to JSON-serializable dict."""
108
+ return {
109
+ "total": reward.total,
110
+ "breakdown": reward.breakdown,
111
+ "reason": reward.reason,
112
+ }
113
+
114
+
115
+ # ── OpenEnv Endpoints ────────────────────────────────────────────────────────
116
+
117
+ @app.post("/reset")
118
+ async def reset(request: ResetRequest) -> JSONResponse:
119
+ """
120
+ Reset environment and start a new episode.
121
+
122
+ Args:
123
+ task_id: One of task_1, task_2, task_3
124
+ seed: Optional random seed
125
+
126
+ Returns:
127
+ {
128
+ "session_id": str,
129
+ "observation": {...},
130
+ "info": {...}
131
+ }
132
+ """
133
  try:
134
+ task_id = request.task_id
135
+ seed = request.seed or 42
136
+
137
+ if task_id not in TASKS:
138
+ raise HTTPException(
139
+ status_code=400,
140
+ detail=f"Invalid task_id. Must be one of: {list(TASKS.keys())}"
141
+ )
142
+
143
+ # Create new environment
144
+ env = CustomerSupportEnv(task_id=task_id, seed=seed)
145
+ obs = env.reset()
146
+
147
+ # Store session
148
+ session_id = new_session_id()
149
+ _sessions[session_id] = {
150
+ "env": env,
151
+ "task_id": task_id,
152
+ "observation": obs,
153
+ "steps": 0,
154
+ "done": False,
155
+ }
156
+
157
+ return JSONResponse(
158
+ status_code=200,
159
+ content={
160
+ "session_id": session_id,
161
+ "observation": serialize_obs(obs),
162
+ "info": {
163
+ "task_id": task_id,
164
+ "difficulty": TASKS[task_id].difficulty,
165
+ "description": TASKS[task_id].description,
166
+ }
167
+ }
168
+ )
169
+
170
+ except Exception as e:
171
+ traceback.print_exc()
172
+ raise HTTPException(status_code=500, detail=f"Reset failed: {str(e)}")
173
+
174
+
175
+ @app.post("/step")
176
+ async def step(request: StepRequest) -> JSONResponse:
177
+ """
178
+ Apply an action and step the environment.
179
+
180
+ Args:
181
+ session_id: Session ID from /reset
182
+ action_type: One of [search_kb, empathize, ask_clarify, offer_solution, escalate, resolve, send_message]
183
+ payload: Optional action payload (required for some action types)
184
+
185
+ Returns:
186
+ {
187
+ "observation": {...},
188
+ "reward": {...},
189
+ "done": bool,
190
+ "info": {...}
191
+ }
192
+ """
193
+ try:
194
+ session_id = request.session_id
195
+ action_type = request.action_type
196
+ payload = request.payload
197
+
198
+ if session_id not in _sessions:
199
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
200
+
201
+ session = _sessions[session_id]
202
+ env = session["env"]
203
+
204
+ if session["done"]:
205
+ raise HTTPException(status_code=400, detail="Episode already done. Call /reset to start new episode.")
206
+
207
+ # Create action
208
+ action = Action(action_type=action_type, payload=payload)
209
+
210
+ # Step environment
211
  result = env.step(action)
212
+
213
+ # Update session
214
+ session["observation"] = result.observation
215
+ session["steps"] += 1
216
+ session["done"] = result.observation.done
217
+
218
+ return JSONResponse(
219
+ status_code=200,
220
+ content={
221
+ "observation": serialize_obs(result.observation),
222
+ "reward": serialize_reward(result.reward),
223
+ "done": result.observation.done,
224
+ "info": {
225
+ "step": session["steps"],
226
+ "action": action_type,
227
+ }
228
+ }
229
+ )
230
+
231
+ except HTTPException:
232
+ raise
233
  except Exception as e:
234
+ traceback.print_exc()
235
+ raise HTTPException(status_code=500, detail=f"Step failed: {str(e)}")
236
+
237
+
238
+ @app.get("/state")
239
+ async def state_endpoint(session_id: str) -> JSONResponse:
240
+ """
241
+ Get current environment state without stepping.
242
+
243
+ Args:
244
+ session_id: Session ID from /reset
245
+
246
+ Returns:
247
+ {
248
+ "observation": {...},
249
+ "info": {...}
250
+ }
251
+ """
252
+ try:
253
+ if session_id not in _sessions:
254
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
255
+
256
+ session = _sessions[session_id]
257
+ obs = session["observation"]
258
+
259
+ return JSONResponse(
260
+ status_code=200,
261
+ content={
262
+ "observation": serialize_obs(obs),
263
+ "info": {
264
+ "task_id": session["task_id"],
265
+ "steps": session["steps"],
266
+ "done": session["done"],
267
+ }
268
+ }
269
+ )
270
+
271
+ except HTTPException:
272
+ raise
273
+ except Exception as e:
274
+ traceback.print_exc()
275
+ raise HTTPException(status_code=500, detail=f"State query failed: {str(e)}")
276
 
277
 
278
  @app.get("/tasks")
279
+ async def tasks_endpoint() -> JSONResponse:
280
+ """
281
+ List all available tasks.
282
+
283
+ Returns:
284
+ {
285
+ "tasks": [
286
+ {
287
+ "id": "task_1",
288
+ "name": "...",
289
+ "difficulty": "easy|medium|hard",
290
+ "description": "...",
291
+ "max_turns": int
292
+ },
293
+ ...
294
+ ]
295
+ }
296
+ """
297
+ try:
298
+ task_list = []
299
+ for task_id, task_obj in TASKS.items():
300
+ task_list.append({
301
+ "id": task_id,
302
+ "name": task_obj.name,
303
+ "difficulty": task_obj.difficulty,
304
+ "description": task_obj.description,
305
+ "max_turns": task_obj.max_turns,
306
+ })
307
+
308
+ return JSONResponse(
309
+ status_code=200,
310
+ content={"tasks": task_list}
311
+ )
312
+
313
+ except Exception as e:
314
+ traceback.print_exc()
315
+ raise HTTPException(status_code=500, detail=f"Tasks query failed: {str(e)}")
316
+
317
+
318
+ @app.post("/grade")
319
+ async def grade_endpoint(request: GradeRequest) -> JSONResponse:
320
+ """
321
+ Grade the current episode.
322
+
323
+ Args:
324
+ session_id: Session ID from /reset
325
+
326
+ Returns:
327
+ {
328
+ "score": float (0.0 to 1.0),
329
+ "passed": bool,
330
+ "breakdown": {...},
331
+ "reason": str
332
+ }
333
+ """
334
+ try:
335
+ session_id = request.session_id
336
+
337
+ if session_id not in _sessions:
338
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
339
+
340
+ session = _sessions[session_id]
341
+ env = session["env"]
342
+ task_id = session["task_id"]
343
+
344
+ # Get final state
345
+ final_obs = env.state()
346
+
347
+ # Grade
348
+ grader_result = grade(task_id, final_obs)
349
+
350
+ return JSONResponse(
351
+ status_code=200,
352
+ content={
353
+ "score": grader_result.score,
354
+ "passed": grader_result.passed,
355
+ "breakdown": grader_result.breakdown,
356
+ "reason": grader_result.reason,
357
+ }
358
+ )
359
+
360
+ except HTTPException:
361
+ raise
362
+ except Exception as e:
363
+ traceback.print_exc()
364
+ raise HTTPException(status_code=500, detail=f"Grading failed: {str(e)}")
365
 
366
 
367
+ @app.get("/health")
368
+ async def health() -> JSONResponse:
369
+ """Health check endpoint."""
370
+ return JSONResponse(
371
+ status_code=200,
372
+ content={
373
+ "status": "healthy",
374
+ "service": "CustomerSupportEnv",
375
+ "version": "1.0.0",
376
+ "sessions_active": len(_sessions),
377
+ }
378
+ )
379
 
380
 
381
  @app.get("/openenv.yaml")
382
+ async def openenv_spec() -> FileResponse:
383
+ """Serve OpenEnv specification."""
384
+ spec_path = Path(__file__).parent / "openenv.yaml"
385
+ if not spec_path.exists():
386
+ raise HTTPException(status_code=404, detail="openenv.yaml not found")
387
+ return FileResponse(spec_path, media_type="text/yaml")
388
+
389
+
390
+ # ── Root endpoint ────────────────────────────────────────────────────────────
391
+ @app.get("/")
392
+ async def root() -> JSONResponse:
393
+ """Root endpoint."""
394
+ return JSONResponse(
395
+ status_code=200,
396
+ content={
397
+ "service": "CustomerSupportEnv OpenEnv Server",
398
+ "version": "1.0.0",
399
+ "endpoints": {
400
+ "POST /reset": "Initialize new episode",
401
+ "POST /step": "Apply action",
402
+ "GET /state": "Get current state",
403
+ "GET /tasks": "List tasks",
404
+ "POST /grade": "Grade episode",
405
+ "GET /health": "Health check",
406
+ "GET /openenv.yaml": "Specification",
407
+ }
408
+ }
409
+ )
410
+
411
+
412
+ # ── Startup/Shutdown ─────────────────────────────────────────────────────────
413
+ @app.on_event("startup")
414
+ async def startup_event():
415
+ """Log startup."""
416
+ print("[INFO] CustomerSupportEnv server started", flush=True)
417
+
418
+
419
+ @app.on_event("shutdown")
420
+ async def shutdown_event():
421
+ """Log shutdown."""
422
+ print("[INFO] CustomerSupportEnv server shutdown", flush=True)
423
+
424
+
425
+ # ── Main ─────────────────────────────────────────────────────────────────────
426
  if __name__ == "__main__":
427
+ port = int(os.environ.get("PORT", 7860))
428
+ host = os.environ.get("HOST", "0.0.0.0")
429
+
430
+ print(f"[INFO] Starting server on {host}:{port}", flush=True)
431
+ uvicorn.run(app, host=host, port=port, log_level="info")