Sakil commited on
Commit
5dd447e
1 Parent(s): 477e44b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain import PromptTemplate
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+ import chainlit as cl
9
+
10
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
11
+
12
+ custom_prompt_template = """Use the following pieces of information to answer the user's question.
13
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
14
+ Context: {context}
15
+ Question: {question}
16
+ Only return the helpful answer below and nothing else.
17
+ Helpful answer:
18
+ """
19
+
20
+ def set_custom_prompt():
21
+ """
22
+ Prompt template for QA retrieval for each vectorstore
23
+ """
24
+ prompt = PromptTemplate(template=custom_prompt_template,
25
+ input_variables=['context', 'question'])
26
+ return prompt
27
+
28
+ # Retrieval QA Chain
29
+ def retrieval_qa_chain(llm, prompt, db):
30
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
31
+ chain_type='stuff',
32
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
33
+ return_source_documents=True,
34
+ chain_type_kwargs={'prompt': prompt}
35
+ )
36
+ return qa_chain
37
+
38
+ # Loading the model
39
+ def load_llm():
40
+ # Load the locally downloaded model here
41
+ llm = CTransformers(
42
+ model="llama-2-7b-chat.ggmlv3.q8_0.bin",
43
+ model_type="llama",
44
+ max_new_tokens=512,
45
+ temperature=0.5
46
+ )
47
+ return llm
48
+
49
+ # QA Model Function
50
+ def qa_bot():
51
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
52
+ model_kwargs={'device': 'cpu'})
53
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
54
+ llm = load_llm()
55
+ qa_prompt = set_custom_prompt()
56
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
57
+
58
+ return qa
59
+
60
+ def main():
61
+ st.title("AI ChatBot LLM")
62
+
63
+ qa_result = qa_bot()
64
+
65
+ user_input = st.text_input("Enter your question:")
66
+
67
+ if st.button("Ask"):
68
+ response = qa_result({'query': user_input})
69
+ answer = response["result"]
70
+ sources = response["source_documents"]
71
+
72
+ st.write("Answer:", answer)
73
+ if sources:
74
+ st.write("Sources:", sources)
75
+ else:
76
+ st.write("No sources found")
77
+
78
+ if __name__ == "__main__":
79
+ main()