nkcong206 commited on
Commit
209ea81
·
verified ·
1 Parent(s): cf4099e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -139
app.py CHANGED
@@ -2,17 +2,15 @@ import streamlit as st
2
  import os
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain_core.prompts import ChatPromptTemplate
 
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain.prompts import PromptTemplate
7
 
8
  from langchain_core.output_parsers import StrOutputParser
9
 
10
  from langchain_core.runnables import RunnablePassthrough
11
- from langchain_qdrant import QdrantVectorStore
12
  import Raptor
13
- from io import StringIO
14
- from qdrant_client import QdrantClient
15
- from qdrant_client.models import Distance, VectorParams
16
 
17
  page = st.title("Chat with AskUSTH")
18
 
@@ -52,50 +50,6 @@ def get_embedding_model():
52
  if "embd" not in st.session_state:
53
  st.session_state.embd = get_embedding_model()
54
 
55
- @st.cache_resource
56
- def load_chromadb(collection_name):
57
- client = QdrantClient(
58
- url="https://da9fadd2-dc5a-4481-ac79-4e2677a2354b.europe-west3-0.gcp.cloud.qdrant.io",
59
- api_key="X_-IVToBM07Mot4Mmzg5xNjYzc1DlIgl0VQDUNmGhI_Z-WA5FJ2ETA"
60
- )
61
-
62
- client.recreate_collection(
63
- collection_name=collection_name,
64
- vectors_config=VectorParams(size=768, distance=Distance.COSINE)
65
- )
66
- db = QdrantVectorStore(
67
- client=client,
68
- collection_name=collection_name,
69
- embedding=st.session_state.embd,
70
- )
71
- return db
72
-
73
- @st.cache_resource
74
- def update_chromadb(collection_name):
75
- client = QdrantClient(
76
- url="https://da9fadd2-dc5a-4481-ac79-4e2677a2354b.europe-west3-0.gcp.cloud.qdrant.io",
77
- api_key="X_-IVToBM07Mot4Mmzg5xNjYzc1DlIgl0VQDUNmGhI_Z-WA5FJ2ETA"
78
- )
79
-
80
- try:
81
- client.delete_collection(collection_name=collection_name)
82
- except Exception as e:
83
- print(f"Warning: {e}")
84
-
85
- client.recreate_collection(
86
- collection_name=collection_name,
87
- vectors_config=VectorParams(size=768, distance=Distance.COSINE)
88
- )
89
- db = QdrantVectorStore(
90
- client=client,
91
- collection_name=collection_name,
92
- embedding=st.session_state.embd,
93
- )
94
- return db
95
-
96
- if "vector_store" not in st.session_state:
97
- st.session_state.vector_store = load_chromadb("data")
98
-
99
  if "model" not in st.session_state:
100
  st.session_state.model = None
101
 
@@ -123,12 +77,64 @@ if st.session_state.gemini_api is None:
123
  if st.session_state.gemini_api and st.session_state.model is None:
124
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def format_docs(docs):
127
  return "\n\n".join(doc.page_content for doc in docs)
128
 
129
  @st.cache_resource
130
- def rag_chain(_model, _vectorstore):
131
- retriever = _vectorstore.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
132
  template = """
133
  Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
134
  Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
@@ -139,93 +145,24 @@ def rag_chain(_model, _vectorstore):
139
  {question}
140
  """
141
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
142
- rag = (
143
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
144
  | prompt
145
  | _model
146
  | StrOutputParser()
147
  )
