Daniel Marques commited on
Commit
1fd336c
1 Parent(s): 3a84e24

fix: add websocket in handlerToken

Browse files
Files changed (1) hide show
  1. main.py +19 -20
main.py CHANGED
@@ -41,23 +41,19 @@ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, mode
41
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
42
  RETRIEVER = DB.as_retriever()
43
 
 
 
 
 
44
  class MyCustomSyncHandler(BaseCallbackHandler):
45
- def on_llm_start(
46
- self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
47
- ) -> None:
48
- print(f'on_llm_start self {self}')
49
- print(f'on_llm_start kwargs {prompts}')
50
- print(f'on_llm_start token {kwargs}')
51
-
52
- def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
53
- print(f'on_llm_end self {self}')
54
- print(f'on_llm_end kwargs {response}')
55
- print(f'on_llm_end token {kwargs}')
56
 
57
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
58
- print(f'on_llm_new_token self {self}')
59
- print(f'on_llm_new_token kwargs {kwargs}')
60
- print(f'on_llm_new_token token {token}')
 
61
 
62
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
63
 
@@ -74,10 +70,6 @@ QA = RetrievalQA.from_chain_type(
74
  },
75
  )
76
 
77
- redisClient = redis.Redis(host='localhost', port=6379, db=0)
78
-
79
- redisClient.set('foo', 'bar')
80
-
81
  app = FastAPI(title="homepage-app")
82
  api_app = FastAPI(title="api app")
83
 
@@ -232,6 +224,14 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
232
  try:
233
  while True:
234
  prompt = await websocket.receive_text()
 
 
 
 
 
 
 
 
235
  QA(
236
  inputs=prompt,
237
  return_only_outputs=True,
@@ -240,9 +240,8 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
240
  include_run_info=True
241
  )
242
 
243
- response = redisClient.get('foo')
244
 
245
- await websocket.send_text(response)
246
 
247
  except WebSocketDisconnect:
248
  print('disconnect')
 
41
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
42
  RETRIEVER = DB.as_retriever()
43
 
44
+
45
+ redisClient = redis.Redis(host='localhost', port=6379, db=0)
46
+
47
+
48
  class MyCustomSyncHandler(BaseCallbackHandler):
49
+ def __init__(self):
50
+ self.message = ''
 
 
 
 
 
 
 
 
 
51
 
52
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
53
+ message += token
54
+ redisClient.publish(f'{kwargs["tags"][0]}', message)
55
+
56
+ print(message)
57
 
58
  LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
59
 
 
70
  },
71
  )
72
 
 
 
 
 
73
  app = FastAPI(title="homepage-app")
74
  api_app = FastAPI(title="api app")
75
 
 
224
  try:
225
  while True:
226
  prompt = await websocket.receive_text()
227
+
228
+ pubsub = redisClient.pubsub()
229
+ pubsub.subscribe(f'{client_id}')
230
+
231
+ for item in pubsub.listen():
232
+ if item['type'] == 'message':
233
+ await websocket.send_text(item["data"])
234
+
235
  QA(
236
  inputs=prompt,
237
  return_only_outputs=True,
 
240
  include_run_info=True
241
  )
242
 
 
243
 
244
+
245
 
246
  except WebSocketDisconnect:
247
  print('disconnect')