ArturG9 commited on
Commit
467c73a
1 Parent(s): 05f67cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -96
app.py CHANGED
@@ -1,109 +1,95 @@
1
 
2
 
3
- import streamlit as st
4
  import os
5
- import sys
6
- import shutil
7
- from langchain.text_splitter import TokenTextSplitter,RecursiveCharacterTextSplitter,CharacterTextSplitter
8
- from langchain.document_loaders import PyPDFLoader
9
- from langchain.document_loaders.pdf import PyPDFDirectoryLoader
10
- from langchain_community.embeddings import HuggingFaceEmbeddings
11
  from transformers import pipeline
12
- import torch
13
- from langchain.chains.query_constructor.base import AttributeInfo
14
- from langchain.vectorstores import DocArrayInMemorySearch
15
- from langchain.document_loaders import TextLoader
16
- from langchain.chains import RetrievalQA, ConversationalRetrievalChain
17
- from langchain.memory import ConversationBufferMemory
18
- from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
19
- from langchain.chains import create_history_aware_retriever, create_retrieval_chain
20
- from langchain.chains.combine_documents import create_stuff_documents_chain
21
- from langchain_core.runnables.history import RunnableWithMessageHistory
22
- from langchain_core.chat_history import BaseChatMessageHistory
23
- from langchain_community.chat_message_histories import ChatMessageHistory
24
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
25
- from langchain_community.llms import Aphrodite
26
- from typing import Callable, Dict, List, Optional, Union
27
  from langchain.vectorstores import Chroma
28
- import streamlit as st
29
- from langchain_community.llms import llamacpp
30
- from utills import split_docs, retriever_from_chroma, history_aware_retriever,chroma_db
31
- from langchain_community.chat_message_histories.streamlit import StreamlitChatMessageHistory
32
- from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
33
-
34
-
35
-
36
-
37
-
38
-
39
-
40
 
 
41
  script_dir = os.path.dirname(os.path.abspath(__file__))
42
  data_path = "./data/"
43
  model_path = os.path.join(script_dir, 'mistral-7b-v0.1-layla-v4-Q4_K_M.gguf.2')
44
  store = {}
45
 
 
46
  model_name = "sentence-transformers/all-mpnet-base-v2"
47
  model_kwargs = {'device': 'cpu'}
48
  encode_kwargs = {'normalize_embeddings': True}
49
- hf = HuggingFaceEmbeddings(
50
- model_name=model_name,
51
- model_kwargs=model_kwargs,
52
- encode_kwargs=encode_kwargs)
53
-
54
-
55
-
56
 
57
- documents = []
58
-
59
- for filename in os.listdir(data_path):
 
 
 
 
 
60
 
61
- if filename.endswith('.txt'):
62
 
63
- file_path = os.path.join(data_path, filename)
 
 
 
 
 
 
 
64
 
65
- documents = TextLoader(file_path).load()
66
-
67
- documents.extend(documents)
68
 
 
 
 
69
 
70
  docs = split_docs(documents, 450, 20)
71
- chroma_db = chroma_db(docs,hf)
72
- retriever = retriever_from_chroma(chroma_db, "mmr", 6)
73
 
 
 
 
74
 
75
- model_name = "sentence-transformers/all-mpnet-base-v2"
76
- model_kwargs = {'device': 'cpu'}
77
- encode_kwargs = {'normalize_embeddings': True}
78
- hf = HuggingFaceEmbeddings(
79
- model_name=model_name,
80
- model_kwargs=model_kwargs,
81
- encode_kwargs=encode_kwargs
82
- )
83
 
 
 
 
84
 
85
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
86
 
87
- llm = llamacpp.LlamaCpp(
88
- model_path= 'qwen2-0_5b-instruct-q4_0.gguf',
89
- n_gpu_layers=0,
90
- temperature=0.0,
91
- top_p=0.5,
92
- n_ctx=7000,
93
- max_tokens=350,
94
- repeat_penalty=1.7,
95
- stop=["", "Instruction:", "### Instruction:", "###<user>", "</user>"],
96
- callback_manager=callback_manager,
97
- verbose=False,
98
- )
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  contextualize_q_system_prompt = """Given a context, chat history and the latest user question
102
  which maybe reference context in the chat history, formulate a standalone question
103
  which can be understood without the chat history. Do NOT answer the question,
104
  just reformulate it if needed and otherwise return it as is."""
105
 
106
- ha_retriever = history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
 
 
 
 
107
 
108
  qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
109
 