148
- return rag
149
-
150
- if st.session_state.model is not None and st.session_state.vector_store is not None:
151
- st.session_state.rag = rag_chain(st.session_state.model, st.session_state.vector_store)
152
-
153
- if "new_docs" not in st.session_state:
154
- st.session_state.new_docs = False
155
-
156
- with st.sidebar:
157
- uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
158
- if st.session_state.model:
159
- documents = []
160
- uploaded_file_names = set()
161
- if uploaded_files:
162
- for uploaded_file in uploaded_files:
163
- uploaded_file_names.add(uploaded_file.name)
164
- if uploaded_file_names != st.session_state.uploaded_files and not st.session_state.new_docs:
165
- st.session_state.uploaded_files = uploaded_file_names
166
- st.session_state.new_docs = True
167
- if uploaded_files:
168
- for uploaded_file in uploaded_files:
169
- stringio=StringIO(uploaded_file.getvalue().decode('utf-8'))
170
- read_data=str(stringio.read())
171
- documents.append(read_data)
172
-
173
- def update_rag_chain(_model, _embd, _vectorstore, docs_texts):
174
- results = Raptor.recursive_embed_cluster_summarize(_model, _embd, docs_texts, level=1, n_levels=3)
175
- all_texts = docs_texts.copy()
176
- for level in sorted(results.keys()):
177
- summaries = results[level][1]["summaries"].tolist()
178
- all_texts.extend(summaries)
179
- _vectorstore.add_texts(texts=all_texts)
180
- rag = rag_chain(_model, _vectorstore)
181
- return rag
182
-
183
- def reset_rag_chain(_model, _vectorstore):
184
- rag = rag_chain(_model, _vectorstore)
185
- return rag
186
-
187
- if "query_router" not in st.session_state:
188
- st.session_state.query_router = None
189
-
190
- @st.cache_resource
191
- def query_router(_model):
192
- mess = ChatPromptTemplate.from_messages(
193
- [
194
- (
195
- "system",
196
- """Bạn là một chatbot hỗ trợ giải đáp về đại học, nhiệm vụ của bạn là phân loại câu hỏi.
197
- Nếu câu hỏi về đại học thì trả về 'university', nếu không liên quan tới tuyển sinh và sinh viên thì trả về 'other'.
198
- Bắt buộc Kết quả chỉ trả về với một trong hai lựa chọn trên.
199
- Không được trả lời thêm bất kỳ thông tin nào khác.""",
200
- ),
201
- ("human", "{input}"),
202
- ]
203
- )
204
- chain = mess | _model
205
- return chain
206
 
207
- if st.session_state.model is not None:
208
- st.session_state.query_router = query_router(st.session_state.model)
209
-
210
- @st.dialog("Update DB")
211
- def update_vectorstore(_model, _embd, _vectorstore, docs):
212
- docs_texts = [d for d in docs]
213
- st.session_state.rag = update_rag_chain(_model, _embd, _vectorstore, docs_texts)
214
  st.rerun()
215
-
216
- @st.dialog("Reset DB")
217
- def reset_vectorstore(_model, _vectorstore):
218
- st.session_state.rag = reset_rag_chain(_model, _vectorstore)
219
- st.rerun()
220
 
221
- if st.session_state.new_docs:
222
- st.session_state.new_docs = False
223
- st.session_state.vector_store = update_chromadb("data")
224
- if st.session_state.uploaded_files:
225
- update_vectorstore(st.session_state.model, st.session_state.embd, st.session_state.vector_store, documents)
226
- else:
227
- reset_vectorstore(st.session_state.model, st.session_state.vector_store)
228
-
229
  if st.session_state.model is not None:
230
  if st.session_state.llm is None:
