modelbuilderhq commited on
Commit
1305932
·
verified ·
1 Parent(s): 55db2c6

Upload folder using huggingface_hub

Browse files
supportdesk_env/server/app.py CHANGED
@@ -36,6 +36,7 @@ from typing import Any
36
 
37
  import uvicorn
38
  from fastapi import Body, HTTPException
 
39
 
40
  try:
41
  from openenv.core.env_server import http_server as openenv_http_server
@@ -70,6 +71,69 @@ app = create_app(
70
  )
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  @app.get("/tasks")
74
  def list_tasks() -> dict[str, Any]:
75
  """Expose a stable task catalog for UI, debugging, and pre-submit checks."""
@@ -137,6 +201,7 @@ def step_episode(
137
  "observation": observation.model_dump(),
138
  "reward": observation.reward,
139
  "done": observation.done,
 
140
  }
141
 
142
 
 
36
 
37
  import uvicorn
38
  from fastapi import Body, HTTPException
39
+ from fastapi.routing import APIRoute
40
 
41
  try:
42
  from openenv.core.env_server import http_server as openenv_http_server
 
71
  )
72
 
73
 
74
+ def _replace_route(path: str, methods: set[str]) -> None:
75
+ """Remove a generated route so we can register a score-aware replacement."""
76
+
77
+ app.router.routes = [
78
+ route
79
+ for route in app.router.routes
80
+ if not (
81
+ isinstance(route, APIRoute)
82
+ and route.path == path
83
+ and methods.issubset(set(route.methods or set()))
84
+ )
85
+ ]
86
+
87
+
88
+ def _score_response(env: SupportDeskEnvironment, observation: SupportDeskObservation) -> dict[str, Any]:
89
+ """Return the standard OpenEnv shape plus an explicit top-level score."""
90
+
91
+ return {
92
+ "observation": observation.model_dump(),
93
+ "reward": observation.reward,
94
+ "done": observation.done,
95
+ "score": env.state.current_score,
96
+ }
97
+
98
+
99
+ _replace_route("/reset", {"POST"})
100
+ _replace_route("/step", {"POST"})
101
+
102
+
103
+ @app.post("/reset")
104
+ async def reset_with_score(
105
+ request: openenv_http_server.ResetRequest = Body(default_factory=openenv_http_server.ResetRequest),
106
+ ) -> dict[str, Any]:
107
+ """Reset the environment and expose the initial deterministic score at top level."""
108
+
109
+ env = SupportDeskEnvironment()
110
+ try:
111
+ kwargs = request.model_dump(exclude_unset=True)
112
+ observation = env.reset(**kwargs)
113
+ return _score_response(env, observation)
114
+ finally:
115
+ env.close()
116
+
117
+
118
+ @app.post("/step")
119
+ async def step_with_score(request: openenv_http_server.StepRequest) -> dict[str, Any]:
120
+ """Execute a step and expose the current deterministic score at top level."""
121
+
122
+ action_data = request.action
123
+ try:
124
+ action = openenv_http_server.deserialize_action(action_data, SupportDeskAction)
125
+ except openenv_http_server.ValidationError as exc:
126
+ raise HTTPException(status_code=422, detail=exc.errors()) from exc
127
+
128
+ env = SupportDeskEnvironment()
129
+ try:
130
+ kwargs = request.model_dump(exclude_unset=True, exclude={"action"})
131
+ observation = env.step(action, **kwargs)
132
+ return _score_response(env, observation)
133
+ finally:
134
+ env.close()
135
+
136
+
137
  @app.get("/tasks")
138
  def list_tasks() -> dict[str, Any]:
139
  """Expose a stable task catalog for UI, debugging, and pre-submit checks."""
 
201
  "observation": observation.model_dump(),
202
  "reward": observation.reward,
203
  "done": observation.done,
204
+ "score": SupportDeskEnvironment.state_for_episode(episode_id).current_score,
205
  }
206
 
207
 
tests/test_supportdesk.py CHANGED
@@ -140,6 +140,9 @@ def test_http_reset_step_state_are_session_consistent():
140
 
141
  reset_response = client.post("/reset", json={"episode_id": "http-episode"})
142
  assert reset_response.status_code == 200
 
 
 
143
 
144
  step_response = client.post(
145
  "/step",
@@ -157,6 +160,9 @@ def test_http_reset_step_state_are_session_consistent():
157
  },
158
  )
159
  assert step_response.status_code == 200
 
 
 
160
 
161
  state_response = client.get("/state")
162
  assert state_response.status_code == 200
@@ -191,6 +197,9 @@ def test_http_explicit_episode_helpers_work():
191
  },
192
  )
193
  assert step_response.status_code == 200
 
 
 
194
 
195
  state_response = client.get(f"/episodes/{episode_id}/state")
196
  assert state_response.status_code == 200
 
140
 
141
  reset_response = client.post("/reset", json={"episode_id": "http-episode"})
142
  assert reset_response.status_code == 200
143
+ reset_payload = reset_response.json()
144
+ assert "score" in reset_payload
145
+ assert 0.0 < reset_payload["score"] < 1.0
146
 
147
  step_response = client.post(
148
  "/step",
 
160
  },
161
  )
162
  assert step_response.status_code == 200
163
+ step_payload = step_response.json()
164
+ assert "score" in step_payload
165
+ assert 0.0 < step_payload["score"] < 1.0
166
 
167
  state_response = client.get("/state")
168
  assert state_response.status_code == 200
 
197
  },
198
  )
199
  assert step_response.status_code == 200
200
+ step_payload = step_response.json()
201
+ assert "score" in step_payload
202
+ assert 0.0 < step_payload["score"] < 1.0
203
 
204
  state_response = client.get(f"/episodes/{episode_id}/state")
205
  assert state_response.status_code == 200