@@ -115,22 +101,30 @@ qa_prompt = ChatPromptTemplate.from_messages(
115
  ]
116
  )
117
 
118
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
119
- rag_chain = create_retrieval_chain(ha_retriever, question_answer_chain)
120
- msgs = StreamlitChatMessageHistory(key="special_app_key")
121
-
122
- conversational_rag_chain = RunnableWithMessageHistory(
123
- rag_chain,
124
- lambda session_id: msgs,
125
- input_messages_key="input",
126
- history_messages_key="chat_history",
127
- output_messages_key="answer",
128
- )
129
 
 
130
 
 
 
 
131
 
 
 
132
 
 
 
 
 
 
 
 
 
 
133
 
 
134
 
135
  def display_chat_history(chat_history):
136
  """Displays the chat history in Streamlit."""
@@ -139,45 +133,38 @@ def display_chat_history(chat_history):
139
 
140
  def display_documents(docs, on_click=None):
141
  """Displays retrieved documents with optional click action."""
142
- if docs: # Check if documents exist before displaying
143
- for i, document in enumerate(docs): # Iterate over docs, not documents
144
  st.write(f"**Docs {i+1}**")
145
- st.markdown(document, unsafe_allow_html=True) # Allow HTML formatting
146
  if on_click:
147
  if st.button(f"Expand Article {i+1}"):
148
- on_click(i) # Call the user-defined click function
149
 
150
  def main(conversational_rag_chain):
151
  """Main function for the Streamlit app."""
152
- # Initialize chat history if not already present in session state
153
  msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
154
  chain_with_history = conversational_rag_chain
155
 
156
  st.title("Conversational RAG Chatbot")
157
 
158
- # Display chat history
159
  display_chat_history(msgs)
160
 
161
  if prompt := st.chat_input():
162
  st.chat_message("human").write(prompt)
163
 
164
- # Prepare the input dictionary with the correct keys
165
  input_dict = {"input": prompt, "chat_history": msgs.messages}
166
  config = {"configurable": {"session_id": "any"}}
167
 
168
- # Process user input and handle response
169
  response = chain_with_history.invoke(input_dict, config)
170
  st.chat_message("ai").write(response["answer"])
171
 
172
- # Display retrieved documents (if any and present in response)
173
  if "docs" in response and response["documents"]:
174
  docs = response["documents"]
175
  def expand_document(index):
176
- # Implement your document expansion logic here (e.g., show extra details)
177
  st.write(f"Expanding document {index+1}...")
178
- display_documents(docs, expand_document) # Pass click function
179
 
180
- # Update chat history in session state
181
  st.session_state["chat_history"] = msgs
182
 
183
  if __name__ == "__main__":
 
1
 
2
 
 
3
  import os
4
+ import streamlit as st
 
 
 
 
 
5
  from transformers import pipeline
6
+ from langchain import HuggingFaceEmbeddings, CallbackManager, LlamaCpp, TextLoader, create_stuff_documents_chain, create_retrieval_chain, RunnableWithMessageHistory, ChatPromptTemplate, MessagesPlaceholder, StreamlitChatMessageHistory
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.chains.question_answering import load_qa_chain
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain.vectorstores import Chroma
10
+ from langchain.retrievers import mmr_retriever
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Initialize variables and paths
13
  script_dir = os.path.dirname(os.path.abspath(__file__))
14
  data_path = "./data/"
15
  model_path = os.path.join(script_dir, 'mistral-7b-v0.1-layla-v4-Q4_K_M.gguf.2')
16
  store = {}
17
 
18
+ # Set up HuggingFace embeddings
19
  model_name = "sentence-transformers/all-mpnet-base-v2"
20
  model_kwargs = {'device': 'cpu'}
21
  encode_kwargs = {'normalize_embeddings': True}
 
 
 
 
 
 
 
22
 
23
+ # Use Streamlit's cache to avoid recomputation
24
+ @st.cache_resource
25
+ def load_embeddings():
26
+ return HuggingFaceEmbeddings(
27
+ model_name=model_name,
28
+ model_kwargs=model_kwargs,
29
+ encode_kwargs=encode_kwargs
30
+ )
31
 
32
+ hf = load_embeddings()
33
 
34
+ @st.cache_data
35
+ def load_documents(data_path):
36
+ documents = []
37
+ for filename in os.listdir(data_path):
38
+ if filename.endswith('.txt'):
39
+ file_path = os.path.join(data_path, filename)
40
+ documents.extend(TextLoader(file_path).load())
41
+ return documents
42
 
