Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
•
a638c0e
1
Parent(s):
86dcbf6
fix: add type Union
Browse files
main.py
CHANGED
@@ -43,33 +43,34 @@ EMBEDDINGS = HuggingFaceInstructEmbeddings(model_name=EMBEDDING_MODEL_NAME, mode
|
|
43 |
DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
|
44 |
RETRIEVER = DB.as_retriever()
|
45 |
|
46 |
-
|
47 |
redisClient = redis.Redis(host='localhost', port=6379, db=0)
|
48 |
|
49 |
-
|
50 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
51 |
-
def __init__(self):
|
52 |
self.message = ''
|
|
|
53 |
|
54 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
55 |
self.message += token
|
56 |
-
redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
57 |
|
58 |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
59 |
self.message = "end"
|
60 |
-
redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
61 |
|
62 |
def on_llm_error(
|
63 |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
64 |
) -> Any:
|
65 |
self.message = "end"
|
66 |
-
redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
67 |
|
68 |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
69 |
self.message = "end"
|
70 |
-
redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
71 |
|
72 |
-
|
|
|
|
|
73 |
|
74 |
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
|
75 |
|
@@ -84,7 +85,6 @@ QA = RetrievalQA.from_chain_type(
|
|
84 |
},
|
85 |
)
|
86 |
|
87 |
-
|
88 |
app = FastAPI(title="homepage-app")
|
89 |
api_app = FastAPI(title="api app")
|
90 |
|
@@ -242,7 +242,7 @@ async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
|
242 |
pubsub.subscribe(f'{client_id}')
|
243 |
|
244 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
245 |
-
executor.submit(QA(inputs=prompt, return_only_outputs=True,
|
246 |
|
247 |
i = 0
|
248 |
for item in pubsub.listen():
|
|
|
43 |
DB = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=EMBEDDINGS, client_settings=CHROMA_SETTINGS)
|
44 |
RETRIEVER = DB.as_retriever()
|
45 |
|
|
|
46 |
redisClient = redis.Redis(host='localhost', port=6379, db=0)
|
47 |
|
|
|
48 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
49 |
+
def __init__(self, redisClient):
|
50 |
self.message = ''
|
51 |
+
self.redisClient = redisClient
|
52 |
|
53 |
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
54 |
self.message += token
|
55 |
+
self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
56 |
|
57 |
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
58 |
self.message = "end"
|
59 |
+
self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
60 |
|
61 |
def on_llm_error(
|
62 |
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
63 |
) -> Any:
|
64 |
self.message = "end"
|
65 |
+
self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
66 |
|
67 |
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
68 |
self.message = "end"
|
69 |
+
self.redisClient.publish(f'{kwargs["tags"][0]}', self.message)
|
70 |
|
71 |
+
handleCallback = MyCustomSyncHandler(redisClient)
|
72 |
+
|
73 |
+
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handleCallback])
|
74 |
|
75 |
prompt, memory = get_prompt_template(promptTemplate_type="llama", history=True)
|
76 |
|
|
|
85 |
},
|
86 |
)
|
87 |
|
|
|
88 |
app = FastAPI(title="homepage-app")
|
89 |
api_app = FastAPI(title="api app")
|
90 |
|
|
|
242 |
pubsub.subscribe(f'{client_id}')
|
243 |
|
244 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
245 |
+
executor.submit(QA(inputs=prompt, return_only_outputs=True, tags=f'{client_id}', include_run_info=True))
|
246 |
|
247 |
i = 0
|
248 |
for item in pubsub.listen():
|