Daniel Marques commited on
Commit
8a26b55
1 Parent(s): 2453cc0

fix: add types

Browse files
Files changed (1) hide show
  1. main.py +27 -17
main.py CHANGED
@@ -14,7 +14,7 @@ from langchain.chains import RetrievalQA
14
  from langchain.embeddings import HuggingFaceInstructEmbeddings
15
  from langchain.prompts import PromptTemplate
16
  from langchain.memory import ConversationBufferMemory
17
- from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
18
  from langchain.schema import LLMResult
19
 
20
  # from langchain.embeddings import HuggingFaceEmbeddings
@@ -31,7 +31,7 @@ class Predict(BaseModel):
31
  class Delete(BaseModel):
32
  filename: str
33
 
34
- class MyCustomAsyncHandler(AsyncCallbackHandler):
35
  def on_llm_new_token(self, token: str, **kwargs) -> None:
36
  print(f" token: {token}")
37
 
@@ -44,6 +44,19 @@ class MyCustomAsyncHandler(AsyncCallbackHandler):
44
  async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
45
  print("finish")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # if torch.backends.mps.is_available():
48
  # DEVICE_TYPE = "mps"
49
  # elif torch.cuda.is_available():
@@ -65,7 +78,7 @@ DB = Chroma(
65
 
66
  RETRIEVER = DB.as_retriever()
67
 
68
- LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[MyCustomAsyncHandler()])
69
 
70
  template = """you are a helpful, respectful and honest assistant. When answering questions, you should only use the documents provided.
71
  You should only answer the topics that appear in these documents.
@@ -87,7 +100,8 @@ QA = RetrievalQA.from_chain_type(
87
  return_source_documents=SHOW_SOURCES,
88
  chain_type_kwargs={
89
  "prompt": QA_CHAIN_PROMPT,
90
- "memory": memory
 
91
  },
92
  )
93
 
@@ -179,8 +193,6 @@ async def predict(data: Predict):
179
  if user_prompt:
180
  res = QA(user_prompt)
181
 
182
- print(res)
183
-
184
  answer, docs = res["result"], res["source_documents"]
185
 
186
  prompt_response_dict = {
@@ -194,17 +206,6 @@ async def predict(data: Predict):
194
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
195
  )
196
 
197
- qa_chain_response = res.stream(
198
- {"query": user_prompt},
199
- )
200
-
201
- print(f"{qa_chain_response} stream")
202
-
203
- # generated_text = ""
204
- # for new_text in STREAMER:
205
- # generated_text += new_text
206
- # print(generated_text)
207
-
208
  return {"response": prompt_response_dict}
209
  else:
210
  raise HTTPException(status_code=400, detail="Prompt Incorrect")
@@ -254,4 +255,13 @@ async def websocket_endpoint(websocket: WebSocket):
254
  await websocket.accept()
255
  while True:
256
  data = await websocket.receive_text()
 
 
 
 
 
 
 
 
 
257
  await websocket.send_text(f"Message text was: {data}")
 
14
  from langchain.embeddings import HuggingFaceInstructEmbeddings
15
  from langchain.prompts import PromptTemplate
16
  from langchain.memory import ConversationBufferMemory
17
+ from langchain.callbacks.base import BaseCallbackHandler
18
  from langchain.schema import LLMResult
19
 
20
  # from langchain.embeddings import HuggingFaceEmbeddings
 
31
  class Delete(BaseModel):
32
  filename: str
33
 
34
+ class MyCustomHandler(BaseCallbackHandler):
35
  def on_llm_new_token(self, token: str, **kwargs) -> None:
36
  print(f" token: {token}")
37
 
 
44
  async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
45
  print("finish")
46
 
47
+ class CustomHandler(BaseCallbackHandler):
48
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
49
+ print(f" CustomHandler: {token}")
50
+
51
+ async def on_llm_start(
52
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
53
+ ) -> None:
54
+ class_name = serialized["name"]
55
+ print("CustomHandler start")
56
+
57
+ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
58
+ print("CustomHandler finish")
59
+
60
  # if torch.backends.mps.is_available():
61
  # DEVICE_TYPE = "mps"
62
  # elif torch.cuda.is_available():
 
78
 
79
  RETRIEVER = DB.as_retriever()
80
 
81
+ LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[MyCustomHandler()])
82
 
83
  template = """you are a helpful, respectful and honest assistant. When answering questions, you should only use the documents provided.
84
  You should only answer the topics that appear in these documents.
 
100
  return_source_documents=SHOW_SOURCES,
101
  chain_type_kwargs={
102
  "prompt": QA_CHAIN_PROMPT,
103
+ "memory": memory,
104
+ "callbacks": [CustomHandler()]
105
  },
106
  )
107
 
 
193
  if user_prompt:
194
  res = QA(user_prompt)
195
 
 
 
196
  answer, docs = res["result"], res["source_documents"]
197
 
198
  prompt_response_dict = {
 
206
  (os.path.basename(str(document.metadata["source"])), str(document.page_content))
207
  )
208
 
 
 
 
 
 
 
 
 
 
 
 
209
  return {"response": prompt_response_dict}
210
  else:
211
  raise HTTPException(status_code=400, detail="Prompt Incorrect")
 
255
  await websocket.accept()
256
  while True:
257
  data = await websocket.receive_text()
258
+
259
+ res = QA(data)
260
+
261
+ qa_chain_response = res.stream(
262
+ {"query": data},
263
+ )
264
+
265
+ print(f"{qa_chain_response} stream")
266
+
267
  await websocket.send_text(f"Message text was: {data}")