moriire commited on
Commit
40c5643
·
verified ·
1 Parent(s): 62785d4

Create vmodel

Browse files
Files changed (1) hide show
  1. app/vmodel +94 -0
app/vmodel ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain.prompts import PromptTemplate
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_community.llms import CTransformers
6
+ from langchain.chains import RetrievalQA
7
+ import chainlit as cl
8
+
9
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
10
+
11
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
12
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
13
+
14
+ Context: {context}
15
+ Question: {question}
16
+
17
+ Only return the helpful answer below and nothing else.
18
+ Helpful answer:
19
+ """
20
+
21
+ def set_custom_prompt():
22
+ """
23
+ Prompt template for QA retrieval for each vectorstore
24
+ """
25
+ prompt = PromptTemplate(template=custom_prompt_template,
26
+ input_variables=['context', 'question'])
27
+ return prompt
28
+
29
+ #Retrieval QA Chain
30
+ def retrieval_qa_chain(llm, prompt, db):
31
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
32
+ chain_type='stuff',
33
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
34
+ return_source_documents=True,
35
+ chain_type_kwargs={'prompt': prompt}
36
+ )
37
+ return qa_chain
38
+
39
+ #Loading the model
40
+ def load_llm():
41
+ # Load the locally downloaded model here
42
+ llm = CTransformers(
43
+ model = "TheBloke/Llama-2-7B-Chat-GGML",
44
+ model_type="llama",
45
+ max_new_tokens = 512,
46
+ temperature = 0.5
47
+ )
48
+ return llm
49
+
50
+ #QA Model Function
51
+ def qa_bot():
52
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
53
+ model_kwargs={'device': 'cpu'})
54
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
55
+ llm = load_llm()
56
+ qa_prompt = set_custom_prompt()
57
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
58
+
59
+ return qa
60
+
61
+ #output function
62
+ def final_result(query):
63
+ qa_result = qa_bot()
64
+ response = qa_result({'query': query})
65
+ return response
66
+
67
+ #chainlit code
68
+ @cl.on_chat_start
69
+ async def start():
70
+ chain = qa_bot()
71
+ msg = cl.Message(content="Starting the bot...")
72
+ await msg.send()
73
+ msg.content = "Hi, Welcome to Medical Bot. What is your query?"
74
+ await msg.update()
75
+
76
+ cl.user_session.set("chain", chain)
77
+
78
+ @cl.on_message
79
+ async def main(message: cl.Message):
80
+ chain = cl.user_session.get("chain")
81
+ cb = cl.AsyncLangchainCallbackHandler(
82
+ stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
83
+ )
84
+ cb.answer_reached = True
85
+ res = await chain.acall(message.content, callbacks=[cb])
86
+ answer = res["result"]
87
+ sources = res["source_documents"]
88
+
89
+ if sources:
90
+ answer += f"\nSources:" + str(sources)
91
+ else:
92
+ answer += "\nNo sources found"
93
+
94
+ await cl.Message(content=answer).send()