231
  mess = ChatPromptTemplate.from_messages(
@@ -256,16 +193,12 @@ if st.session_state.model is not None:
256
  st.write(prompt)
257
 
258
  with st.chat_message("assistant"):
259
- router = st.session_state.query_router.invoke(prompt)
260
- switch = router.content
261
- if "university" in switch:
262
  respone = st.session_state.rag.invoke(prompt)
263
- f_response = f"RAG: {respone}"
264
- st.write(f_response)
265
  else:
266
- respone = st.session_state.llm.invoke(prompt)
267
- f_response = f"other: {respone.content}"
268
- st.write(f_response)
269
 
270
- st.session_state.chat_history.append({"role": "assistant", "content": f_response})
271
-
 
2
  import os
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_community.document_loaders import TextLoader
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from langchain.prompts import PromptTemplate
8
 
9
  from langchain_core.output_parsers import StrOutputParser
10
 
11
  from langchain_core.runnables import RunnablePassthrough
12
+ from langchain_chroma import Chroma
13
  import Raptor
 
 
 
14
 
15
  page = st.title("Chat with AskUSTH")
16
 
 
50
  if "embd" not in st.session_state:
51
  st.session_state.embd = get_embedding_model()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if "model" not in st.session_state:
54
  st.session_state.model = None
55
 
 
77
  if st.session_state.gemini_api and st.session_state.model is None:
78
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
79
 
80
+ if st.session_state.save_dir is None:
81
+ save_dir = "./Documents"
82
+ if not os.path.exists(save_dir):
83
+ os.makedirs(save_dir)
84
+ st.session_state.save_dir = save_dir
85
+
86
+ def load_txt(file_path):
87
+ loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
88
+ doc = loader_sv.load()
89
+ return doc
90
+
91
+ with st.sidebar:
92
+ uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
93
+ if st.session_state.gemini_api:
94
+ if uploaded_files:
95
+ documents = []
96
+ uploaded_file_names = set()
97
+ new_docs = False
98
+ for uploaded_file in uploaded_files:
99
+ uploaded_file_names.add(uploaded_file.name)
100
+ if uploaded_file.name not in st.session_state.uploaded_files:
101
+ file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
102
+ with open(file_path, mode='wb') as w:
103
+ w.write(uploaded_file.getvalue())
104
+ else:
105
+ continue
106
+
107
+ new_docs = True
108
+
109
+ doc = load_txt(file_path)
110
+
111
+ documents.extend([*doc])
112
+
113
+ if new_docs:
114
+ st.session_state.uploaded_files = uploaded_file_names
115
+ st.session_state.rag = None
116
+ else:
117
+ st.session_state.uploaded_files = set()
118
+ st.session_state.rag = None
119
+
120
  def format_docs(docs):
121
  return "\n\n".join(doc.page_content for doc in docs)
122
 
123
  @st.cache_resource
124
+ def compute_rag_chain(_model, _embd, docs_texts):
125
+ results = Raptor.recursive_embed_cluster_summarize(_model, _embd, docs_texts, level=1, n_levels=3)
126
+ all_texts = docs_texts.copy()
127
+ i = 0
128
+ for level in sorted(results.keys()):
129
+ summaries = results[level][1]["summaries"].tolist()
130
+ all_texts.extend(summaries)
131
+ print(f"summary {i} -------------------------------------------------")
132
+ print(summaries)
133
+ i += 1
134
+ print("all_texts ______________________________________")
135
+ print(all_texts)
136
+ vectorstore = Chroma.from_texts(texts=all_texts, embedding=_embd)
137
+ retriever = vectorstore.as_retriever()
138
  template = """
139
  Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
140
  Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
 
145
  {question}
146
  """
147
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
148
+ rag_chain = (
149
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
150
  | prompt
151
  | _model
152
  | StrOutputParser()
153
  )
154
+ return rag_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ @st.dialog("Setup RAG")
157
+ def load_rag():
158
+ docs_texts = [d.page_content for d in documents]
159
+ st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
 
 
 
160
  st.rerun()
 
 
 
 
 
161
 
162
+ if st.session_state.uploaded_files and st.session_state.model is not None:
163
+ if st.session_state.rag is None:
164
+ load_rag()
165
+
 
 
 
 
166
  if st.session_state.model is not None:
167
  if st.session_state.llm is None:
168
  mess = ChatPromptTemplate.from_messages(
 
193
  st.write(prompt)
194
 
195
  with st.chat_message("assistant"):
196
+ if st.session_state.rag is not None:
 
 
197
  respone = st.session_state.rag.invoke(prompt)
198
+ st.write(respone)
 
199
  else:
200
+ ans = st.session_state.llm.invoke(prompt)
201
+ respone = ans.content
202
+ st.write(respone)
203
 
204
+ st.session_state.chat_history.append({"role": "assistant", "content": respone})