Elia Wäfler commited on
Commit
2c73bfa
1 Parent(s): 951e0ac

renamed frontend to app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dotenv import load_dotenv
3
+ from PyPDF2 import PdfReader
4
+ from langchain import embeddings
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.vectorstores import faiss
9
+ from langchain.chat_models import ChatOpenAI
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.chains import ConversationalRetrievalChain
12
+ from html_templates import css, bot_template, user_template
13
+ from langchain.llms import HuggingFaceHub
14
+ import os
15
+ import pickle
16
+ from datetime import datetime
17
+
18
+
19
+ def get_pdf_text(pdf_docs):
20
+ text = ""
21
+ for pdf in pdf_docs:
22
+ pdf_reader = PdfReader(pdf)
23
+ for page in pdf_reader.pages:
24
+ text += page.extract_text()
25
+ return text
26
+
27
+
28
+ def get_text_chunks(text):
29
+ text_splitter = CharacterTextSplitter(
30
+ separator="\n",
31
+ chunk_size=1000,
32
+ chunk_overlap=200,
33
+ length_function=len
34
+ )
35
+ chunks = text_splitter.split_text(text)
36
+ return chunks
37
+
38
+
39
+ def get_vectorstore(text_chunks):
40
+ embeddings = OpenAIEmbeddings()
41
+ # embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
42
+ vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
43
+ return vectorstore
44
+
45
+
46
+ def get_conversation_chain(vectorstore):
47
+ llm = ChatOpenAI()
48
+ # llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.5, "max_length":512})
49
+
50
+ memory = ConversationBufferMemory(
51
+ memory_key='chat_history', return_messages=True)
52
+ conversation_chain = ConversationalRetrievalChain.from_llm(
53
+ llm=llm,
54
+ retriever=vectorstore.as_retriever(),
55
+ memory=memory
56
+ )
57
+ return conversation_chain
58
+
59
+
60
+ def handle_userinput(user_question):
61
+ response = st.session_state.conversation({'question': user_question})
62
+ st.session_state.chat_history = response['chat_history']
63
+
64
+ for i, message in enumerate(st.session_state.chat_history):
65
+ # Display user message
66
+ if i % 2 == 0:
67
+ st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
68
+ else:
69
+ print(message)
70
+ # Display AI response
71
+ st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
72
+ # Display source document information if available in the message
73
+ if hasattr(message, 'source') and message.source:
74
+ st.write(f"Source Document: {message.source}", unsafe_allow_html=True)
75
+
76
+
77
+ def safe_vec_store():
78
+ os.makedirs('vectorstore', exist_ok=True)
79
+ filename = 'vectores' + datetime.now().strftime('%Y%m%d%H%M') + '.pkl'
80
+ file_path = os.path.join('vectorstore', filename)
81
+ vector_store = st.session_state.vectorstore
82
+
83
+ # Serialize and save the entire FAISS object using pickle
84
+ with open(file_path, 'wb') as f:
85
+ pickle.dump(vector_store, f)
86
+
87
+
88
+
89
+ def main():
90
+ load_dotenv()
91
+ st.set_page_config(page_title="Doc Verify RAG", page_icon=":hospital:")
92
+ st.write(css, unsafe_allow_html=True)
93
+
94
+ st.subheader("Your documents")
95
+ pdf_docs = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
96
+ filenames = [file.name for file in pdf_docs if file is not None]
97
+
98
+ if st.button("Process"):
99
+ with st.spinner("Processing"):
100
+ loaded_vec_store = None
101
+ for filename in filenames:
102
+ if ".pkl" in filename:
103
+ file_path = os.path.join('vectorstore', filename)
104
+ with open(file_path, 'rb') as f:
105
+ loaded_vec_store = pickle.load(f)
106
+ raw_text = get_pdf_text(pdf_docs)
107
+ text_chunks = get_text_chunks(raw_text)
108
+ vec = get_vectorstore(text_chunks)
109
+ if loaded_vec_store:
110
+ vec.merge_from(loaded_vec_store)
111
+ st.warning("loaded vectorstore")
112
+ if "vectorstore" in st.session_state:
113
+ vec.merge_from(st.session_state.vectorstore)
114
+ st.warning("merged to existing")
115
+ st.session_state.vectorstore = vec
116
+ st.session_state.conversation = get_conversation_chain(vec)
117
+ st.success("data loaded")
118
+
119
+ if "conversation" not in st.session_state:
120
+ st.session_state.conversation = None
121
+ if "chat_history" not in st.session_state:
122
+ st.session_state.chat_history = None
123
+
124
+ st.header("Doc Verify RAG :hospital:")
125
+ user_question = st.text_input("Ask a question about your documents:")
126
+ if user_question:
127
+ handle_userinput(user_question)
128
+
129
+ with st.sidebar:
130
+
131
+ st.subheader("Classification Instrucitons")
132
+ classifier_docs = st.file_uploader("Upload your instructions here and click on 'Process'", accept_multiple_files=True)
133
+ filenames = [file.name for file in classifier_docs if file is not None]
134
+
135
+ if st.button("Process Classification"):
136
+ with st.spinner("Processing"):
137
+ loaded_vec_store = None
138
+ for filename in filenames:
139
+ if ".pkl" in filename:
140
+ file_path = os.path.join('vectorstore', filename)
141
+ with open(file_path, 'rb') as f:
142
+ loaded_vec_store = pickle.load(f)
143
+ raw_text = get_pdf_text(pdf_docs)
144
+ text_chunks = get_text_chunks(raw_text)
145
+ vec = get_vectorstore(text_chunks)
146
+ if loaded_vec_store:
147
+ vec.merge_from(loaded_vec_store)
148
+ st.warning("loaded vectorstore")
149
+ if "vectorstore" in st.session_state:
150
+ vec.merge_from(st.session_state.vectorstore)
151
+ st.warning("merged to existing")
152
+ st.session_state.vectorstore = vec
153
+ st.session_state.conversation = get_conversation_chain(vec)
154
+ st.success("data loaded")
155
+
156
+ # Save and Load Embeddings
157
+ if st.button("Save Embeddings"):
158
+ if "vectorstore" in st.session_state:
159
+ safe_vec_store()
160
+ # st.session_state.vectorstore.save_local("faiss_index")
161
+ st.sidebar.success("safed")
162
+ else:
163
+ st.sidebar.warning("No embeddings to save. Please process documents first.")
164
+
165
+ if st.button("Load Embeddings"):
166
+ st.warning("this function is not in use, just upload the vectorstore")
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()