Silicon Valley - Admin commited on
Commit
76ca3e9
1 Parent(s): cce3030

Refactor server.py and update requirements.txt for enhanced functionality and logging

Browse files

- Simplified server.py by removing unused imports and restructuring configuration settings.
- Introduced new data models for read and write operations, improving API clarity.
- Enhanced logging to include remote addresses for better traceability of requests and responses.
- Updated requirements.txt to include uvicorn for ASGI support, ensuring compatibility with the new server structure.

Files changed (3) hide show
  1. broker.py +73 -0
  2. requirements.txt +2 -1
  3. server.py +91 -59
broker.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import uuid
3
+ from typing import AsyncGenerator, Dict, Tuple, Any
4
+ from dataclasses import dataclass
5
+
6
+ class SessionDoesNotExist(Exception):
7
+ pass
8
+
9
+ class SessionAlreadyExists(Exception):
10
+ pass
11
+
12
+ class ClientError(Exception):
13
+ def __init__(self, message):
14
+ super().__init__(message)
15
+ self.message = message
16
+
17
+ @dataclass
18
+ class ClientRequest:
19
+ request_id: str
20
+ data: Any
21
+
22
+ @dataclass
23
+ class ClientResponse:
24
+ request_id: str
25
+ error: bool
26
+ data: Any
27
+
28
+ class SessionBroker:
29
+ def __init__(self):
30
+ self.sessions: Dict[str, asyncio.Queue] = {}
31
+ self.pending_responses: Dict[Tuple[str, str], asyncio.Future] = {}
32
+
33
+ async def send_request(self, session_id: str, data: Any, timeout: int = 60) -> Any:
34
+ if session_id not in self.sessions:
35
+ raise SessionDoesNotExist()
36
+
37
+ request_id = str(uuid.uuid4())
38
+ future = asyncio.get_event_loop().create_future()
39
+ self.pending_responses[(session_id, request_id)] = future
40
+
41
+ await self.sessions[session_id].put(ClientRequest(request_id=request_id, data=data))
42
+
43
+ try:
44
+ return await asyncio.wait_for(future, timeout)
45
+ except asyncio.TimeoutError:
46
+ raise
47
+ finally:
48
+ if (session_id, request_id) in self.pending_responses:
49
+ del self.pending_responses[(session_id, request_id)]
50
+
51
+ async def receive_response(self, session_id: str, response: ClientResponse) -> None:
52
+ if (session_id, response.request_id) in self.pending_responses:
53
+ future = self.pending_responses.pop((session_id, response.request_id))
54
+ if not future.done():
55
+ if response.error:
56
+ future.set_exception(ClientError(message=response.data))
57
+ else:
58
+ future.set_result(response.data)
59
+
60
+ async def subscribe(self, session_id: str) -> AsyncGenerator[ClientRequest, None]:
61
+ if session_id in self.sessions:
62
+ raise SessionAlreadyExists()
63
+
64
+ queue = asyncio.Queue()
65
+ self.sessions[session_id] = queue
66
+
67
+ try:
68
+ while True:
69
+ yield await queue.get()
70
+ finally:
71
+ if session_id in self.sessions:
72
+ del self.sessions[session_id]
73
+ self.pending_responses = {k: v for k, v in self.pending_responses.items() if k[0] != session_id}
requirements.txt CHANGED
@@ -7,4 +7,5 @@ websockets==11.0.3
7
  python-json-logger==2.0.7
8
  prometheus-client==0.17.1
9
  pydantic==1.10.13
10
- python-dotenv==1.0.0
 
 
7
  python-json-logger==2.0.7
8
  prometheus-client==0.17.1
9
  pydantic==1.10.13
10
+ python-dotenv==1.0.0
11
+ uvicorn==0.24.0
server.py CHANGED
@@ -1,45 +1,32 @@
1
  # server.py
2
- import asyncio
3
- import uuid
4
- from typing import AsyncGenerator, Dict, Tuple, Any, Optional
5
- from dataclasses import dataclass
6
- from quart import Quart, websocket, request, Response
7
- from quart_schema import QuartSchema, validate_request, validate_response
8
- from quart_cors import cors
9
  import importlib.metadata
10
  import secrets
11
  import logging
12
- import os
 
 
 
 
 
 
13
 
14
  from broker import SessionBroker, SessionDoesNotExist, ClientRequest, ClientResponse, ClientError
15
 
