Update app.py
Browse files
app.py
CHANGED
@@ -1,109 +1,95 @@
|
|
1 |
|
2 |
|
3 |
-
import streamlit as st
|
4 |
import os
|
5 |
-
import
|
6 |
-
import shutil
|
7 |
-
from langchain.text_splitter import TokenTextSplitter,RecursiveCharacterTextSplitter,CharacterTextSplitter
|
8 |
-
from langchain.document_loaders import PyPDFLoader
|
9 |
-
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
|
10 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
from transformers import pipeline
|
12 |
-
import
|
13 |
-
from langchain.
|
14 |
-
from langchain.
|
15 |
-
from langchain.document_loaders import TextLoader
|
16 |
-
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
17 |
-
from langchain.memory import ConversationBufferMemory
|
18 |
-
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
|
19 |
-
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
20 |
-
from langchain.chains.combine_documents import create_stuff_documents_chain
|
21 |
-
from langchain_core.runnables.history import RunnableWithMessageHistory
|
22 |
-
from langchain_core.chat_history import BaseChatMessageHistory
|
23 |
-
from langchain_community.chat_message_histories import ChatMessageHistory
|
24 |
-
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
25 |
-
from langchain_community.llms import Aphrodite
|
26 |
-
from typing import Callable, Dict, List, Optional, Union
|
27 |
from langchain.vectorstores import Chroma
|
28 |
-
|
29 |
-
from langchain_community.llms import llamacpp
|
30 |
-
from utills import split_docs, retriever_from_chroma, history_aware_retriever,chroma_db
|
31 |
-
from langchain_community.chat_message_histories.streamlit import StreamlitChatMessageHistory
|
32 |
-
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
|
|
41 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
42 |
data_path = "./data/"
|
43 |
model_path = os.path.join(script_dir, 'mistral-7b-v0.1-layla-v4-Q4_K_M.gguf.2')
|
44 |
store = {}
|
45 |
|
|
|
46 |
model_name = "sentence-transformers/all-mpnet-base-v2"
|
47 |
model_kwargs = {'device': 'cpu'}
|
48 |
encode_kwargs = {'normalize_embeddings': True}
|
49 |
-
hf = HuggingFaceEmbeddings(
|
50 |
-
model_name=model_name,
|
51 |
-
model_kwargs=model_kwargs,
|
52 |
-
encode_kwargs=encode_kwargs)
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
documents.extend(documents)
|
68 |
|
|
|
|
|
|
|
69 |
|
70 |
docs = split_docs(documents, 450, 20)
|
71 |
-
chroma_db = chroma_db(docs,hf)
|
72 |
-
retriever = retriever_from_chroma(chroma_db, "mmr", 6)
|
73 |
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
model_kwargs = {'device': 'cpu'}
|
77 |
-
encode_kwargs = {'normalize_embeddings': True}
|
78 |
-
hf = HuggingFaceEmbeddings(
|
79 |
-
model_name=model_name,
|
80 |
-
model_kwargs=model_kwargs,
|
81 |
-
encode_kwargs=encode_kwargs
|
82 |
-
)
|
83 |
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
n_gpu_layers=0,
|
90 |
-
temperature=0.0,
|
91 |
-
top_p=0.5,
|
92 |
-
n_ctx=7000,
|
93 |
-
max_tokens=350,
|
94 |
-
repeat_penalty=1.7,
|
95 |
-
stop=["", "Instruction:", "### Instruction:", "###<user>", "</user>"],
|
96 |
-
callback_manager=callback_manager,
|
97 |
-
verbose=False,
|
98 |
-
)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
contextualize_q_system_prompt = """Given a context, chat history and the latest user question
|
102 |
which maybe reference context in the chat history, formulate a standalone question
|
103 |
which can be understood without the chat history. Do NOT answer the question,
|
104 |
just reformulate it if needed and otherwise return it as is."""
|
105 |
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
|
108 |
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
|
109 |
|
@@ -115,22 +101,30 @@ qa_prompt = ChatPromptTemplate.from_messages(
|
|
115 |
]
|
116 |
)
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
conversational_rag_chain = RunnableWithMessageHistory(
|
123 |
-
rag_chain,
|
124 |
-
lambda session_id: msgs,
|
125 |
-
input_messages_key="input",
|
126 |
-
history_messages_key="chat_history",
|
127 |
-
output_messages_key="answer",
|
128 |
-
)
|
129 |
|
|
|
130 |
|
|
|
|
|
|
|
131 |
|
|
|
|
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
|
|
134 |
|
135 |
def display_chat_history(chat_history):
|
136 |
"""Displays the chat history in Streamlit."""
|
@@ -139,45 +133,38 @@ def display_chat_history(chat_history):
|
|
139 |
|
140 |
def display_documents(docs, on_click=None):
|
141 |
"""Displays retrieved documents with optional click action."""
|
142 |
-
if docs:
|
143 |
-
for i, document in enumerate(docs):
|
144 |
st.write(f"**Docs {i+1}**")
|
145 |
-
st.markdown(document, unsafe_allow_html=True)
|
146 |
if on_click:
|
147 |
if st.button(f"Expand Article {i+1}"):
|
148 |
-
on_click(i)
|
149 |
|
150 |
def main(conversational_rag_chain):
|
151 |
"""Main function for the Streamlit app."""
|
152 |
-
# Initialize chat history if not already present in session state
|
153 |
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
|
154 |
chain_with_history = conversational_rag_chain
|
155 |
|
156 |
st.title("Conversational RAG Chatbot")
|
157 |
|
158 |
-
# Display chat history
|
159 |
display_chat_history(msgs)
|
160 |
|
161 |
if prompt := st.chat_input():
|
162 |
st.chat_message("human").write(prompt)
|
163 |
|
164 |
-
# Prepare the input dictionary with the correct keys
|
165 |
input_dict = {"input": prompt, "chat_history": msgs.messages}
|
166 |
config = {"configurable": {"session_id": "any"}}
|
167 |
|
168 |
-
# Process user input and handle response
|
169 |
response = chain_with_history.invoke(input_dict, config)
|
170 |
st.chat_message("ai").write(response["answer"])
|
171 |
|
172 |
-
# Display retrieved documents (if any and present in response)
|
173 |
if "docs" in response and response["documents"]:
|
174 |
docs = response["documents"]
|
175 |
def expand_document(index):
|
176 |
-
# Implement your document expansion logic here (e.g., show extra details)
|
177 |
st.write(f"Expanding document {index+1}...")
|
178 |
-
display_documents(docs, expand_document)
|
179 |
|
180 |
-
# Update chat history in session state
|
181 |
st.session_state["chat_history"] = msgs
|
182 |
|
183 |
if __name__ == "__main__":
|
|
|
1 |
|
2 |
|
|
|
3 |
import os
|
4 |
+
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
5 |
from transformers import pipeline
|
6 |
+
from langchain import HuggingFaceEmbeddings, CallbackManager, LlamaCpp, TextLoader, create_stuff_documents_chain, create_retrieval_chain, RunnableWithMessageHistory, ChatPromptTemplate, MessagesPlaceholder, StreamlitChatMessageHistory
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain.chains.question_answering import load_qa_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from langchain.vectorstores import Chroma
|
10 |
+
from langchain.retrievers import mmr_retriever
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
# Initialize variables and paths
|
13 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
14 |
data_path = "./data/"
|
15 |
model_path = os.path.join(script_dir, 'mistral-7b-v0.1-layla-v4-Q4_K_M.gguf.2')
|
16 |
store = {}
|
17 |
|
18 |
+
# Set up HuggingFace embeddings
|
19 |
model_name = "sentence-transformers/all-mpnet-base-v2"
|
20 |
model_kwargs = {'device': 'cpu'}
|
21 |
encode_kwargs = {'normalize_embeddings': True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Use Streamlit's cache to avoid recomputation
|
24 |
+
@st.cache_resource
|
25 |
+
def load_embeddings():
|
26 |
+
return HuggingFaceEmbeddings(
|
27 |
+
model_name=model_name,
|
28 |
+
model_kwargs=model_kwargs,
|
29 |
+
encode_kwargs=encode_kwargs
|
30 |
+
)
|
31 |
|
32 |
+
hf = load_embeddings()
|
33 |
|
34 |
+
@st.cache_data
|
35 |
+
def load_documents(data_path):
|
36 |
+
documents = []
|
37 |
+
for filename in os.listdir(data_path):
|
38 |
+
if filename.endswith('.txt'):
|
39 |
+
file_path = os.path.join(data_path, filename)
|
40 |
+
documents.extend(TextLoader(file_path).load())
|
41 |
+
return documents
|
42 |
|
43 |
+
documents = load_documents(data_path)
|
|
|
|
|
44 |
|
45 |
+
def split_docs(documents, chunk_size, overlap):
|
46 |
+
# Your implementation here
|
47 |
+
pass
|
48 |
|
49 |
docs = split_docs(documents, 450, 20)
|
|
|
|
|
50 |
|
51 |
+
@st.cache_resource
|
52 |
+
def create_chroma_db(docs, hf):
|
53 |
+
return Chroma(docs, hf)
|
54 |
|
55 |
+
chroma_db = create_chroma_db(docs, hf)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
@st.cache_resource
|
58 |
+
def create_retriever(chroma_db):
|
59 |
+
return mmr_retriever(chroma_db, "mmr", 6)
|
60 |
|
61 |
+
retriever = create_retriever(chroma_db)
|
62 |
|
63 |
+
# Set up LlamaCpp model
|
64 |
+
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
@st.cache_resource
|
67 |
+
def load_llm():
|
68 |
+
return LlamaCpp(
|
69 |
+
model_path='qwen2-0_5b-instruct-q4_0.gguf',
|
70 |
+
n_gpu_layers=0,
|
71 |
+
temperature=0.0,
|
72 |
+
top_p=0.5,
|
73 |
+
n_ctx=7000,
|
74 |
+
max_tokens=350,
|
75 |
+
repeat_penalty=1.7,
|
76 |
+
stop=["", "Instruction:", "### Instruction:", "###<user>", "</user>"],
|
77 |
+
callback_manager=callback_manager,
|
78 |
+
verbose=False,
|
79 |
+
)
|
80 |
+
|
81 |
+
llm = load_llm()
|
82 |
|
83 |
contextualize_q_system_prompt = """Given a context, chat history and the latest user question
|
84 |
which maybe reference context in the chat history, formulate a standalone question
|
85 |
which can be understood without the chat history. Do NOT answer the question,
|
86 |
just reformulate it if needed and otherwise return it as is."""
|
87 |
|
88 |
+
@st.cache_resource
|
89 |
+
def create_history_aware_retriever():
|
90 |
+
return history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
|
91 |
+
|
92 |
+
ha_retriever = create_history_aware_retriever()
|
93 |
|
94 |
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
|
95 |
|
|
|
101 |
]
|
102 |
)
|
103 |
|
104 |
+
@st.cache_resource
|
105 |
+
def create_question_answer_chain():
|
106 |
+
return create_stuff_documents_chain(llm, qa_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
question_answer_chain = create_question_answer_chain()
|
109 |
|
110 |
+
@st.cache_resource
|
111 |
+
def create_rag_chain():
|
112 |
+
return create_retrieval_chain(ha_retriever, question_answer_chain)
|
113 |
|
114 |
+
rag_chain = create_rag_chain()
|
115 |
+
msgs = StreamlitChatMessageHistory(key="special_app_key")
|
116 |
|
117 |
+
@st.cache_resource
|
118 |
+
def create_conversational_rag_chain():
|
119 |
+
return RunnableWithMessageHistory(
|
120 |
+
rag_chain,
|
121 |
+
lambda session_id: msgs,
|
122 |
+
input_messages_key="input",
|
123 |
+
history_messages_key="chat_history",
|
124 |
+
output_messages_key="answer",
|
125 |
+
)
|
126 |
|
127 |
+
conversational_rag_chain = create_conversational_rag_chain()
|
128 |
|
129 |
def display_chat_history(chat_history):
|
130 |
"""Displays the chat history in Streamlit."""
|
|
|
133 |
|
134 |
def display_documents(docs, on_click=None):
|
135 |
"""Displays retrieved documents with optional click action."""
|
136 |
+
if docs:
|
137 |
+
for i, document in enumerate(docs):
|
138 |
st.write(f"**Docs {i+1}**")
|
139 |
+
st.markdown(document, unsafe_allow_html=True)
|
140 |
if on_click:
|
141 |
if st.button(f"Expand Article {i+1}"):
|
142 |
+
on_click(i)
|
143 |
|
144 |
def main(conversational_rag_chain):
|
145 |
"""Main function for the Streamlit app."""
|
|
|
146 |
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
|
147 |
chain_with_history = conversational_rag_chain
|
148 |
|
149 |
st.title("Conversational RAG Chatbot")
|
150 |
|
|
|
151 |
display_chat_history(msgs)
|
152 |
|
153 |
if prompt := st.chat_input():
|
154 |
st.chat_message("human").write(prompt)
|
155 |
|
|
|
156 |
input_dict = {"input": prompt, "chat_history": msgs.messages}
|
157 |
config = {"configurable": {"session_id": "any"}}
|
158 |
|
|
|
159 |
response = chain_with_history.invoke(input_dict, config)
|
160 |
st.chat_message("ai").write(response["answer"])
|
161 |
|
|
|
162 |
if "docs" in response and response["documents"]:
|
163 |
docs = response["documents"]
|
164 |
def expand_document(index):
|
|
|
165 |
st.write(f"Expanding document {index+1}...")
|
166 |
+
display_documents(docs, expand_document)
|
167 |
|
|
|
168 |
st.session_state["chat_history"] = msgs
|
169 |
|
170 |
if __name__ == "__main__":
|