ThisIs-Developer commited on
Commit
da12358
1 Parent(s): 0dcdc5b

Upload model.py

Browse files
Files changed (1) hide show
  1. Chainlit/model.py +98 -0
Chainlit/model.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain import PromptTemplate
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+ import chainlit as cl
9
+
10
+ DB_FAISS_PATH = 'vectorstores/db_faiss'
11
+
12
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
13
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
14
+
15
+ Context: {context}
16
+ Question: {question}
17
+
18
+ Only return the helpful answer below and nothing else.
19
+ Helpful answer:
20
+ """
21
+
22
+ def set_custom_prompt():
23
+ """
24
+ Prompt template for QA retrieval for each vectorstore
25
+ """
26
+ prompt = PromptTemplate(template=custom_prompt_template,
27
+ input_variables=['context', 'question'])
28
+ return prompt
29
+
30
+ # Retrieval QA Chain
31
+ def retrieval_qa_chain(llm, prompt, db):
32
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
33
+ chain_type='stuff',
34
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
35
+ return_source_documents=True,
36
+ chain_type_kwargs={'prompt': prompt}
37
+ )
38
+ return qa_chain
39
+
40
+ # Loading the model
41
+ def load_llm():
42
+ # Load the locally downloaded model here
43
+ llm = CTransformers(
44
+ model="TheBloke/Llama-2-7B-Chat-GGML",
45
+ model_type="llama",
46
+ max_new_tokens=512,
47
+ temperature=0.5
48
+ )
49
+ return llm
50
+
51
+ # QA Model Function
52
+ async def qa_bot():
53
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
54
+ model_kwargs={'device': 'cpu'})
55
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
56
+ llm = load_llm()
57
+ qa_prompt = set_custom_prompt()
58
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
59
+
60
+ return qa
61
+
62
+ # Output function
63
+ async def final_result(query):
64
+ qa_result = await qa_bot()
65
+ response = await qa_result({'query': query})
66
+ return response
67
+
68
+ # chainlit code
69
+ @cl.on_chat_start
70
+ async def start():
71
+ chain = await qa_bot()
72
+ # msg = cl.Message(content="Starting the bot...")
73
+ # await msg.send()
74
+ # msg.content = "Hi, Welcome to Medical Bot. What is your query?"
75
+ # await msg.update()
76
+
77
+ cl.user_session.set("chain", chain)
78
+
79
+ @cl.on_message
80
+ async def main(message):
81
+ chain = cl.user_session.get("chain")
82
+ cb = cl.AsyncLangchainCallbackHandler(
83
+ stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
84
+ )
85
+ cb.answer_reached = True
86
+ res = await chain.acall(message.content, callbacks=[cb])
87
+ answer = res["result"]
88
+ sources = res["source_documents"]
89
+
90
+ if sources:
91
+ answer += f"\nSources:" + str(sources)
92
+ else:
93
+ answer += "\nNo sources found"
94
+
95
+ await cl.Message(content=answer).send()
96
+
97
+ if __name__ == "__main__":
98
+ asyncio.run(cl.main())