Daniel Marques commited on
Commit
f82125b
1 Parent(s): ea3e72d

fix: add handle

Browse files
Files changed (1) hide show
  1. main.py +8 -22
main.py CHANGED
@@ -38,26 +38,12 @@ websocketClient = contextvars.ContextVar("websocketClient")
38
 
39
  class MyCustomSyncHandler(BaseCallbackHandler):
40
  def on_llm_new_token(self, token: str, **kwargs) -> None:
41
- print(f"Sync handler being called in a `thread_pool_executor`: token: {token}")
42
-
43
-
44
- class MyCustomAsyncHandler(AsyncCallbackHandler):
45
- """Async callback handler that can be used to handle callbacks from langchain."""
46
-
47
- async def on_llm_start(
48
- self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
49
- ) -> None:
50
- """Run when chain starts running."""
51
- print("zzzz....")
52
- await asyncio.sleep(0.3)
53
- class_name = serialized["name"]
54
- print("Hi! I just woke up. Your llm is starting")
55
-
56
- async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
57
- """Run when chain ends running."""
58
- print("zzzz....")
59
- await asyncio.sleep(0.3)
60
- print("Hi! I just woke up. Your llm is ending")
61
  # if torch.backends.mps.is_available():
62
  # DEVICE_TYPE = "mps"
63
  # elif torch.cuda.is_available():
@@ -79,7 +65,7 @@ DB = Chroma(
79
 
80
  RETRIEVER = DB.as_retriever()
81
 
82
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[MyCustomSyncHandler(), MyCustomAsyncHandler()])
83
 
84
  template = """you are a helpful, respectful and honest assistant.
85
  Your name is Katara llma. You should only use the source documents provided to answer the questions.
@@ -263,6 +249,6 @@ async def websocket_endpoint(websocket: WebSocket):
263
 
264
  data = await websocket.receive_text()
265
 
266
- QA(data)
267
 
268
  await websocket.send_text(f"Message text was:")
 
38
 
39
  class MyCustomSyncHandler(BaseCallbackHandler):
40
  def on_llm_new_token(self, token: str, **kwargs) -> None:
41
+ ws = websocketClient.get()
42
+
43
+ ws.receive_text(token)
44
+
45
+ print(f"token: {token}")
46
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # if torch.backends.mps.is_available():
48
  # DEVICE_TYPE = "mps"
49
  # elif torch.cuda.is_available():
 
65
 
66
  RETRIEVER = DB.as_retriever()
67
 
68
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[MyCustomSyncHandler()])
69
 
70
  template = """you are a helpful, respectful and honest assistant.
71
  Your name is Katara llma. You should only use the source documents provided to answer the questions.
 
249
 
250
  data = await websocket.receive_text()
251
 
252
+ res = QA(data)
253
 
254
  await websocket.send_text(f"Message text was:")