Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
1fd336c
1
Parent(s):
3a84e24
fix: add websocket in handlerToken
Browse files
main.py
CHANGED
@@ -41,23 +41,19 @@ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, mode
|
|
41 |
DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
|
42 |
RETRIEVER = DB.as_retriever()
|
43 |
|
|
|
|
|
|
|
|
|
44 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
45 |
-
def
|
46 |
-
self
|
47 |
-
) -> None:
|
48 |
-
print(f'on_llm_start self {self}')
|
49 |
-
print(f'on_llm_start kwargs {prompts}')
|
50 |
-
print(f'on_llm_start token {kwargs}')
|
51 |
-
|
52 |
-
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
53 |
-
print(f'on_llm_end self {self}')
|
54 |
-
print(f'on_llm_end kwargs {response}')
|
55 |
-
print(f'on_llm_end token {kwargs}')
|
56 |
|
57 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
|
63 |
|
@@ -74,10 +70,6 @@ QA = RetrievalQA.from_chain_type(
|
|
74 |
},
|
75 |
)
|
76 |
|
77 |
-
redisClient = redis.Redis(host='localhost', port=6379, db=0)
|
78 |
-
|
79 |
-
redisClient.set('foo', 'bar')
|
80 |
-
|
81 |
app = FastAPI(title="homepage-app")
|
82 |
api_app = FastAPI(title="api app")
|
83 |
|
@@ -232,6 +224,14 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
|
232 |
try:
|
233 |
while True:
|
234 |
prompt = await websocket.receive_text()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
QA(
|
236 |
inputs=prompt,
|
237 |
return_only_outputs=True,
|
@@ -240,9 +240,8 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
|
240 |
include_run_info=True
|
241 |
)
|
242 |
|
243 |
-
response = redisClient.get('foo')
|
244 |
|
245 |
-
|
246 |
|
247 |
except WebSocketDisconnect:
|
248 |
print('disconnect')
|
|
|
41 |
DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
|
42 |
RETRIEVER = DB.as_retriever()
|
43 |
|
44 |
+
|
45 |
+
redisClient = redis.Redis(host='localhost', port=6379, db=0)
|
46 |
+
|
47 |
+
|
48 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
49 |
+
def __init__(self):
|
50 |
+
self.message = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
53 |
+
message += token
|
54 |
+
redisClient.publish(f'{kwargs["tags"][0]}', message)
|
55 |
+
|
56 |
+
print(message)
|
57 |
|
58 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
|
59 |
|
|
|
70 |
},
|
71 |
)
|
72 |
|
|
|
|
|
|
|
|
|
73 |
app = FastAPI(title="homepage-app")
|
74 |
api_app = FastAPI(title="api app")
|
75 |
|
|
|
224 |
try:
|
225 |
while True:
|
226 |
prompt = await websocket.receive_text()
|
227 |
+
|
228 |
+
pubsub = redisClient.pubsub()
|
229 |
+
pubsub.subscribe(f'{client_id}')
|
230 |
+
|
231 |
+
for item in pubsub.listen():
|
232 |
+
if item['type'] == 'message':
|
233 |
+
await websocket.send_text(item["data"])
|
234 |
+
|
235 |
QA(
|
236 |
inputs=prompt,
|
237 |
return_only_outputs=True,
|
|
|
240 |
include_run_info=True
|
241 |
)
|
242 |
|
|
|
243 |
|
244 |
+
|
245 |
|
246 |
except WebSocketDisconnect:
|
247 |
print('disconnect')
|