Daniel Marques commited on
Commit
abff149
1 Parent(s): e04f8da

fix: add websocket in handlerToken

Browse files
Files changed (1) hide show
  1. main.py +8 -21
main.py CHANGED
@@ -20,8 +20,6 @@ from langchain.memory import ConversationBufferMemory
20
  from langchain.callbacks.base import BaseCallbackHandler
21
  from langchain.schema import LLMResult
22
 
23
- from varstate import State
24
-
25
  # from langchain.embeddings import HuggingFaceEmbeddings
26
  from load_models import load_model
27
 
@@ -58,9 +56,9 @@ DB = Chroma(
58
  RETRIEVER = DB.as_retriever()
59
 
60
  class MyCustomSyncHandler(BaseCallbackHandler):
61
- def __init__(self, state):
62
  self.end = False
63
- self.state = state
64
 
65
  def on_llm_start(
66
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -77,14 +75,8 @@ class MyCustomSyncHandler(BaseCallbackHandler):
77
 
78
  print(token)
79
 
80
-
81
  # Create State
82
-
83
- tokenMessageLLM = State()
84
-
85
- get, update = tokenMessageLLM.create('')
86
-
87
- handlerToken = MyCustomSyncHandler(update)
88
 
89
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
90
 
@@ -253,8 +245,8 @@ async def create_upload_file(file: UploadFile):
253
 
254
  return {"filename": file.filename}
255
 
256
- @api_app.websocket("/ws")
257
- async def websocket_endpoint(websocket: WebSocket):
258
  global QA
259
 
260
  await websocket.accept()
@@ -265,16 +257,11 @@ async def websocket_endpoint(websocket: WebSocket):
265
  while True:
266
  prompt = await websocket.receive_text()
267
 
268
- statusProcess = handlerToken.end
269
 
270
- if (oldReceiveText != prompt) and (statusProcess == False) :
271
  oldReceiveText = prompt
272
- res = QA(prompt)
273
-
274
- print(statusProcess);
275
-
276
- tokenState = get()
277
- await websocket.send_text(f"token: {tokenState}")
278
 
279
  except WebSocketDisconnect:
280
  print('disconnect')
 
20
  from langchain.callbacks.base import BaseCallbackHandler
21
  from langchain.schema import LLMResult
22
 
 
 
23
  # from langchain.embeddings import HuggingFaceEmbeddings
24
  from load_models import load_model
25
 
 
56
  RETRIEVER = DB.as_retriever()
57
 
58
  class MyCustomSyncHandler(BaseCallbackHandler):
59
+ def __init__(self):
60
  self.end = False
61
+ self.callback = None
62
 
63
  def on_llm_start(
64
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
 
75
 
76
  print(token)
77
 
 
78
  # Create State
79
+ handlerToken = MyCustomSyncHandler()
 
 
 
 
 
80
 
81
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
82
 
 
245
 
246
  return {"filename": file.filename}
247
 
248
+ @api_app.websocket("/ws/{client_id}")
249
+ async def websocket_endpoint(websocket: WebSocket, client_id: int):
250
  global QA
251
 
252
  await websocket.accept()
 
257
  while True:
258
  prompt = await websocket.receive_text()
259
 
260
+ handlerToken.callback = websocket.send_text;
261
 
262
+ if (oldReceiveText != prompt):
263
  oldReceiveText = prompt
264
+ asyncio.run(QA(prompt))
 
 
 
 
 
265
 
266
  except WebSocketDisconnect:
267
  print('disconnect')