sonali-tamhankar commited on
Commit
20e1dc2
1 Parent(s): 1128416

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain import PromptTemplate
3
+ from langchain import HuggingFaceHub
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+ import chainlit as cl
9
+
10
+ DB_FAISS_PATH = '.'
11
+
12
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
13
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
14
+
15
+ Context: {context}
16
+ Question: {question}
17
+
18
+ Only return the helpful answer below and nothing else.
19
+ Helpful answer:
20
+ """
21
+
22
+ def set_custom_prompt():
23
+ """
24
+ Prompt template for QA retrieval for each vectorstore
25
+ """
26
+ prompt = PromptTemplate(template=custom_prompt_template,
27
+ input_variables=['context', 'question'])
28
+ return prompt
29
+
30
+ #Retrieval QA Chain
31
+ def retrieval_qa_chain(llm, prompt, db):
32
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
33
+ chain_type='stuff',
34
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
35
+ return_source_documents=True,
36
+ chain_type_kwargs={'prompt': prompt}
37
+ )
38
+ return qa_chain
39
+
40
+ #Loading the model
41
+ def load_llm():
42
+ # Load the locally downloaded model here
43
+ llm = HuggingFaceHub(repo_id = "meta-llama/Llama-2-7b-chat-hf", model_kwargs={"temperature":0.5}) #, "max_length":512})
44
+ return llm
45
+
46
+ #QA Model Function
47
+ def qa_bot():
48
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
49
+ model_kwargs={'device': 'cpu'})
50
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
51
+ llm = load_llm()
52
+ qa_prompt = set_custom_prompt()
53
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
54
+
55
+ return qa
56
+
57
+ #output function
58
+ def final_result(query):
59
+ qa_result = qa_bot()
60
+ response = qa_result({'query': query})
61
+ return response
62
+
63
+ #chainlit code
64
+ @cl.on_chat_start
65
+ async def start():
66
+ chain = qa_bot()
67
+ msg = cl.Message(content="Starting the bot...")
68
+ await msg.send()
69
+ msg.content = "Hi, Welcome to Medical Bot. What is your query?"
70
+ await msg.update()
71
+
72
+ cl.user_session.set("chain", chain)
73
+
74
+ @cl.on_message
75
+ async def main(message):
76
+ chain = cl.user_session.get("chain")
77
+ cb = cl.AsyncLangchainCallbackHandler(
78
+ stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
79
+ )
80
+ cb.answer_reached = True
81
+ res = await chain.acall(message, callbacks=[cb])
82
+ answer = res["result"]
83
+ sources = res["source_documents"]
84
+
85
+ if sources:
86
+ answer += f"\nSources:" + str(sources)
87
+ else:
88
+ answer += "\nNo sources found"
89
+
90
+ await cl.Message(content=answer).send()