16
- # Configuraci贸n para Hugging Face Spaces
17
- PORT = int(os.getenv('PORT', 7860))
18
- TIMEOUT: int = int(os.getenv('TIMEOUT', 60))
19
- LOG_LEVEL: int = getattr(logging, os.getenv('LOG_LEVEL', 'INFO'))
20
- MAX_MESSAGE_SIZE: int = int(os.getenv('MAX_MESSAGE_SIZE', 16 * 1024 * 1024))
21
- RATE_LIMIT: int = int(os.getenv('RATE_LIMIT', 100))
22
- SESSION_TIMEOUT: int = int(os.getenv('SESSION_TIMEOUT', 3600))
23
-
24
- # Configurar logging
25
- logging.basicConfig(
26
- level=LOG_LEVEL,
27
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
28
- )
29
-
30
- # Crear aplicaci贸n con CORS habilitado
31
  app = Quart(__name__)
32
- app = cors(app,
33
- allow_origin=["https://*.hf.space", "https://*.huggingface.co"],
34
- allow_methods=["GET", "POST", "OPTIONS"],
35
- allow_headers=["Content-Type"],
36
- max_age=3600
37
- )
38
  QuartSchema(app)
 
 
39
 
40
  broker = SessionBroker()
41
 
42
- # Definici贸n de modelos de datos
43
  @dataclass
44
  class Status:
45
  status: str
@@ -54,12 +41,31 @@ class Command:
54
  session_id: str
55
  command: str
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  @dataclass
58
  class CommandResponse:
59
  return_code: int
60
  stdout: str
61
  stderr: str
62
 
 
 
 
 
 
 
 
 
63
  @dataclass
64
  class ErrorResponse:
65
  error: str
@@ -73,21 +79,24 @@ async def status() -> Status:
73
  @app.websocket('/session')
74
  async def session_handler():
75
  session_id = secrets.token_hex()
76
- app.logger.info(f"New session: {session_id}")
77
  await websocket.send_as(Session(session_id=session_id), Session)
78
 
79
- task = asyncio.ensure_future(_receive(session_id))
80
  try:
 
81
  async for request in broker.subscribe(session_id):
82
- app.logger.info(f"Sending request {request.request_id} to client.")
83
  await websocket.send_as(request, ClientRequest)
84
  finally:
85
- task.cancel()
 
 
86
 
87
  async def _receive(session_id: str) -> None:
88
  while True:
89
  response = await websocket.receive_as(ClientResponse)
90
- app.logger.info(f"Received response for session {session_id}: {response}")
91
  await broker.receive_response(session_id, response)
92
 
93
  @app.post('/command')
@@ -96,33 +105,56 @@ async def _receive(session_id: str) -> None:
96
  @validate_response(ErrorResponse, 500)
97
  async def command(data: Command) -> Tuple[CommandResponse | ErrorResponse, int]:
98
  try:
