Daniel Marques commited on
Commit
5d66516
·
1 Parent(s): 57fab83

feat: add broadcast

Browse files
Files changed (6) hide show
  1. main.py +46 -20
  2. redisPubSubManger.py +69 -0
  3. requirements.txt +1 -1
  4. run.sh +1 -2
  5. test.txt +0 -1
  6. webSocketManger.py +70 -0
main.py CHANGED
@@ -1,13 +1,13 @@
1
- from typing import Any, Dict, Union
2
-
3
  import os
4
  import glob
5
  import shutil
6
  import subprocess
7
  import torch
 
8
 
9
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
10
  from fastapi.staticfiles import StaticFiles
 
11
 
12
  from pydantic import BaseModel
13
 
@@ -55,6 +55,8 @@ QA = RetrievalQA.from_chain_type(
55
  },
56
  )
57
 
 
 
58
  app = FastAPI(title="homepage-app")
59
  api_app = FastAPI(title="api app")
60
 
@@ -162,8 +164,6 @@ def predict(data: Predict):
162
  except Exception as e:
163
  raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
164
 
165
-
166
-
167
  @api_app.post("/save_document/")
168
  async def create_upload_file(file: UploadFile):
169
  # Get the file size (in bytes)
@@ -204,31 +204,57 @@ async def create_upload_file(file: UploadFile):
204
 
205
  return {"filename": file.filename}
206
 
207
- @api_app.websocket("/ws/{client_id}")
208
- async def websocket_endpoint(websocket: WebSocket, client_id: str):
209
  global QA
210
 
211
- await websocket.accept()
 
 
 
 
 
 
 
 
212
 
213
  try:
214
  while True:
215
- user_prompt = await websocket.receive_text()
216
- response = QA(inputs=user_prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
217
- answer, docs = response["result"], response["source_documents"]
218
 
219
- prompt_response_dict = {
220
- "Prompt": user_prompt,
221
- "Answer": answer,
 
222
  }
223
 
224
- prompt_response_dict["Sources"] = []
225
- for document in docs:
226
- prompt_response_dict["Sources"].append(
227
- (os.path.basename(str(document.metadata["source"])), str(document.page_content))
228
- )
229
- await websocket.send_json(prompt_response_dict)
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  except WebSocketDisconnect:
232
- print('disconnect')
 
 
 
 
 
 
 
 
233
  except RuntimeError as error:
234
  print(error)
 
 
 
1
  import os
2
  import glob
3
  import shutil
4
  import subprocess
5
  import torch
6
+ import json
7
 
8
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
9
  from fastapi.staticfiles import StaticFiles
10
+ from websocket.socketManager import WebSocketManager
11
 
12
  from pydantic import BaseModel
13
 
 
55
  },
56
  )
57
 
58
+ socket_manager = WebSocketManager()
59
+
60
  app = FastAPI(title="homepage-app")
61
  api_app = FastAPI(title="api app")
62
 
 
164
  except Exception as e:
165
  raise HTTPException(status_code=500, detail=f"Error occurred: {str(e)}")
166
 
 
 
167
  @api_app.post("/save_document/")
168
  async def create_upload_file(file: UploadFile):
169
  # Get the file size (in bytes)
 
204
 
205
  return {"filename": file.filename}
206
 
207
+ @api_app.websocket("/ws/{room_id}/{user_id}")
208
+ async def websocket_endpoint(websocket: WebSocket, room_id: str, user_id: int):
209
  global QA
210
 
211
+ await socket_manager.add_user_to_room(room_id, websocket)
212
+
213
+ message = {
214
+ "user_id": user_id,
215
+ "room_id": room_id,
216
+ "message": f"User {user_id} connected to room - {room_id}"
217
+ }
218
+
219
+ await socket_manager.broadcast_to_room(room_id, json.dumps(message))
220
 
221
  try:
222
  while True:
223
+ data = await websocket.receive_text()
 
 
224
 
225
+ message = {
226
+ "user_id": user_id,
227
+ "room_id": room_id,
228
+ "message": data
229
  }
230
 
231
+ await socket_manager.broadcast_to_room(room_id, json.dumps(message))
232
+
233
+ # user_prompt = await websocket.receive_text()
234
+ # response = QA(inputs=user_prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True)
235
+ # answer, docs = response["result"], response["source_documents"]
236
+
237
+ # prompt_response_dict = {
238
+ # "Prompt": user_prompt,
239
+ # "Answer": answer,
240
+ # }
241
+
242
+ # prompt_response_dict["Sources"] = []
243
+ # for document in docs:
244
+ # prompt_response_dict["Sources"].append(
245
+ # (os.path.basename(str(document.metadata["source"])), str(document.page_content))
246
+ # )
247
+ # await websocket.send_json(prompt_response_dict)
248
 
249
  except WebSocketDisconnect:
250
+ await socket_manager.remove_user_from_room(room_id, websocket)
251
+
252
+ message = {
253
+ "user_id": user_id,
254
+ "room_id": room_id,
255
+ "message": f"User {user_id} disconnected from room - {room_id}"
256
+ }
257
+
258
+ await socket_manager.broadcast_to_room(room_id, json.dumps(message))
259
  except RuntimeError as error:
260
  print(error)
