JAYASREESS commited on
Commit
ccc1c93
·
verified ·
1 Parent(s): 6571ce4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from fastapi import Request
5
+ from openenv.core.env_server.http_server import create_app
6
+ from fastapi import HTTPException
7
+
8
+ try:
9
+ from ..models import (
10
+ BaselineRequest,
11
+ BaselineScores,
12
+ GraderRequest,
13
+ GraderResponse,
14
+ GridAction,
15
+ GridObservation,
16
+ PlanningContextRequest,
17
+ PlanningContextResponse,
18
+ SimulationRequest,
19
+ SimulationResponse,
20
+ TaskListResponse,
21
+ )
22
+ from .graders import grade_episode
23
+ from .grid_environment import GridEnvironment
24
+ from .logging_utils import configure_logging
25
+ from .tasks import task_list
26
+ except ImportError:
27
+ from models import (
28
+ BaselineRequest,
29
+ BaselineScores,
30
+ GraderRequest,
31
+ GraderResponse,
32
+ GridAction,
33
+ GridObservation,
34
+ PlanningContextRequest,
35
+ PlanningContextResponse,
36
+ SimulationRequest,
37
+ SimulationResponse,
38
+ TaskListResponse,
39
+ )
40
+ from server.graders import grade_episode
41
+ from server.grid_environment import GridEnvironment
42
+ from server.logging_utils import configure_logging
43
+ from server.tasks import task_list
44
+
45
+ configure_logging()
46
+ logger = logging.getLogger(__name__)
47
+
48
+ app = create_app(
49
+ GridEnvironment,
50
+ GridAction,
51
+ GridObservation,
52
+ env_name="grid2op_env",
53
+ max_concurrent_envs=2,
54
+ )
55
+
56
+
57
+ @app.get("/tasks", response_model=TaskListResponse)
58
+ def get_tasks() -> TaskListResponse:
59
+ logger.info("Serving /tasks")
60
+ return TaskListResponse(
61
+ tasks=task_list(),
62
+ action_schema=GridAction.model_json_schema(),
63
+ )
64
+
65
+
66
+ @app.post("/grader", response_model=GraderResponse)
67
+ def post_grader(payload: GraderRequest) -> GraderResponse:
68
+ logger.info(
69
+ "Serving /grader task_id=%s steps=%s",
70
+ payload.task_id,
71
+ len(payload.episode_log),
72
+ )
73
+ return GraderResponse(
74
+ task_id=payload.task_id,
75
+ score=grade_episode(payload.task_id, payload.episode_log),
76
+ )
77
+
78
+
79
+ @app.post("/baseline", response_model=BaselineScores)
80
+ def run_baseline_route(payload: BaselineRequest, request: Request) -> BaselineScores:
81
+ from ..inference import run_baseline_suite
82
+
83
+ base_url = str(request.base_url).rstrip("/")
84
+ logger.info("Serving /baseline model=%s base_url=%s", payload.model, base_url)
85
+ return run_baseline_suite(base_url=base_url, config=payload)
86
+
87
+
88
+ @app.post("/planning_context", response_model=PlanningContextResponse)
89
+ def post_planning_context(payload: PlanningContextRequest) -> PlanningContextResponse:
90
+ env = GridEnvironment.get_active_instance(payload.episode_id)
91
+ if env is None:
92
+ raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}")
93
+ logger.info("Serving /planning_context episode_id=%s", payload.episode_id)
94
+ return env.get_planning_context()
95
+
96
+
97
+ @app.post("/simulate", response_model=SimulationResponse)
98
+ def post_simulate(payload: SimulationRequest) -> SimulationResponse:
99
+ env = GridEnvironment.get_active_instance(payload.episode_id)
100
+ if env is None:
101
+ raise HTTPException(status_code=404, detail=f"Unknown episode_id: {payload.episode_id}")
102
+ logger.info(
103
+ "Serving /simulate episode_id=%s candidate_count=%s",
104
+ payload.episode_id,
105
+ len(payload.actions),
106
+ )
107
+ return SimulationResponse(
108
+ episode_id=payload.episode_id,
109
+ results=env.simulate_actions(payload.actions),
110
+ )
111
+
112
+
113
+ def main(host: str = "0.0.0.0", port: int = 7860) -> None:
114
+ import argparse
115
+ import uvicorn
116
+
117
+ parser = argparse.ArgumentParser()
118
+ parser.add_argument("--host", default=host)
119
+ parser.add_argument("--port", type=int, default=port)
120
+ args = parser.parse_args()
121
+ logger.info("Starting Grid2Op FastAPI server host=%s port=%s", args.host, args.port)
122
+ uvicorn.run(app, host=args.host, port=args.port)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()