Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
·
8a26b55
1
Parent(s):
2453cc0
fix: add types
Browse files
main.py
CHANGED
@@ -14,7 +14,7 @@ from langchain.chains import RetrievalQA
|
|
14 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
from langchain.memory import ConversationBufferMemory
|
17 |
-
from langchain.callbacks.base import
|
18 |
from langchain.schema import LLMResult
|
19 |
|
20 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
@@ -31,7 +31,7 @@ class Predict(BaseModel):
|
|
31 |
class Delete(BaseModel):
|
32 |
filename: str
|
33 |
|
34 |
-
class
|
35 |
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
36 |
print(f" token: {token}")
|
37 |
|
@@ -44,6 +44,19 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
|
|
44 |
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
45 |
print("finish")
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# if torch.backends.mps.is_available():
|
48 |
# DEVICE_TYPE = "mps"
|
49 |
# elif torch.cuda.is_available():
|
@@ -65,7 +78,7 @@ DB = Chroma(
|
|
65 |
|
66 |
RETRIEVER = DB.as_retriever()
|
67 |
|
68 |
-
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[
|
69 |
|
70 |
template = """you are a helpful, respectful and honest assistant. When answering questions, you should only use the documents provided.
|
71 |
You should only answer the topics that appear in these documents.
|
@@ -87,7 +100,8 @@ QA = RetrievalQA.from_chain_type(
|
|
87 |
return_source_documents=SHOW_SOURCES,
|
88 |
chain_type_kwargs={
|
89 |
"prompt": QA_CHAIN_PROMPT,
|
90 |
-
"memory": memory
|
|
|
91 |
},
|
92 |
)
|
93 |
|
@@ -179,8 +193,6 @@ async def predict(data: Predict):
|
|
179 |
if user_prompt:
|
180 |
res = QA(user_prompt)
|
181 |
|
182 |
-
print(res)
|
183 |
-
|
184 |
answer, docs = res["result"], res["source_documents"]
|
185 |
|
186 |
prompt_response_dict = {
|
@@ -194,17 +206,6 @@ async def predict(data: Predict):
|
|
194 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
195 |
)
|
196 |
|
197 |
-
qa_chain_response = res.stream(
|
198 |
-
{"query": user_prompt},
|
199 |
-
)
|
200 |
-
|
201 |
-
print(f"{qa_chain_response} stream")
|
202 |
-
|
203 |
-
# generated_text = ""
|
204 |
-
# for new_text in STREAMER:
|
205 |
-
# generated_text += new_text
|
206 |
-
# print(generated_text)
|
207 |
-
|
208 |
return {"response": prompt_response_dict}
|
209 |
else:
|
210 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|
@@ -254,4 +255,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
254 |
await websocket.accept()
|
255 |
while True:
|
256 |
data = await websocket.receive_text()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
await websocket.send_text(f"Message text was: {data}")
|
|
|
14 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
15 |
from langchain.prompts import PromptTemplate
|
16 |
from langchain.memory import ConversationBufferMemory
|
17 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
18 |
from langchain.schema import LLMResult
|
19 |
|
20 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
31 |
class Delete(BaseModel):
|
32 |
filename: str
|
33 |
|
34 |
+
class MyCustomHandler(BaseCallbackHandler):
|
35 |
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
36 |
print(f" token: {token}")
|
37 |
|
|
|
44 |
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
45 |
print("finish")
|
46 |
|
47 |
+
class CustomHandler(BaseCallbackHandler):
|
48 |
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
49 |
+
print(f" CustomHandler: {token}")
|
50 |
+
|
51 |
+
async def on_llm_start(
|
52 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
53 |
+
) -> None:
|
54 |
+
class_name = serialized["name"]
|
55 |
+
print("CustomHandler start")
|
56 |
+
|
57 |
+
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
58 |
+
print("CustomHandler finish")
|
59 |
+
|
60 |
# if torch.backends.mps.is_available():
|
61 |
# DEVICE_TYPE = "mps"
|
62 |
# elif torch.cuda.is_available():
|
|
|
78 |
|
79 |
RETRIEVER = DB.as_retriever()
|
80 |
|
81 |
+
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[MyCustomHandler()])
|
82 |
|
83 |
template = """you are a helpful, respectful and honest assistant. When answering questions, you should only use the documents provided.
|
84 |
You should only answer the topics that appear in these documents.
|
|
|
100 |
return_source_documents=SHOW_SOURCES,
|
101 |
chain_type_kwargs={
|
102 |
"prompt": QA_CHAIN_PROMPT,
|
103 |
+
"memory": memory,
|
104 |
+
"callbacks": [CustomHandler()]
|
105 |
},
|
106 |
)
|
107 |
|
|
|
193 |
if user_prompt:
|
194 |
res = QA(user_prompt)
|
195 |
|
|
|
|
|
196 |
answer, docs = res["result"], res["source_documents"]
|
197 |
|
198 |
prompt_response_dict = {
|
|
|
206 |
(os.path.basename(str(document.metadata["source"])), str(document.page_content))
|
207 |
)
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
return {"response": prompt_response_dict}
|
210 |
else:
|
211 |
raise HTTPException(status_code=400, detail="Prompt Incorrect")
|
|
|
255 |
await websocket.accept()
|
256 |
while True:
|
257 |
data = await websocket.receive_text()
|
258 |
+
|
259 |
+
res = QA(data)
|
260 |
+
|
261 |
+
qa_chain_response = res.stream(
|
262 |
+
{"query": data},
|
263 |
+
)
|
264 |
+
|
265 |
+
print(f"{qa_chain_response} stream")
|
266 |
+
|
267 |
await websocket.send_text(f"Message text was: {data}")
|