redisPubSubManger.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import redis.asyncio as aioredis
3
+ import json
4
+ from fastapi import WebSocket
5
+
6
+
7
+ class RedisPubSubManager:
8
+ """
9
+ Initializes the RedisPubSubManager.
10
+
11
+ Args:
12
+ host (str): Redis server host.
13
+ port (int): Redis server port.
14
+ """
15
+
16
+ def __init__(self, host='localhost', port=6379):
17
+ self.redis_host = host
18
+ self.redis_port = port
19
+ self.pubsub = None
20
+
21
+ async def _get_redis_connection(self) -> aioredis.Redis:
22
+ """
23
+ Establishes a connection to Redis.
24
+
25
+ Returns:
26
+ aioredis.Redis: Redis connection object.
27
+ """
28
+ return aioredis.Redis(host=self.redis_host,
29
+ port=self.redis_port,
30
+ auto_close_connection_pool=False)
31
+
32
+ async def connect(self) -> None:
33
+ """
34
+ Connects to the Redis server and initializes the pubsub client.
35
+ """
36
+ self.redis_connection = await self._get_redis_connection()
37
+ self.pubsub = self.redis_connection.pubsub()
38
+
39
+ async def _publish(self, room_id: str, message: str) -> None:
40
+ """
41
+ Publishes a message to a specific Redis channel.
42
+
43
+ Args:
44
+ room_id (str): Channel or room ID.
45
+ message (str): Message to be published.
46
+ """
47
+ await self.redis_connection.publish(room_id, message)
48
+
49
+ async def subscribe(self, room_id: str) -> aioredis.Redis:
50
+ """
51
+ Subscribes to a Redis channel.
52
+
53
+ Args:
54
+ room_id (str): Channel or room ID to subscribe to.
55
+
56
+ Returns:
57
+ aioredis.ChannelSubscribe: PubSub object for the subscribed channel.
58
+ """
59
+ await self.pubsub.subscribe(room_id)
60
+ return self.pubsub
61
+
62
+ async def unsubscribe(self, room_id: str) -> None:
63
+ """
64
+ Unsubscribes from a Redis channel.
65
+
66
+ Args:
67
+ room_id (str): Channel or room ID to unsubscribe from.
68
+ """
69
+ await self.pubsub.unsubscribe(room_id)
requirements.txt CHANGED
@@ -29,7 +29,7 @@ uvicorn
29
  fastapi
30
  websockets
31
  pydantic
32
- redis
33
 
34
  # Streamlit related
35
  streamlit
 
29
  fastapi
30
  websockets
31
  pydantic
32
+ aioredis
33
 
34
  # Streamlit related
35
  streamlit
run.sh CHANGED
@@ -1,5 +1,4 @@
1
  # Redis Support uncomment this lines
2
- # redis-cli --version
3
- # nohup redis-server &
4
 
5
  uvicorn "main:app" --port 7860 --host 0.0.0.0
 
1
  # Redis Support uncomment this lines
2
+ nohup redis-server &
 
3
 
4
  uvicorn "main:app" --port 7860 --host 0.0.0.0
test.txt DELETED
@@ -1 +0,0 @@
1
- dkdaniz is an avatar of instagram, create by daniel marques
 
 
webSocketManger.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class WebSocketManager:
2
+ def __init__(self):
3
+ """
4
+ Initializes the WebSocketManager.
5
+
6
+ Attributes:
7
+ rooms (dict): A dictionary to store WebSocket connections in different rooms.
8
+ pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality.
9
+ """
10
+ self.rooms: dict = {}
11
+ self.pubsub_client = RedisPubSubManager()
12
+
13
+ async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
14
+ """
15
+ Adds a user's WebSocket connection to a room.
16
+
17
+ Args:
18
+ room_id (str): Room ID or channel name.
19
+ websocket (WebSocket): WebSocket connection object.
20
+ """
21
+ await websocket.accept()
22
+
23
+ if room_id in self.rooms:
24
+ self.rooms[room_id].append(websocket)
25
+ else:
26
+ self.rooms[room_id] = [websocket]
27
+
28
+ await self.pubsub_client.connect()
29
+ pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
30
+ asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
31
+
32
+ async def broadcast_to_room(self, room_id: str, message: str) -> None:
33
+ """
34
+ Broadcasts a message to all connected WebSockets in a room.
35
+
36
+ Args:
37
+ room_id (str): Room ID or channel name.
38
+ message (str): Message to be broadcasted.
39
+ """
40
+ await self.pubsub_client._publish(room_id, message)
41
+
42
+ async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
43
+ """
44
+ Removes a user's WebSocket connection from a room.
45
+
46
+ Args:
47
+ room_id (str): Room ID or channel name.
48
+ websocket (WebSocket): WebSocket connection object.
49
+ """
50
+ self.rooms[room_id].remove(websocket)
51
+
52
+ if len(self.rooms[room_id]) == 0:
53
+ del self.rooms[room_id]
54
+ await self.pubsub_client.unsubscribe(room_id)
55
+
56
+ async def _pubsub_data_reader(self, pubsub_subscriber):
57
+ """
58
+ Reads and broadcasts messages received from Redis PubSub.
59
+
60
+ Args:
61
+ pubsub_subscriber (aioredis.ChannelSubscribe): PubSub object for the subscribed channel.
62
+ """
63
+ while True:
64
+ message = await pubsub_subscriber.get_message(ignore_subscribe_messages=True)
65
+ if message is not None:
66
+ room_id = message['channel'].decode('utf-8')
67
+ all_sockets = self.rooms[room_id]
68
+ for socket in all_sockets:
69
+ data = message['data'].decode('utf-8')
70
+ await socket.send_text(data)