43
+ documents = load_documents(data_path)
 
 
44
 
45
+ def split_docs(documents, chunk_size, overlap):
46
+ # Your implementation here
47
+ pass
48
 
49
  docs = split_docs(documents, 450, 20)
 
 
50
 
51
+ @st.cache_resource
52
+ def create_chroma_db(docs, hf):
53
+ return Chroma(docs, hf)
54
 
55
+ chroma_db = create_chroma_db(docs, hf)
 
 
 
 
 
 
 
56
 
57
+ @st.cache_resource
58
+ def create_retriever(chroma_db):
59
+ return mmr_retriever(chroma_db, "mmr", 6)
60
 
61
+ retriever = create_retriever(chroma_db)
62
 
63
+ # Set up LlamaCpp model
64
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 
 
 
 
 
 
 
 
 
 
65
 
66
+ @st.cache_resource
67
+ def load_llm():
68
+ return LlamaCpp(
69
+ model_path='qwen2-0_5b-instruct-q4_0.gguf',
70
+ n_gpu_layers=0,
71
+ temperature=0.0,
72
+ top_p=0.5,
73
+ n_ctx=7000,
74
+ max_tokens=350,
75
+ repeat_penalty=1.7,
76
+ stop=["", "Instruction:", "### Instruction:", "###<user>", "</user>"],
77
+ callback_manager=callback_manager,
78
+ verbose=False,
79
+ )
80
+
81
+ llm = load_llm()
82
 
83
  contextualize_q_system_prompt = """Given a context, chat history and the latest user question
84
  which maybe reference context in the chat history, formulate a standalone question
85
  which can be understood without the chat history. Do NOT answer the question,
86
  just reformulate it if needed and otherwise return it as is."""
87
 
88
+ @st.cache_resource
89
+ def create_history_aware_retriever():
90
+ return history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
91
+
92
+ ha_retriever = create_history_aware_retriever()
93
 
94
  qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
95
 
 
101
  ]
102
  )
103
 
104
+ @st.cache_resource
105
+ def create_question_answer_chain():
106
+ return create_stuff_documents_chain(llm, qa_prompt)
 
 
 
 
 
 
 
 
107
 
108
+ question_answer_chain = create_question_answer_chain()
109
 
110
+ @st.cache_resource
111
+ def create_rag_chain():
112
+ return create_retrieval_chain(ha_retriever, question_answer_chain)
113
 
114
+ rag_chain = create_rag_chain()
115
+ msgs = StreamlitChatMessageHistory(key="special_app_key")
116
 
117
+ @st.cache_resource
118
+ def create_conversational_rag_chain():
119
+ return RunnableWithMessageHistory(
120
+ rag_chain,
121
+ lambda session_id: msgs,
122
+ input_messages_key="input",
123
+ history_messages_key="chat_history",
124
+ output_messages_key="answer",
125
+ )
126
 
127
+ conversational_rag_chain = create_conversational_rag_chain()
128
 
129
  def display_chat_history(chat_history):
130
  """Displays the chat history in Streamlit."""
 
133
 
134
  def display_documents(docs, on_click=None):
135
  """Displays retrieved documents with optional click action."""
136
+ if docs:
137
+ for i, document in enumerate(docs):
138
  st.write(f"**Docs {i+1}**")
139
+ st.markdown(document, unsafe_allow_html=True)
140
  if on_click:
141
  if st.button(f"Expand Article {i+1}"):
142
+ on_click(i)
143
 
144
  def main(conversational_rag_chain):
145
  """Main function for the Streamlit app."""
 
146
  msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
147
  chain_with_history = conversational_rag_chain
148
 
149
  st.title("Conversational RAG Chatbot")
150
 
 
151
  display_chat_history(msgs)
152
 
153
  if prompt := st.chat_input():
154
  st.chat_message("human").write(prompt)
155
 
 
156
  input_dict = {"input": prompt, "chat_history": msgs.messages}
157
  config = {"configurable": {"session_id": "any"}}
158
 
 
159
  response = chain_with_history.invoke(input_dict, config)
160
  st.chat_message("ai").write(response["answer"])
161
 
 
162
  if "docs" in response and response["documents"]:
163
  docs = response["documents"]
164
  def expand_document(index):
 
165
  st.write(f"Expanding document {index+1}...")
166
+ display_documents(docs, expand_document)
167
 
 
168
  st.session_state["chat_history"] = msgs
169
 
170
  if __name__ == "__main__":