SiraH commited on
Commit
c1cc993
1 Parent(s): 4d76407

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import re
4
+ from tempfile import NamedTemporaryFile
5
+ import time
6
+ import pathlib
7
+ #from PyPDF2 import PdfReader
8
+
9
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
10
+ from langchain_community.llms import LlamaCpp
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain.chains import LLMChain
13
+ from langchain.callbacks.manager import CallbackManager
14
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
15
+ from langchain_community.embeddings import HuggingFaceEmbeddings
16
+ from langchain.chains import RetrievalQA
17
+ from langchain_community.vectorstores import FAISS
18
+ from langchain.chains.question_answering import load_qa_chain
19
+ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
20
+ from langchain_community.document_loaders import TextLoader
21
+ from langchain_community.document_loaders import PyPDFLoader
22
+ from langchain.memory import ConversationBufferWindowMemory
23
+ from langchain.memory import ConversationBufferMemory
24
+ from langchain.chains import ConversationalRetrievalChain
25
+ from langchain.memory.chat_message_histories.streamlit import StreamlitChatMessageHistory
26
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
27
+ from langchain_community.llms import HuggingFaceHub
28
+
29
+
30
+ # sidebar contents
31
+ with st.sidebar:
32
+ st.title('DOC-QA DEMO ')
33
+ st.markdown('''
34
+ ## About
35
+ Detail this application:
36
+ - LLM model: Phi-2-4bit
37
+ - Hardware resource : Huggingface space 8 vCPU 32 GB
38
+ ''')
39
+
40
+ def split_docs(documents,chunk_size=1000):
41
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=200)
42
+ sp_docs = text_splitter.split_documents(documents)
43
+ return sp_docs
44
+
45
+ @st.cache_resource
46
+ def load_llama2_llamaCpp():
47
+ core_model_name = "phi-2.Q4_K_M.gguf"
48
+ #n_gpu_layers = 32
49
+ n_batch = 512
50
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
51
+ llm = LlamaCpp(
52
+ model_path=core_model_name,
53
+ #n_gpu_layers=n_gpu_layers,
54
+ n_batch=n_batch,
55
+ callback_manager=callback_manager,
56
+ verbose=True,n_ctx = 4096, temperature = 0.1, max_tokens = 128
57
+ )
58
+ return llm
59
+
60
+ def set_custom_prompt():
61
+ custom_prompt_template = """ Use the following pieces of information from context to answer the user's question.
62
+ If you don't know the answer, don't try to make up an answer.
63
+ Context : {context}
64
+ Question : {question}
65
+ Please answer the questions in a concise and straightforward manner.
66
+ Helpful answer:
67
+ """
68
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context',
69
+ 'question',
70
+ ])
71
+ return prompt
72
+
73
+
74
+ @st.cache_resource
75
+ def load_embeddings():
76
+ embeddings = HuggingFaceEmbeddings(model_name = "thenlper/gte-base",
77
+ model_kwargs = {'device': 'cpu'})
78
+ return embeddings
79
+
80
+
81
+
82
+ def main():
83
+ data = []
84
+ sp_docs_list = []
85
+ msgs = StreamlitChatMessageHistory(key="langchain_messages")
86
+ print(msgs)
87
+ if "messages" not in st.session_state:
88
+ st.session_state.messages = []
89
+
90
+ # repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
91
+ # llm = HuggingFaceHub(
92
+ # repo_id=repo_id, model_kwargs={"temperature": 0.1, "max_length": 128})
93
+
94
+ llm = load_llama2_llamaCpp()
95
+ qa_prompt = set_custom_prompt()
96
+ embeddings = load_embeddings()
97
+
98
+ uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf")
99
+ if uploaded_file is not None :
100
+ with NamedTemporaryFile(dir='PDF', suffix='.pdf', delete=False) as f:
101
+ f.write(uploaded_file.getbuffer())
102
+ print(f.name)
103
+ #filename = f.name
104
+ loader = PyPDFLoader(f.name)
105
+ pages = loader.load_and_split()
106
+ data.extend(pages)
107
+ #st.write(pages)
108
+ f.close()
109
+ os.unlink(f.name)
110
+ os.path.exists(f.name)
111
+ if len(data) > 0 :
112
+ embeddings = load_embeddings()
113
+ sp_docs = split_docs(documents = data)
114
+ st.write(f"This document have {len(sp_docs)} chunks")
115
+ sp_docs_list.extend(sp_docs)
116
+ try:
117
+ db = FAISS.from_documents(sp_docs_list, embeddings)
118
+ memory = ConversationBufferMemory(memory_key="chat_history",
119
+ return_messages=True,
120
+ input_key="query",
121
+ output_key="result")
122
+ qa_chain = RetrievalQA.from_chain_type(
123
+ llm = llm,
124
+ chain_type = "stuff",
125
+ retriever = db.as_retriever(search_kwargs = {'k':3}),
126
+ return_source_documents = True,
127
+ memory = memory,
128
+ chain_type_kwargs = {"prompt":qa_prompt})
129
+ for message in st.session_state.messages:
130
+ with st.chat_message(message["role"]):
131
+ st.markdown(message["content"])
132
+
133
+ # Accept user input
134
+ if query := st.chat_input("What is up?"):
135
+ # Display user message in chat message container
136
+ with st.chat_message("user"):
137
+ st.markdown(query)
138
+ # Add user message to chat history
139
+ st.session_state.messages.append({"role": "user", "content": query})
140
+
141
+ start = time.time()
142
+
143
+ response = qa_chain({'query': query})
144
+
145
+ with st.chat_message("assistant"):
146
+ st.markdown(response['result'])
147
+
148
+ end = time.time()
149
+ st.write("Respone time:",int(end-start),"sec")
150
+ print(response)
151
+
152
+ # Add assistant response to chat history
153
+ st.session_state.messages.append({"role": "assistant", "content": response['result']})
154
+
155
+ with st.expander("See the related documents"):
156
+ for count, url in enumerate(response['source_documents']):
157
+ st.write(str(count+1)+":", url)
158
+
159
+ clear_button = st.button("Start new convo")
160
+ if clear_button :
161
+ st.session_state.messages = []
162
+ qa_chain.memory.chat_memory.clear()
163
+
164
+ except:
165
+ st.write("Plaese upload your pdf file.")
166
+
167
+
168
+ if __name__ == '__main__':
169
+ main()