Daniel Marques commited on
Commit
c287747
1 Parent(s): ea6c52b

fix: add websocketClient

Browse files
Files changed (1) hide show
  1. main.py +8 -7
main.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import glob
3
  import shutil
4
  import subprocess
 
 
5
  from typing import Any, Dict, List
6
 
7
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket
@@ -31,11 +33,12 @@ class Predict(BaseModel):
31
  class Delete(BaseModel):
32
  filename: str
33
 
34
- class MyCustomHandler(BaseCallbackHandler):
35
- async def on_llm_new_token(self, token: str, **kwargs) -> None:
36
- global websocketClient
37
 
38
- await websocketClient.send_text(f"Message text was: {token}")
 
 
 
39
 
40
  print(f" token: {token}")
41
 
@@ -246,12 +249,10 @@ async def create_upload_file(file: UploadFile):
246
  @api_app.websocket("/ws")
247
  async def websocket_endpoint(websocket: WebSocket):
248
  global QA
249
- global websocketClient
250
 
251
  await websocket.accept()
252
  while True:
253
- global websocketClient
254
- websocketClient = websocket;
255
 
256
  data = await websocket.receive_text()
257
 
 
2
  import glob
3
  import shutil
4
  import subprocess
5
+ import contextvars
6
+
7
  from typing import Any, Dict, List
8
 
9
  from fastapi import FastAPI, HTTPException, UploadFile, WebSocket
 
33
  class Delete(BaseModel):
34
  filename: str
35
 
36
+ websocketClient = contextvars.ContextVar("websocketClient")
 
 
37
 
38
+ class MyCustomHandler(AsyncCallbackHandler):
39
+ async def on_llm_new_token(self, token: str, **kwargs) -> None:
40
+ ws = websocketClient.get()
41
+ await ws.send_text(f"Message text was: {token}")
42
 
43
  print(f" token: {token}")
44
 
 
249
  @api_app.websocket("/ws")
250
  async def websocket_endpoint(websocket: WebSocket):
251
  global QA
 
252
 
253
  await websocket.accept()
254
  while True:
255
+ websocketClient.set(websocket);
 
256
 
257
  data = await websocket.receive_text()
258