Daniel Marques commited on
Commit
d815dea
1 Parent(s): 25121ae

fix: add websocket in handlerToken

Browse files
Files changed (1) hide show
  1. main.py +17 -11
main.py CHANGED
@@ -5,7 +5,7 @@ import subprocess
5
 
6
  from typing import Any, Dict, List
7
 
8
- from fastapi import FastAPI, HTTPException, UploadFile, WebSocket
9
  from fastapi.staticfiles import StaticFiles
10
 
11
  from pydantic import BaseModel
@@ -58,7 +58,7 @@ RETRIEVER = DB.as_retriever()
58
  class MyCustomSyncHandler(BaseCallbackHandler):
59
  def __init__(self):
60
  self.end = False
61
- self.token = ""
62
 
63
  def on_llm_start(
64
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -69,6 +69,11 @@ class MyCustomSyncHandler(BaseCallbackHandler):
69
  self.end = True
70
 
71
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
 
 
 
 
 
72
  self.token += token
73
 
74
 
@@ -76,11 +81,8 @@ handlerToken = MyCustomSyncHandler()
76
 
77
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[])
78
 
79
- template = """You are a helpful, respectful and honest assistant. You should only use the documents provided in the context to answer the questions.
80
- You should only respond only topics that contains in documents use to training. Use the following pieces of context to answer the question at the end.
81
- Always answer in the most helpful and safe way possible.
82
- If you don't know the answer to a question, just say that you "I don't know", never try to make up an answer and don't share false information.
83
- Use 15 sentences maximum. The answer must be as concise as possible. Always say "thanks for asking!" at the end of the answer.
84
  Context: {context}
85
  Question: {question}
86
  """
@@ -250,12 +252,16 @@ async def websocket_endpoint(websocket: WebSocket):
250
  global QA
251
 
252
  await websocket.accept()
253
- while True:
254
- data = await websocket.receive_text()
 
255
 
256
- res = QA(data)
 
 
257
 
258
- await websocket.send_text(f"result: {res}")
 
259
 
260
 
261
 
 
5
 
6
  from typing import Any, Dict, List
7
 
8
+ from fastapi import FastAPI, HTTPException, UploadFile, WebSocket, WebSocketDisconnect
9
  from fastapi.staticfiles import StaticFiles
10
 
11
  from pydantic import BaseModel
 
58
  class MyCustomSyncHandler(BaseCallbackHandler):
59
  def __init__(self):
60
  self.end = False
61
+ self.websocket = None
62
 
63
  def on_llm_start(
64
  self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
 
69
  self.end = True
70
 
71
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
72
+ if self.websocket != None:
73
+ self.websocket.send_text(token)
74
+
75
+ print(token)
76
+
77
  self.token += token
78
 
79
 
 
81
 
82
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[])
83
 
84
+ template = """You are a helpful, respectful and honest assistant.
85
+ Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer.
 
 
 
86
  Context: {context}
87
  Question: {question}
88
  """
 
252
  global QA
253
 
254
  await websocket.accept()
255
+ try:
256
+ while True:
257
+ handlerToken.websocket = websocket
258
 
259
+ data = await websocket.receive_text()
260
+ res = QA(data)
261
+ print(res)
262
 
263
+ except WebSocketDisconnect:
264
+ await websocket.send_text(f"disconnect")
265
 
266
 
267