Daniel Marques commited on
Commit
598b96c
1 Parent(s): f1368ae

feat: add websocket

Browse files
Files changed (3) hide show
  1. main.py +40 -2
  2. prompt_template_utils.py +3 -2
  3. websocket/socketManager.py +14 -1
main.py CHANGED
@@ -213,7 +213,26 @@ async def create_upload_file(file: UploadFile):
213
 
214
  @api_app.websocket("/ws/{user_id}")
215
  async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
216
- global QA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  message = {
219
  "message": f"Student {user_id} connected"
@@ -243,7 +262,26 @@ async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
243
 
244
  @api_app.websocket("/ws/{room_id}/{user_id}")
245
  async def websocket_endpoint_room(websocket: WebSocket, room_id: str, user_id: str):
246
- global QA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  message = {
249
  "message": f"Student {user_id} connected to the classroom"
 
213
 
214
  @api_app.websocket("/ws/{user_id}")
215
  async def websocket_endpoint_student(websocket: WebSocket, user_id: str):
216
+ DB = Chroma(
217
+ persist_directory=PERSIST_DIRECTORY,
218
+ embedding_function=EMBEDDINGS,
219
+ client_settings=CHROMA_SETTINGS,
220
+ )
221
+
222
+ RETRIEVER = DB.as_retriever()
223
+
224
+ newInstanceQA = RetrievalQA.from_chain_type(
225
+ llm=LLM,
226
+ chain_type="stuff",
227
+ retriever=RETRIEVER,
228
+ return_source_documents=SHOW_SOURCES,
229
+ chain_type_kwargs={
230
+ "prompt": prompt,
231
+ "memory": memory
232
+ },
233
+ )
234
+
235
+ QA = socket_manager.get_instance_qa(user_id, newInstanceQA)
236
 
237
  message = {
238
  "message": f"Student {user_id} connected"
 
262
 
263
  @api_app.websocket("/ws/{room_id}/{user_id}")
264
  async def websocket_endpoint_room(websocket: WebSocket, room_id: str, user_id: str):
265
+ DB = Chroma(
266
+ persist_directory=PERSIST_DIRECTORY,
267
+ embedding_function=EMBEDDINGS,
268
+ client_settings=CHROMA_SETTINGS,
269
+ )
270
+
271
+ RETRIEVER = DB.as_retriever()
272
+
273
+ newInstanceQA = RetrievalQA.from_chain_type(
274
+ llm=LLM,
275
+ chain_type="stuff",
276
+ retriever=RETRIEVER,
277
+ return_source_documents=SHOW_SOURCES,
278
+ chain_type_kwargs={
279
+ "prompt": prompt,
280
+ "memory": memory
281
+ },
282
+ )
283
+
284
+ QA = socket_manager.get_instance_qa(room_id, newInstanceQA)
285
 
286
  message = {
287
  "message": f"Student {user_id} connected to the classroom"
prompt_template_utils.py CHANGED
@@ -15,8 +15,9 @@ from langchain.prompts import PromptTemplate
15
 
16
  # system_prompt = """You are a helpful assistant, and you will use the context and documents provided in the training to answer users' questions. Please read the context provided carefully before responding to questions and follow a step-by-step thought process. If you cannot answer a user's question based on the provided context, please inform the user. Do not use any other information to answer the user. Provide a detailed response based on the content of locally trained documents."""
17
 
18
- system_prompt = """It's a useful assistant who will use the context and documents provided in the training to answer users' questions.
19
- Read the context provided before answering the questions and think step by step. If you can't answer, just say "I don't know" and don't try to put together an answer to respond to the user."""
 
20
 
21
  def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, history=False):
22
  if promptTemplate_type == "llama":
 
15
 
16
  # system_prompt = """You are a helpful assistant, and you will use the context and documents provided in the training to answer users' questions. Please read the context provided carefully before responding to questions and follow a step-by-step thought process. If you cannot answer a user's question based on the provided context, please inform the user. Do not use any other information to answer the user. Provide a detailed response based on the content of locally trained documents."""
17
 
18
+ system_prompt = """It's a useful assistant that will use the context and documents provided in the training to answer users' questions.
19
+ Read the context provided before answering the questions and think step by step. Your answer cannot be more than 10 sentences long.
20
+ If you can't answer, just say "I don't know" and don't try to work out an answer to respond to the user."""
21
 
22
  def get_prompt_template(system_prompt=system_prompt, promptTemplate_type=None, history=False):
23
  if promptTemplate_type == "llama":
websocket/socketManager.py CHANGED
@@ -1,9 +1,13 @@
 
1
  import asyncio
2
  import redis.asyncio as aioredis
3
  import json
4
  from fastapi import WebSocket
5
 
6
 
 
 
 
7
  class RedisPubSubManager:
8
  """
9
  Initializes the RedisPubSubManager.
@@ -80,6 +84,7 @@ class WebSocketManager:
80
  pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality.
81
  """
82
  self.rooms: dict = {}
 
83
  self.pubsub_client = RedisPubSubManager()
84
 
85
  async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
@@ -96,7 +101,6 @@ class WebSocketManager:
96
  self.rooms[room_id].append(websocket)
97
  else:
98
  self.rooms[room_id] = [websocket]
99
-
100
  await self.pubsub_client.connect()
101
  pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
102
  asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
@@ -120,6 +124,7 @@ class WebSocketManager:
120
  websocket (WebSocket): WebSocket connection object.
121
  """
122
  self.rooms[room_id].remove(websocket)
 
123
 
124
  if len(self.rooms[room_id]) == 0:
125
  del self.rooms[room_id]
@@ -140,3 +145,11 @@ class WebSocketManager:
140
  for socket in all_sockets:
141
  data = message['data'].decode('utf-8')
142
  await socket.send_text(data)
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
  import asyncio
3
  import redis.asyncio as aioredis
4
  import json
5
  from fastapi import WebSocket
6
 
7
 
8
+
9
+
10
+
11
  class RedisPubSubManager:
12
  """
13
  Initializes the RedisPubSubManager.
 
84
  pubsub_client (RedisPubSubManager): An instance of the RedisPubSubManager class for pub-sub functionality.
85
  """
86
  self.rooms: dict = {}
87
+ self.qa: dict = {}
88
  self.pubsub_client = RedisPubSubManager()
89
 
90
  async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
 
101
  self.rooms[room_id].append(websocket)
102
  else:
103
  self.rooms[room_id] = [websocket]
 
104
  await self.pubsub_client.connect()
105
  pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
106
  asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
 
124
  websocket (WebSocket): WebSocket connection object.
125
  """
126
  self.rooms[room_id].remove(websocket)
127
+ self.qa.pop(room_id, None)
128
 
129
  if len(self.rooms[room_id]) == 0:
130
  del self.rooms[room_id]
 
145
  for socket in all_sockets:
146
  data = message['data'].decode('utf-8')
147
  await socket.send_text(data)
148
+
149
+ async def get_instance_qa(self, room_id: str, QA: Any):
150
+ if room_id in self.qa:
151
+ return self.qa[room_id]
152
+
153
+ self.qa[room_id] = QA
154
+ return self.qa[room_id]
155
+