hiitsesh commited on
Commit
73b708a
·
1 Parent(s): 0287ccf

feat: enhance reset functionality and grading logic with optional request body

Browse files
Files changed (3) hide show
  1. adv_rebuild.py +14 -6
  2. server/app.py +10 -5
  3. src/models.py +4 -1
adv_rebuild.py CHANGED
@@ -6,7 +6,10 @@ def write_file(path, content):
6
 
7
  models_py = """
8
  from pydantic import BaseModel, Field
9
- from typing import Dict, Literal, List
 
 
 
10
 
11
  class Observation(BaseModel):
12
  time_step: int
@@ -184,8 +187,9 @@ TASKS = {
184
  """
185
 
186
  main_py = """
187
- from fastapi import FastAPI, HTTPException
188
- from src.models import Action, TaskConfig
 
189
  from src.env import DesalEnv
190
  from src.tasks import TASKS
191
  import subprocess
@@ -198,7 +202,11 @@ def health_check():
198
  return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
199
 
200
  @app.post("/reset")
201
- def reset_env(task_id: str = "easy_spring"):
 
 
 
 
202
  if task_id not in TASKS:
203
  raise HTTPException(status_code=404, detail="Task not found")
204
  obs = env.reset(TASKS[task_id])
@@ -225,11 +233,11 @@ def list_tasks():
225
  @app.get("/grader")
226
  def grader():
227
  if env.state is None:
228
- return {"score": 0.0}
229
  # Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
230
  baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
231
  scale_factor = env.config.max_steps * 1500.0
232
- score = max(0.0, min(1.0, (env.total_reward + baseline_offset) / scale_factor))
233
  return {"score": score}
234
 
235
  @app.post("/baseline")
 
6
 
7
  models_py = """
8
  from pydantic import BaseModel, Field
9
+ from typing import Dict, Literal, List, Optional
10
+
11
+ class ResetRequest(BaseModel):
12
+ task_id: str = "easy_spring"
13
 
14
  class Observation(BaseModel):
15
  time_step: int
 
187
  """
188
 
189
  main_py = """
190
+ from fastapi import FastAPI, HTTPException, Body
191
+ from typing import Optional
192
+ from src.models import Action, TaskConfig, ResetRequest
193
  from src.env import DesalEnv
194
  from src.tasks import TASKS
195
  import subprocess
 
202
  return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
203
 
204
  @app.post("/reset")
205
+ def reset_env(task_id: str = "easy_spring", req: Optional[ResetRequest] = None):
206
+ # Support both GET query params and POST JSON body for task_id
207
+ if req and req.task_id != "easy_spring":
208
+ task_id = req.task_id
209
+
210
  if task_id not in TASKS:
211
  raise HTTPException(status_code=404, detail="Task not found")
212
  obs = env.reset(TASKS[task_id])
 
233
  @app.get("/grader")
234
  def grader():
235
  if env.state is None:
236
+ return {"score": 0.001}
237
  # Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
238
  baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
239
  scale_factor = env.config.max_steps * 1500.0
240
+ score = max(0.001, min(0.999, (env.total_reward + baseline_offset) / scale_factor))
241
  return {"score": score}
242
 
243
  @app.post("/baseline")
server/app.py CHANGED
@@ -1,8 +1,9 @@
1
- from fastapi import FastAPI, HTTPException
2
- from src.models import Action, TaskConfig
3
  from src.env import DesalEnv
4
  from src.tasks import TASKS
5
  import subprocess
 
6
 
7
  app = FastAPI(title="Advanced Municipal Desalination Plant Env")
8
  env = DesalEnv()
@@ -12,7 +13,11 @@ def health_check():
12
  return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
13
 
14
  @app.post("/reset")
15
- def reset_env(task_id: str = "easy_spring"):
 
 
 
 
16
  if task_id not in TASKS:
17
  raise HTTPException(status_code=404, detail="Task not found")
18
  obs = env.reset(TASKS[task_id])
@@ -39,11 +44,11 @@ def list_tasks():
39
  @app.get("/grader")
40
  def grader():
41
  if env.state is None:
42
- return {"score": 0.0}
43
  # Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
44
  baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
45
  scale_factor = env.config.max_steps * 1500.0
46
- score = max(0.0, min(1.0, (env.total_reward + baseline_offset) / scale_factor))
47
  return {"score": score}
48
 
49
  @app.post("/baseline")
 
1
+ from fastapi import FastAPI, HTTPException, Body
2
+ from src.models import Action, TaskConfig, ResetRequest
3
  from src.env import DesalEnv
4
  from src.tasks import TASKS
5
  import subprocess
6
+ from typing import Optional
7
 
8
  app = FastAPI(title="Advanced Municipal Desalination Plant Env")
9
  env = DesalEnv()
 
13
  return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
14
 
15
  @app.post("/reset")
16
+ def reset_env(task_id: str = "easy_spring", req: Optional[ResetRequest] = None):
17
+ # Support both GET query params and POST JSON body for task_id
18
+ if req and req.task_id != "easy_spring":
19
+ task_id = req.task_id
20
+
21
  if task_id not in TASKS:
22
  raise HTTPException(status_code=404, detail="Task not found")
23
  obs = env.reset(TASKS[task_id])
 
44
  @app.get("/grader")
45
  def grader():
46
  if env.state is None:
47
+ return {"score": 0.001}
48
  # Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
49
  baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
50
  scale_factor = env.config.max_steps * 1500.0
51
+ score = max(0.001, min(0.999, (env.total_reward + baseline_offset) / scale_factor))
52
  return {"score": score}
53
 
54
  @app.post("/baseline")
src/models.py CHANGED
@@ -1,5 +1,8 @@
1
  from pydantic import BaseModel, Field
2
- from typing import Dict, Literal, List
 
 
 
3
 
4
  class Observation(BaseModel):
5
  time_step: int
 
1
  from pydantic import BaseModel, Field
2
+ from typing import Dict, Literal, List, Optional
3
+
4
+ class ResetRequest(BaseModel):
5
+ task_id: str = "easy_spring"
6
 
7
  class Observation(BaseModel):
8
  time_step: int