Daniel Marques commited on
Commit
a638c0e
1 Parent(s): 86dcbf6

fix: add type Union

Browse files
Files changed (1) hide show
  1. main.py +10 -10
main.py CHANGED
@@ -43,33 +43,34 @@ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, mode
43
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
44
  RETRIEVER = DB.as_retriever()
45
 
46
-
47
  redisClient = redis.Redis(host='localhost', port=6379, db=0)
48
 
49
-
50
  class MyCustomSyncHandler(BaseCallbackHandler):
51
- def __init__(self):
52
  self.message = ''
 
53
 
54
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
55
  self.message += token
56
- redisClient.publish(f'{kwargs["tags"][0]}', self.message)
57
 
58
  def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
59
  self.message = "end"
60
- redisClient.publish(f'{kwargs["tags"][0]}', self.message)
61
 
62
  def on_llm_error(
63
  self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
64
  ) -> Any:
65
  self.message = "end"
66
- redisClient.publish(f'{kwargs["tags"][0]}', self.message)
67
 
68
  def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
69
  self.message = "end"
70
- redisClient.publish(f'{kwargs["tags"][0]}', self.message)
71
 
72
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True)
 
 
73
 
74
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
75
 
@@ -84,7 +85,6 @@ QA = RetrievalQA.from_chain_type(
84
  },
85
  )
86
 
87
-
88
  app = FastAPI(title="homepage-app")
89
  api_app = FastAPI(title="api app")
90
 
@@ -242,7 +242,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
242
  pubsub.subscribe(f'{client_id}')
243
 
244
  with concurrent.futures.ThreadPoolExecutor() as executor:
245
- executor.submit(QA(inputs=prompt, return_only_outputs=True, callbacks=[MyCustomSyncHandler()], tags=f'{client_id}', include_run_info=True))
246
 
247
  i = 0
248
  for item in pubsub.listen():
 
43
  DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
44
  RETRIEVER = DB.as_retriever()
45
 
 
46
  redisClient = redis.Redis(host='localhost', port=6379, db=0)
47
 
 
48
  class MyCustomSyncHandler(BaseCallbackHandler):
49
+ def __init__(self, redisClient):
50
  self.message = ''
51
+ self.redisClient = redisClient
52
 
53
  def on_llm_new_token(self, token: str, **kwargs) -> Any:
54
  self.message += token
55
+ self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
56
 
57
  def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
58
  self.message = "end"
59
+ self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
60
 
61
  def on_llm_error(
62
  self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
63
  ) -> Any:
64
  self.message = "end"
65
+ self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
66
 
67
  def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
68
  self.message = "end"
69
+ self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
70
 
71
+ handleCallback = MyCustomSyncHandler(redisClient)
72
+
73
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handleCallback])
74
 
75
  prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
76
 
 
85
  },
86
  )
87
 
 
88
  app = FastAPI(title="homepage-app")
89
  api_app = FastAPI(title="api app")
90
 
 
242
  pubsub.subscribe(f'{client_id}')
243
 
244
  with concurrent.futures.ThreadPoolExecutor() as executor:
245
+ executor.submit(QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True))
246
 
247
  i = 0
248
  for item in pubsub.listen():