Daniel Marques commited on
Commit
99f6cbc
1 Parent(s): ef75206

fix: add websocket in handlerToken

Browse files
Files changed (1) hide show
  1. main.py +8 -6
main.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import glob
3
  import shutil
4
  import subprocess
5
- import sys
6
 
7
  from typing import Any, Dict, List
8
 
@@ -17,7 +17,6 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
17
  from langchain.prompts import PromptTemplate
18
  from langchain.memory import ConversationBufferMemory
19
  from langchain.callbacks.base import BaseCallbackHandler
20
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
21
 
22
 
23
  from langchain.schema import LLMResult
@@ -59,7 +58,7 @@ DB = Chroma(
59
 
60
  RETRIEVER = DB.as_retriever()
61
 
62
- class MyCustomSyncHandler(StreamingStdOutCallbackHandler):
63
  def __init__(self):
64
  self.end = False
65
  self.websocket = None
@@ -73,8 +72,10 @@ class MyCustomSyncHandler(StreamingStdOutCallbackHandler):
73
  self.end = True
74
 
75
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
 
 
76
  if self.websocket != None:
77
- self.websocket.send_text(token)
78
 
79
  print(token)
80
 
@@ -82,7 +83,7 @@ class MyCustomSyncHandler(StreamingStdOutCallbackHandler):
82
 
83
  handlerToken = MyCustomSyncHandler()
84
 
85
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[])
86
 
87
  template = """You are a helpful, respectful and honest assistant.
88
  Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer.
@@ -101,7 +102,6 @@ QA = RetrievalQA.from_chain_type(
101
  return_source_documents=SHOW_SOURCES,
102
  chain_type_kwargs={
103
  "prompt": QA_CHAIN_PROMPT,
104
- "callbacks": [handlerToken]
105
  },
106
  )
107
 
@@ -260,6 +260,8 @@ async def websocket_endpoint(websocket: WebSocket):
260
  while True:
261
  handlerToken.websocket = websocket
262
 
 
 
263
  data = await websocket.receive_text()
264
  res = QA(data)
265
  print(res)
 
2
  import glob
3
  import shutil
4
  import subprocess
5
+ import asyncio
6
 
7
  from typing import Any, Dict, List
8
 
 
17
  from langchain.prompts import PromptTemplate
18
  from langchain.memory import ConversationBufferMemory
19
  from langchain.callbacks.base import BaseCallbackHandler
 
20
 
21
 
22
  from langchain.schema import LLMResult
 
58
 
59
  RETRIEVER = DB.as_retriever()
60
 
61
+ class MyCustomSyncHandler(BaseCallbackHandler):
62
  def __init__(self):
63
  self.end = False
64
  self.websocket = None
 
72
  self.end = True
73
 
74
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
75
+ print(token)
76
+
77
  if self.websocket != None:
78
+ asyncio.run(self.websocket.send_text(token))
79
 
80
  print(token)
81
 
 
83
 
84
  handlerToken = MyCustomSyncHandler()
85
 
86
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
87
 
88
  template = """You are a helpful, respectful and honest assistant.
89
  Always answer in the most helpful and safe way possible without trying to make up an answer, if you don't know the answer just say "I don't know" and don't share false information or topics that were not provided in your training. Use a maximum of 15 sentences. Your answer should be as concise and clear as possible. Always say "thank you for asking!" at the end of your answer.
 
102
  return_source_documents=SHOW_SOURCES,
103
  chain_type_kwargs={
104
  "prompt": QA_CHAIN_PROMPT,
 
105
  },
106
  )
107
 
 
260
  while True:
261
  handlerToken.websocket = websocket
262
 
263
+ print(handlerToken.websocket)
264
+
265
  data = await websocket.receive_text()
266
  res = QA(data)
267
  print(res)