99
- response_data = await broker.send_request(
100
- session_id=data.session_id,
101
- data={'action': 'command', 'command': data.command},
102
- timeout=TIMEOUT
103
- )
104
- response = CommandResponse(**response_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  return response, 200
106
  except SessionDoesNotExist:
107
- app.logger.warning(f"Invalid session ID: {data.session_id}")
108
- return ErrorResponse(error='Session does not exist.'), 500
109
  except ClientError as e:
110
- return ErrorResponse(error=e.message), 500
111
  except asyncio.TimeoutError:
112
- return ErrorResponse(error='Timeout waiting for client response.'), 500
113
 
114
- # Ejecutar aplicaci贸n
115
  def run():
116
- app.run(
117
- host='0.0.0.0',
118
- port=PORT,
119
- debug=False
120
- )
121
-
122
- # Agregar un endpoint de health check
123
- @app.route("/health")
124
- async def health_check():
125
- return {"status": "healthy"}
126
-
127
- if __name__ == "__main__":
128
- run()
 
1
  # server.py
2
+ from dataclasses import dataclass, asdict
 
 
 
 
 
 
3
  import importlib.metadata
4
  import secrets
5
  import logging
6
+ import asyncio
7
+ import json
8
+ from typing import Tuple
9
+
10
+ from quart import Quart, websocket, request
11
+ from quart_schema import QuartSchema, validate_request, validate_response
12
+ from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
13
 
14
  from broker import SessionBroker, SessionDoesNotExist, ClientRequest, ClientResponse, ClientError
15
 
16
+ # Configuraci贸n
17
+ TIMEOUT: int = 40
18
+ LOG_LEVEL: int = logging.DEBUG
19
+ TRUSTED_HOSTS: list[str] = ["127.0.0.1"]
20
+
21
+ # Create app
 
 
 
 
 
 
 
 
 
22
  app = Quart(__name__)
 
 
 
 
 
 
23
  QuartSchema(app)
24
+ app.asgi_app = ProxyHeadersMiddleware(app.asgi_app, trusted_hosts=TRUSTED_HOSTS)
25
+ app.logger.setLevel(LOG_LEVEL)
26
 
27
  broker = SessionBroker()
28
 
29
+ # Modelos de datos
30
  @dataclass
31
  class Status:
32
  status: str
 
41
  session_id: str
42
  command: str
43
 
44
+ @dataclass
45
+ class Read:
46
+ session_id: str
47
+ path: str
48
+
49
+ @dataclass
50
+ class Write:
51
+ session_id: str
52
+ path: str
53
+ content: str
54
+
55
  @dataclass
56
  class CommandResponse:
57
  return_code: int
58
  stdout: str
59
  stderr: str
60
 
61
+ @dataclass
62
+ class ReadResponse:
63
+ content: str
64
+
65
+ @dataclass
66
+ class WriteResponse:
67
+ size: int
68
+
69
  @dataclass
70
  class ErrorResponse:
71
  error: str
 
79
  @app.websocket('/session')
80
  async def session_handler():
81
  session_id = secrets.token_hex()
82
+ app.logger.info(f"{websocket.remote_addr} - NEW SESSION - {session_id}")
83
  await websocket.send_as(Session(session_id=session_id), Session)
84
 
85
+ task = None
86
  try:
87
+ task = asyncio.ensure_future(_receive(session_id))
88
  async for request in broker.subscribe(session_id):
89
+ app.logger.info(f"{websocket.remote_addr} - REQUEST - {session_id} - {json.dumps(asdict(request))}")
90
  await websocket.send_as(request, ClientRequest)
91
  finally:
92
+ if task is not None:
93
+ task.cancel()
94
+ await task
95
 
96
  async def _receive(session_id: str) -> None:
97
  while True:
98
  response = await websocket.receive_as(ClientResponse)
99
+ app.logger.info(f"{websocket.remote_addr} - RESPONSE - {session_id} - {json.dumps(asdict(response))}")
100
  await broker.receive_response(session_id, response)
101
 
102
  @app.post('/command')
 
105
  @validate_response(ErrorResponse, 500)
106
  async def command(data: Command) -> Tuple[CommandResponse | ErrorResponse, int]:
107
  try:
108
+ response = CommandResponse(**await broker.send_request(
109
+ data.session_id,
110
+ {'action': 'command', 'command': data.command},
111
+ timeout=TIMEOUT))
112
+ return response, 200
113
+ except SessionDoesNotExist:
114
+ app.logger.warning(f"{request.remote_addr} - INVALID SESSION ID - {repr(data.session_id)}")
115
+ return ErrorResponse('Session does not exist.'), 500
116
+ except ClientError as e:
117
+ return ErrorResponse(e.message), 500
118
+ except asyncio.TimeoutError:
119
+ return ErrorResponse('Timeout when waiting for client.'), 500
120
+
121
+ @app.post('/read')
122
+ @validate_request(Read)
123
+ @validate_response(ReadResponse, 200)
124
+ @validate_response(ErrorResponse, 500)
125
+ async def read(data: Read) -> Tuple[ReadResponse | ErrorResponse, int]:
126
+ try:
127
+ response = ReadResponse(**await broker.send_request(
128
+ data.session_id,
129
+ {'action': 'read', 'path': data.path},
130
+ timeout=TIMEOUT))
131
+ return response, 200
132
+ except SessionDoesNotExist:
133
+ app.logger.warning(f"{request.remote_addr} - INVALID SESSION ID - {repr(data.session_id)}")
134
+ return ErrorResponse('Session does not exist.'), 500
135
+ except ClientError as e:
136
+ return ErrorResponse(e.message), 500
137
+ except asyncio.TimeoutError:
138
+ return ErrorResponse('Timeout when waiting for client.'), 500
139
+
140
+ @app.post('/write')
141
+ @validate_request(Write)
142
+ @validate_response(WriteResponse, 200)
143
+ @validate_response(ErrorResponse, 500)
144
+ async def write(data: Write) -> Tuple[WriteResponse | ErrorResponse, int]:
145
+ try:
146
+ response = WriteResponse(**await broker.send_request(
147
+ data.session_id,
148
+ {'action': 'write', 'path': data.path, 'content': data.content},
149
+ timeout=TIMEOUT))
150
  return response, 200
151
  except SessionDoesNotExist:
152
+ app.logger.warning(f"{request.remote_addr} - INVALID SESSION ID - {repr(data.session_id)}")
153
+ return ErrorResponse('Session does not exist.'), 500
154
  except ClientError as e:
155
+ return ErrorResponse(e.message), 500
156
  except asyncio.TimeoutError:
157
+ return ErrorResponse('Timeout when waiting for client.'), 500
158
 
 
159
  def run():
160
+ app.run(host='0.0.0.0', port=7860)