miniondenis commited on
Commit
9c32f1b
1 Parent(s): 5ec36b8

refactor: refactoe

Browse files
app.py CHANGED
@@ -10,401 +10,24 @@ from langchain_core.runnables import ConfigurableFieldSpec
10
  from langchain.schema import Document
11
  from langchain.prompts import PromptTemplate
12
  from langchain_core.output_parsers import JsonOutputParser
13
- from langchain.retrievers.multi_query import MultiQueryRetriever
14
  from langchain_core.output_parsers import StrOutputParser
15
  from langchain_pinecone import PineconeVectorStore
16
  from typing_extensions import TypedDict
17
  from typing import Dict, List
18
- from langgraph.graph import END, StateGraph
19
  import warnings
20
 
21
- from lib.embedding import EmbeddingBuilder, build_embedding
22
- from lib.model_builder import ModelBuilder, ModelBuilderV2
23
  from lib.gradio_custom_theme import DarkTheme
24
- warnings.filterwarnings('ignore')
25
- from dotenv import load_dotenv
26
-
27
- load_dotenv()
28
- store = {}
29
-
30
- def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHistory:
31
- if (user_id, conversation_id) not in store:
32
- store[(user_id, conversation_id)] = ChatMessageHistory()
33
- return store[(user_id, conversation_id)]
34
-
35
-
36
- def combine_vectors(vectors):
37
- result = []
38
- vec1_count = len(vectors["vector1"])
39
- # vec2_count = len(vectors["vector2"])
40
- for i in range(vec1_count):
41
- if i < vec1_count:
42
- result.append(vectors['vector1'][i])
43
- # if i < vec2_count:
44
- # result.append(vectors['vector2'][i])
45
- return result
46
-
47
-
48
- class GraphState(TypedDict):
49
- """
50
- Represents the state of our graph.
51
 
52
- Attributes:
53
- question: question
54
- generation: LLM generation
55
- web_search: whether to add search
56
- documents: list of documents
57
- """
58
-
59
- question: str
60
- generation: str
61
- documents: List[Dict]
62
- filtered_documets: List[Dict]
63
- is_fuse: bool
64
- count_regenerations: int
65
-
66
- class FAISSBuilder:
67
- def __enter__(self):
68
- # Initialize resources
69
- with EmbeddingBuilder("intfloat/multilingual-e5-large") as rag_emb:
70
- faiss_db = FAISS.load_local("data/faiss_nk_28_05", rag_emb, allow_dangerous_deserialization=True)
71
- return faiss_db.as_retriever()
72
-
73
- def __exit__(self, exc_type, exc_value, traceback):
74
- # Cleanup resources if necessary
75
- pass
76
 
 
 
77
 
78
- class PineConeBuilder:
79
- def __init__(self, index_name, embedding_model):
80
- self.index_name = index_name
81
- self.embedding_model = embedding_model
82
 
83
- def __enter__(self):
84
- with EmbeddingBuilder(self.embedding_model) as embeddings:
85
- pc_db = PineconeVectorStore.from_existing_index(self.index_name, embeddings)
86
- return pc_db.as_retriever()
87
-
88
- def __exit__(self, exc_type, exc_value, traceback):
89
- # Cleanup resources if necessary
90
- pass
91
 
92
  def deploy():
93
-
94
- casual_prompt = PromptTemplate(
95
- template="""Just answer a question as casual chatter. \n
96
- Here is the user question: {question} \n
97
- Always reply in Russian. Leave clear answer, without any addition.
98
- Chat history:
99
- {chat_history}
100
- Answer:
101
- """,
102
- input_variables=["question", "chat_history"],
103
- )
104
- with ModelBuilderV2("openchat/openchat-7b", 0.7) as llm:
105
- casual_llm = RunnableWithMessageHistory(
106
- casual_prompt | llm,
107
- get_session_history,
108
- input_messages_key="question",
109
- history_messages_key="chat_history",
110
- history_factory_config=[
111
- ConfigurableFieldSpec(
112
- id="user_id",
113
- annotation=str,
114
- name="User ID",
115
- description="Unique identifier for the user.",
116
- default="default_user",
117
- is_shared=True,
118
- ),
119
- ConfigurableFieldSpec(
120
- id="conversation_id",
121
- annotation=str,
122
- name="Conversation ID",
123
- description="Unique identifier for the conversation.",
124
- default="default_session",
125
- is_shared=True,
126
- ),
127
- ],
128
- ) | StrOutputParser()
129
-
130
-
131
- prompt = PromptTemplate(
132
- template="""You are a grader assessing relevance of a retrieved documents to a user question. \n
133
- Here is the first retrieved document: \n\n {document_1} \n\n
134
- Here is the second retrieved document: \n\n {document_2} \n\n
135
- Here is the third retrieved document: \n\n {document_3} \n\n
136
- Here is the user question: {question} \n
137
- If the document contains keywords related to the user question, grade it as relevant. \n
138
- It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
139
- For each! document give a score from 0 to 1, score to indicate whether the document is relevant to the question. \n
140
- Provide the scores as a JSON list that contains an objects with single key 'score' and no premable or explanation.""",
141
- input_variables=["question", "document_1", "document_2", "document_3"],
142
- )
143
- with ModelBuilderV2("cohere/command-r") as llm:
144
- retrieval_grader_3_docs = prompt | llm | JsonOutputParser()
145
-
146
- template = """
147
- SYSTEM: You are an assistant for question-answering tasks.
148
- Use the following pieces of retrieved context to answer the question.
149
- Use previous messages then current message higly likely
150
- If you don't find the answer in the context, transform the question ans ask the user to specify his qusetion.
151
-
152
- Keep the answer concise.
153
- Print a most possible topic of conversation.
154
- Always reply in Russian, all text must be in Russian!
155
-
156
- Context: {context}
157
-
158
- Previous messages: {chat_history}
159
-
160
- Question: {question}
161
-
162
- Answer:
163
-
164
- """
165
- prompt = ChatPromptTemplate.from_template(template)
166
-
167
- with ModelBuilderV2("mistralai/mixtral-8x22b-instruct") as llm:
168
- # Post-processing
169
- def format_docs(docs):
170
- return "\n\n".join(doc.page_content for doc in docs)
171
-
172
- # Chain
173
- rag_chain = RunnableWithMessageHistory(
174
- prompt | llm,
175
- get_session_history,
176
- input_messages_key="question",
177
- history_messages_key="chat_history",
178
- history_factory_config=[
179
- ConfigurableFieldSpec(
180
- id="user_id",
181
- annotation=str,
182
- name="User ID",
183
- description="Unique identifier for the user.",
184
- default="default_user",
185
- is_shared=True,
186
- ),
187
- ConfigurableFieldSpec(
188
- id="conversation_id",
189
- annotation=str,
190
- name="Conversation ID",
191
- description="Unique identifier for the conversation.",
192
- default="default_session",
193
- is_shared=True,
194
- ),
195
- ],
196
- ) | StrOutputParser()
197
-
198
- classification_conversation_template = """
199
- SYSTEM: You are message classificator. You classify message into classes:
200
- - TAX: tax payment or other similar jura topic;
201
- - CASUAL: casual conversation, any generic question that can't fit in the conversation;
202
- - SPAM: any rude, useless messages.
203
- Here is a question: {question}
204
- Provide the answer as a JSON object that contains an object with key 'message_type' and with key 'system_message' with short remark about message and no other premable or explanation.
205
- """
206
-
207
- with ModelBuilderV2("openchat/openchat-7b") as decider_llm:
208
- message_classificator = PromptTemplate.from_template(classification_conversation_template) | decider_llm | JsonOutputParser()
209
-
210
- def retrieve(state):
211
- """
212
- Retrieve documents
213
-
214
- Args:
215
- state (dict): The current graph state
216
-
217
- Returns:
218
- state (dict): New key added to state, documents, that contains retrieved documents
219
- """
220
- print("---RETRIEVE---")
221
- question = state["question"]
222
-
223
- # Retrieval
224
- with FAISSBuilder() as faiss_retriever:
225
- with ModelBuilderV2("openchat/openchat-7b") as mq_llm:
226
- retriever = MultiQueryRetriever.from_llm(retriever=faiss_retriever, llm=mq_llm)
227
- documents = retriever.get_relevant_documents(question)
228
- return {"documents": documents, "question": question}
229
-
230
-
231
- def start_point(state):
232
- """
233
- Start point, just return state
234
-
235
- Args:
236
- state (dict): The current graph state
237
-
238
- Returns:
239
- state (dict): The current graph state
240
- """
241
- return state
242
-
243
- def casual_chat(state):
244
- """
245
- Define type of message
246
-
247
- Args:
248
- state (dict): The current graph state
249
-
250
- Returns:
251
- state (dict): New key added to state, generation, that contains message with casual answer
252
- """
253
- question = state["question"]
254
- print("---CASUAL CHAT---")
255
- generation = casual_llm.invoke({"question": question}, config={"configurable": {"conversation_id": "default_session", "user_id": "deafault_user"}})
256
- state['generation'] = generation
257
-
258
- return state
259
-
260
-
261
- def define_message_type(state):
262
- """
263
- Define type of message
264
-
265
- Args:
266
- state (dict): The current graph state
267
-
268
- Returns:
269
-
270
- """
271
- print("---MESSAGE CLASSIFICATION---")
272
- question = state["question"]
273
- msg_type_obj = message_classificator.invoke({"question": question})
274
- print(f"---MESSAGE TYPE: {msg_type_obj['message_type']} SYSTEM MESSAGE---\n {msg_type_obj['system_message']}")
275
- msg_type = msg_type_obj['message_type']
276
-
277
- if msg_type == "TAX":
278
- return "retrieve"
279
- # if msg_type == "":
280
- return "casual_chat"
281
- return "__end__"
282
-
283
- def generate(state):
284
- """
285
- Generate answer
286
-
287
- Args:
288
- state (dict): The current graph state
289
-
290
- Returns:
291
- state (dict): New key added to state, generation, that contains LLM generation, based on documents
292
- """
293
- print("---GENERATE---")
294
- question = state["question"]
295
- documents = state["documents"]
296
-
297
- # RAG generation
298
- generation = rag_chain.invoke({"context": documents, "question": question}, config={"configurable": {"conversation_id": "default_session", "user_id": "deafault_user"}},)
299
- return {"documents": documents, "question": question, "generation": generation}
300
-
301
-
302
- def grade_documents(state):
303
- """
304
- Determines whether the retrieved documents are relevant to the question.
305
-
306
- Args:
307
- state (dict): The current graph state
308
-
309
- Returns:
310
- state (dict): Updates documents key with only filtered relevant documents
311
- """
312
-
313
- print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
314
- question = state["question"]
315
- documents = state["documents"]
316
-
317
- # Score each doc
318
- filtered_docs = []
319
- count_docs = len(documents)
320
- for ind_d in range(0, count_docs, 3):
321
- d_1 = documents[ind_d] if ind_d < count_docs else None
322
- d_2 = documents[ind_d + 1] if ind_d + 1 < count_docs else None
323
- d_3 = documents[ind_d + 2] if ind_d + 2 < count_docs else None
324
- scores = retrieval_grader_3_docs.invoke(
325
- {"question": question, "document_1": d_1, "document_2": d_2, "document_3": d_3}
326
- )
327
- for j in range(len(scores)):
328
- grade = scores[j]["score"]
329
- if grade > 0.7:
330
- print(f"---GRADE: DOCUMENT RELEVANT--- GRADE: {grade}")
331
- filtered_docs.append(documents[ind_d + j])
332
- else:
333
- print("---GRADE: DOCUMENT NOT RELEVANT---")
334
- is_fuse = len(filtered_docs) / len(documents) <= 0.5
335
-
336
- return {"documents": filtered_docs, "question": question}
337
-
338
-
339
- def make_collapsable_source_message(doc: Dict):
340
- file_path = doc.metadata.get("file_name", "")
341
- file_name = file_path.replace(".pdf", "")
342
- chapter_title = doc.metadata.get("chapter_title", None)
343
- page_num = doc.metadata.get("first_page_num", None)
344
- title = f"""
345
- {file_name}
346
- {f": {chapter_title} " if chapter_title is not None else ""}
347
- {f"Стр. {page_num} " if page_num is not None else ""}
348
- """.replace("\n", " ")
349
- content = doc.page_content.replace("\n\n", "\n")
350
-
351
- if page_num is None:
352
- message = rf"""
353
- <details>
354
- <summary>{title}</summary>
355
- {str(content)}
356
- </details>
357
- """
358
- else:
359
- base_url = "http://localhost:5000/sta"
360
- url = f"{base_url}?file={file_path}&#page={page_num}&zoom=90&toolbar=0"
361
- message = f"""
362
- <a class="open_pdf" href='{url}' onclick="return openPdf('{url}')">{title}</a>
363
- """
364
-
365
- return message
366
-
367
-
368
- def add_sources(state):
369
- """
370
- Determines whether the retrieved documents are relevant to the question.
371
-
372
- Args:
373
- state (dict): The current graph state
374
-
375
- Returns:
376
- state (dict): Add collapsable sources
377
- """
378
- question = state["question"]
379
- documents = state["documents"]
380
- generation = state["generation"]
381
-
382
- sources_message = "<i></i>".join(map(make_collapsable_source_message, documents))
383
- extended_generation_message = f"{generation} {sources_message}"
384
- return {"documents": documents, "question": question, "generation": extended_generation_message}
385
-
386
-
387
- ### Edges
388
- workflow = StateGraph(GraphState)
389
-
390
- # Define the nodes
391
- workflow.add_node("start_point", start_point)
392
- workflow.add_node("retrieve", retrieve) # retrieve
393
- workflow.add_node("grade_documents", grade_documents) # grade documents
394
- workflow.add_node("generate", generate) # generate
395
- workflow.add_node("casual_chat", casual_chat) # simple chat
396
- workflow.add_node("add_sources", add_sources)
397
- # Build graph
398
- workflow.set_entry_point("start_point")
399
- workflow.add_conditional_edges("start_point", define_message_type)
400
- workflow.add_edge("retrieve", "grade_documents")
401
- workflow.add_edge("grade_documents", "generate")
402
- workflow.add_edge("generate", "add_sources")
403
- workflow.add_edge("add_sources", END)
404
- workflow.add_edge("casual_chat", END)
405
-
406
- # Compile
407
- app = workflow.compile()
408
 
409
  pdf_open_js = """
410
  <script>
@@ -417,54 +40,48 @@ def deploy():
417
  </script>
418
  """
419
 
420
- def print_source_documents(documents):
421
- return "\n\n".join([f"Взято из файла: {doc.metadata['file_name']} \n Metadata: {doc.metadata}" for doc in documents])
422
-
423
-
424
  dark_theme = DarkTheme()
425
  with gr.Blocks(head=pdf_open_js, fill_height=True, theme=dark_theme) as demo:
426
  with gr.Row():
427
  with gr.Column(scale=1):
428
- chatbot_rag = gr.Chatbot(label=f"RAG: llama3 + документы", height=740, sanitize_html=False, show_copy_button=True)
429
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=None, placeholder="Введите сообщение...", show_label=False, scale=4)
 
 
 
 
 
 
 
 
 
 
 
430
  with gr.Column(scale=1.5):
431
- pdf_output = gr.HTML("<iframe id='opener' width='100%' height='740px' src=''></iframe>")
 
 
432
  # clear = gr.Button("Clear")
433
 
434
  def user_rag(history, message):
435
  if message["text"] is not None:
436
  history.append((message["text"], None))
437
  return history, gr.update(value=None, interactive=False)
438
-
439
 
440
  def bot_rag(history):
441
  result = app.invoke({"question": history[-1][0]})
442
- form_answer = result["generation"].strip()
443
  history[-1][1] = form_answer
444
  return history
445
 
446
- # def bot_llm(history):
447
- # result = casual_llm.invoke({"question": history[-1][0]}, config={"configurable": {"conversation_id": "default_session_standalone", "user_id": "deafault_user_standalone"}})
448
- # history[-1][1] = result.strip()
449
- # return history
450
-
451
- chat_input.submit(user_rag, [chatbot_rag, chat_input], [chatbot_rag, chat_input], queue=False).then(
452
- bot_rag, chatbot_rag, chatbot_rag
453
- ).then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
454
- # pdf_output.change(lambda x: gr.HTML(pdf_open_js), chatbot_rag, pdf_output, queue=False)
455
-
456
- # chat_input.submit(user_llm, [chatbot_llm, chat_input], [chatbot_llm, chat_input], queue=False).then(
457
- # bot_llm, chatbot_llm, chatbot_llm
458
- # ).then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
459
- # clear.click(lambda: None, None, chatbot_rag, queue=False)
460
- # clear.click(lambda: None, None, chatbot_llm, queue=False)
461
 
462
  demo.launch(share=True)
463
 
464
 
465
  if __name__ == "__main__":
466
- # parser = argparse.ArgumentParser(description='Deploy llm chat')
467
- # parser.add_argument('--model_name', metavar='M', type=str,
468
- # help='model name as: openai/gpt-3.5-turbo')
469
-
470
- deploy()
 
10
  from langchain.schema import Document
11
  from langchain.prompts import PromptTemplate
12
  from langchain_core.output_parsers import JsonOutputParser
 
13
  from langchain_core.output_parsers import StrOutputParser
14
  from langchain_pinecone import PineconeVectorStore
15
  from typing_extensions import TypedDict
16
  from typing import Dict, List
 
17
  import warnings
18
 
 
 
19
  from lib.gradio_custom_theme import DarkTheme
20
+ from lib.graph import build_workflow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ warnings.filterwarnings("ignore")
24
+ from dotenv import load_dotenv
25
 
26
+ load_dotenv()
 
 
 
27
 
 
 
 
 
 
 
 
 
28
 
29
  def deploy():
30
+ app = build_workflow()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  pdf_open_js = """
33
  <script>
 
40
  </script>
41
  """
42
 
 
 
 
 
43
  dark_theme = DarkTheme()
44
  with gr.Blocks(head=pdf_open_js, fill_height=True, theme=dark_theme) as demo:
45
  with gr.Row():
46
  with gr.Column(scale=1):
47
+ chatbot_rag = gr.Chatbot(
48
+ label=f"RAG: llama3 + документы",
49
+ height=740,
50
+ sanitize_html=False,
51
+ show_copy_button=True,
52
+ )
53
+ chat_input = gr.MultimodalTextbox(
54
+ interactive=True,
55
+ file_types=None,
56
+ placeholder="Введите сообщение...",
57
+ show_label=False,
58
+ scale=4,
59
+ )
60
  with gr.Column(scale=1.5):
61
+ pdf_output = gr.HTML(
62
+ "<iframe id='opener' width='100%' height='740px' src=''></iframe>"
63
+ )
64
  # clear = gr.Button("Clear")
65
 
66
  def user_rag(history, message):
67
  if message["text"] is not None:
68
  history.append((message["text"], None))
69
  return history, gr.update(value=None, interactive=False)
 
70
 
71
  def bot_rag(history):
72
  result = app.invoke({"question": history[-1][0]})
73
+ form_answer = result["generation"].strip()
74
  history[-1][1] = form_answer
75
  return history
76
 
77
+ chat_input.submit(
78
+ user_rag, [chatbot_rag, chat_input], [chatbot_rag, chat_input], queue=False
79
+ ).then(bot_rag, chatbot_rag, chatbot_rag).then(
80
+ lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]
81
+ )
 
 
 
 
 
 
 
 
 
 
82
 
83
  demo.launch(share=True)
84
 
85
 
86
  if __name__ == "__main__":
87
+ deploy()
 
 
 
 
lib/embedding.py CHANGED
@@ -2,13 +2,17 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
2
  from dotenv import load_dotenv
3
 
4
  load_dotenv()
 
 
5
  def build_embedding(model_name: str):
6
- embedding = HuggingFaceEmbeddings(model_name=model_name, \
7
- # model_kwargs={"device": "cuda"}, \
8
- encode_kwargs={"normalize_embeddings": True})
 
9
  embedding.show_progress = True
10
  return embedding
11
 
 
12
  class EmbeddingBuilder:
13
  def __init__(self, model_name: str, device: str = "cpu"):
14
  self.model_name = model_name
@@ -16,12 +20,14 @@ class EmbeddingBuilder:
16
 
17
  def __enter__(self):
18
  # Initialize resources
19
- embedding = HuggingFaceEmbeddings(model_name=self.model_name, \
20
- model_kwargs={"device": self.device}, \
21
- encode_kwargs={"normalize_embeddings": True})
 
 
22
  embedding.show_progress = True
23
  return embedding
24
 
25
  def __exit__(self, exc_type, exc_value, traceback):
26
  # Cleanup resources if necessary
27
- pass
 
2
  from dotenv import load_dotenv
3
 
4
  load_dotenv()
5
+
6
+
7
  def build_embedding(model_name: str):
8
+ embedding = HuggingFaceEmbeddings(
9
+ model_name=model_name, # model_kwargs={"device": "cuda"}, \
10
+ encode_kwargs={"normalize_embeddings": True},
11
+ )
12
  embedding.show_progress = True
13
  return embedding
14
 
15
+
16
  class EmbeddingBuilder:
17
  def __init__(self, model_name: str, device: str = "cpu"):
18
  self.model_name = model_name
 
20
 
21
  def __enter__(self):
22
  # Initialize resources
23
+ embedding = HuggingFaceEmbeddings(
24
+ model_name=self.model_name,
25
+ model_kwargs={"device": self.device},
26
+ encode_kwargs={"normalize_embeddings": True},
27
+ )
28
  embedding.show_progress = True
29
  return embedding
30
 
31
  def __exit__(self, exc_type, exc_value, traceback):
32
  # Cleanup resources if necessary
33
+ pass
lib/gradio_custom_theme.py CHANGED
@@ -14,16 +14,12 @@ class DarkTheme(Base):
14
  spacing_size: sizes.Size | str = sizes.spacing_md,
15
  radius_size: sizes.Size | str = sizes.radius_md,
16
  text_size: sizes.Size | str = sizes.text_lg,
17
- font: fonts.Font
18
- | str
19
- | Iterable[fonts.Font | str] = (
20
  fonts.GoogleFont("Quicksand"),
21
  "ui-sans-serif",
22
  "sans-serif",
23
  ),
24
- font_mono: fonts.Font
25
- | str
26
- | Iterable[fonts.Font | str] = (
27
  fonts.GoogleFont("Roboto"),
28
  "ui-monospace",
29
  "monospace",
 
14
  spacing_size: sizes.Size | str = sizes.spacing_md,
15
  radius_size: sizes.Size | str = sizes.radius_md,
16
  text_size: sizes.Size | str = sizes.text_lg,
17
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
 
18
  fonts.GoogleFont("Quicksand"),
19
  "ui-sans-serif",
20
  "sans-serif",
21
  ),
22
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
 
 
23
  fonts.GoogleFont("Roboto"),
24
  "ui-monospace",
25
  "monospace",
lib/graph.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from typing_extensions import TypedDict
3
+
4
+ from langchain_core.documents import Document
5
+
6
+ from lib.model_builder import ModelBuilderV2
7
+ from lib.vectorestores import FAISSBuilder
8
+ from lib.model_builder import ModelBuilderV2
9
+ from lib.vectorestores import FAISSBuilder
10
+ from langchain.retrievers.multi_query import MultiQueryRetriever
11
+ from lib.runnables import (
12
+ casual_llm,
13
+ retrieval_grader_3,
14
+ rag_chain,
15
+ message_classificator,
16
+ )
17
+ from langgraph.graph import END, StateGraph
18
+
19
+
20
+ class GraphState(TypedDict):
21
+ """
22
+ Represents the state of our graph.
23
+
24
+ Attributes:
25
+ question: question
26
+ generation: LLM generation
27
+ web_search: whether to add search
28
+ documents: list of documents
29
+ """
30
+
31
+ question: str
32
+ generation: str
33
+ documents: List[Document]
34
+ filtered_documets: List[Document]
35
+ is_fuse: bool
36
+ count_regenerations: int
37
+
38
+
39
+ def combine_vectors(vectors):
40
+ result = []
41
+ vec1_count = len(vectors["vector1"])
42
+ # vec2_count = len(vectors["vector2"])
43
+ for i in range(vec1_count):
44
+ if i < vec1_count:
45
+ result.append(vectors["vector1"][i])
46
+ # if i < vec2_count:
47
+ # result.append(vectors['vector2'][i])
48
+ return result
49
+
50
+
51
+ def retrieve(state):
52
+ """
53
+ Retrieve documents
54
+
55
+ Args:
56
+ state (dict): The current graph state
57
+
58
+ Returns:
59
+ state (dict): New key added to state, documents, that contains retrieved documents
60
+ """
61
+ print("---RETRIEVE---")
62
+ question = state["question"]
63
+
64
+ # Retrieval
65
+ with FAISSBuilder() as faiss_retriever:
66
+ with ModelBuilderV2("openchat/openchat-7b") as mq_llm:
67
+ retriever = MultiQueryRetriever.from_llm(
68
+ retriever=faiss_retriever, llm=mq_llm
69
+ )
70
+ documents = retriever.get_relevant_documents(question)
71
+ return {"documents": documents, "question": question}
72
+
73
+
74
+ def start_point(state):
75
+ """
76
+ Start point, just return state
77
+
78
+ Args:
79
+ state (dict): The current graph state
80
+
81
+ Returns:
82
+ state (dict): The current graph state
83
+ """
84
+ return state
85
+
86
+
87
+ def casual_chat(state):
88
+ """
89
+ Define type of message
90
+
91
+ Args:
92
+ state (dict): The current graph state
93
+
94
+ Returns:
95
+ state (dict): New key added to state, generation, that contains message with casual answer
96
+ """
97
+ question = state["question"]
98
+ print("---CASUAL CHAT---")
99
+ generation = casual_llm.invoke(
100
+ {"question": question},
101
+ config={
102
+ "configurable": {
103
+ "conversation_id": "default_session",
104
+ "user_id": "deafault_user",
105
+ }
106
+ },
107
+ )
108
+ state["generation"] = generation
109
+
110
+ return state
111
+
112
+
113
+ def define_message_type(state):
114
+ """
115
+ Define type of message
116
+
117
+ Args:
118
+ state (dict): The current graph state
119
+
120
+ Returns:
121
+
122
+ """
123
+ print("---MESSAGE CLASSIFICATION---")
124
+ question = state["question"]
125
+ msg_type_obj = message_classificator.invoke({"question": question})
126
+ print(
127
+ f"---MESSAGE TYPE: {msg_type_obj['message_type']} SYSTEM MESSAGE---\n {msg_type_obj['system_message']}"
128
+ )
129
+ msg_type = msg_type_obj["message_type"]
130
+
131
+ if msg_type == "TAX":
132
+ return "retrieve"
133
+ # if msg_type == "":
134
+ return "casual_chat"
135
+ return "__end__"
136
+
137
+
138
+ def generate(state):
139
+ """
140
+ Generate answer
141
+
142
+ Args:
143
+ state (dict): The current graph state
144
+
145
+ Returns:
146
+ state (dict): New key added to state, generation, that contains LLM generation, based on documents
147
+ """
148
+ print("---GENERATE---")
149
+ question = state["question"]
150
+ documents = state["documents"]
151
+
152
+ # RAG generation
153
+ generation = rag_chain.invoke(
154
+ {"context": documents, "question": question},
155
+ config={
156
+ "configurable": {
157
+ "conversation_id": "default_session",
158
+ "user_id": "deafault_user",
159
+ }
160
+ },
161
+ )
162
+ return {"documents": documents, "question": question, "generation": generation}
163
+
164
+
165
+ def grade_documents(state):
166
+ """
167
+ Determines whether the retrieved documents are relevant to the question.
168
+
169
+ Args:
170
+ state (dict): The current graph state
171
+
172
+ Returns:
173
+ state (dict): Updates documents key with only filtered relevant documents
174
+ """
175
+
176
+ print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
177
+ question = state["question"]
178
+ documents = state["documents"]
179
+
180
+ # Score each doc
181
+ filtered_docs = []
182
+ count_docs = len(documents)
183
+ for ind_d in range(0, count_docs, 3):
184
+ d_1 = documents[ind_d] if ind_d < count_docs else None
185
+ d_2 = documents[ind_d + 1] if ind_d + 1 < count_docs else None
186
+ d_3 = documents[ind_d + 2] if ind_d + 2 < count_docs else None
187
+ scores = retrieval_grader_3.invoke(
188
+ {
189
+ "question": question,
190
+ "document_1": d_1,
191
+ "document_2": d_2,
192
+ "document_3": d_3,
193
+ }
194
+ )
195
+ for j in range(len(scores)):
196
+ grade = scores[j]["score"]
197
+ if grade > 0.7:
198
+ print(f"---GRADE: DOCUMENT RELEVANT--- GRADE: {grade}")
199
+ filtered_docs.append(documents[ind_d + j])
200
+ else:
201
+ print("---GRADE: DOCUMENT NOT RELEVANT---")
202
+ is_fuse = len(filtered_docs) / len(documents) <= 0.5
203
+
204
+ return {"documents": filtered_docs, "question": question}
205
+
206
+
207
+ def make_collapsable_source_message(doc: Dict):
208
+ file_path = doc.metadata.get("file_name", "")
209
+ file_name = file_path.replace(".pdf", "")
210
+ chapter_title = doc.metadata.get("chapter_title", None)
211
+ page_num = doc.metadata.get("first_page_num", None)
212
+ title = f"""
213
+ {file_name}
214
+ {f": {chapter_title} " if chapter_title is not None else ""}
215
+ {f"Стр. {page_num} " if page_num is not None else ""}
216
+ """.replace(
217
+ "\n", " "
218
+ )
219
+ content = doc.page_content.replace("\n\n", "\n")
220
+
221
+ if page_num is None:
222
+ message = rf"""
223
+ <details>
224
+ <summary>{title}</summary>
225
+ {str(content)}
226
+ </details>
227
+ """
228
+ else:
229
+ base_url = "http://localhost:5000/sta"
230
+ url = f"{base_url}?file={file_path}&#page={page_num}&zoom=90&toolbar=0"
231
+ message = f"""
232
+ <a class="open_pdf" href='{url}' onclick="return openPdf('{url}')">{title}</a>
233
+ """
234
+
235
+ return message
236
+
237
+
238
+ def add_sources(state):
239
+ """
240
+ Determines whether the retrieved documents are relevant to the question.
241
+
242
+ Args:
243
+ state (dict): The current graph state
244
+
245
+ Returns:
246
+ state (dict): Add collapsable sources
247
+ """
248
+ question = state["question"]
249
+ documents = state["documents"]
250
+ generation = state["generation"]
251
+
252
+ sources_message = "<i></i>".join(map(make_collapsable_source_message, documents))
253
+ extended_generation_message = f"{generation} {sources_message}"
254
+ return {
255
+ "documents": documents,
256
+ "question": question,
257
+ "generation": extended_generation_message,
258
+ }
259
+
260
+
261
+ def build_workflow():
262
+ workflow = StateGraph(GraphState)
263
+
264
+ # Define the nodes
265
+ workflow.add_node("start_point", start_point)
266
+ workflow.add_node("retrieve", retrieve) # retrieve
267
+ workflow.add_node("grade_documents", grade_documents) # grade documents
268
+ workflow.add_node("generate", generate) # generate
269
+ workflow.add_node("casual_chat", casual_chat) # simple chat
270
+ workflow.add_node("add_sources", add_sources)
271
+ # Build graph
272
+ workflow.set_entry_point("start_point")
273
+ workflow.add_conditional_edges("start_point", define_message_type)
274
+ workflow.add_edge("retrieve", "grade_documents")
275
+ workflow.add_edge("grade_documents", "generate")
276
+ workflow.add_edge("generate", "add_sources")
277
+ workflow.add_edge("add_sources", END)
278
+ workflow.add_edge("casual_chat", END)
279
+ return workflow.compile()
lib/model_builder.py CHANGED
@@ -3,16 +3,24 @@ from langchain_openai import ChatOpenAI
3
  from dotenv import load_dotenv
4
 
5
  load_dotenv()
6
- VSEGPT_KEY = os.getenv('VSEGPT_KEY')
7
- OPENAI_BASE = os.getenv('OPENAI_BASE')
 
8
 
9
  class ModelBuilder:
10
  def createVseGptModel(model, temperature):
11
- return ChatOpenAI(temperature=temperature, model_name=model, \
12
- api_key=VSEGPT_KEY, base_url = OPENAI_BASE)
 
 
 
 
 
13
 
14
  class ModelBuilderV2:
15
- def __init__(self, model_name: str, temperature=0, api_key=VSEGPT_KEY, base_url=OPENAI_BASE):
 
 
16
  self.model_name = model_name
17
  self.temperature = temperature
18
  self.api_key = api_key
@@ -20,9 +28,13 @@ class ModelBuilderV2:
20
 
21
  def __enter__(self):
22
  # Initialize resources
23
- return ChatOpenAI(temperature=self.temperature, model_name=self.model_name, \
24
- api_key=VSEGPT_KEY, base_url = OPENAI_BASE)
 
 
 
 
25
 
26
  def __exit__(self, exc_type, exc_value, traceback):
27
  # Cleanup resources if necessary
28
- pass
 
3
  from dotenv import load_dotenv
4
 
5
  load_dotenv()
6
+ VSEGPT_KEY = os.getenv("VSEGPT_KEY")
7
+ OPENAI_BASE = os.getenv("OPENAI_BASE")
8
+
9
 
10
  class ModelBuilder:
11
  def createVseGptModel(model, temperature):
12
+ return ChatOpenAI(
13
+ temperature=temperature,
14
+ model_name=model,
15
+ api_key=VSEGPT_KEY,
16
+ base_url=OPENAI_BASE,
17
+ )
18
+
19
 
20
  class ModelBuilderV2:
21
+ def __init__(
22
+ self, model_name: str, temperature=0, api_key=VSEGPT_KEY, base_url=OPENAI_BASE
23
+ ):
24
  self.model_name = model_name
25
  self.temperature = temperature
26
  self.api_key = api_key
 
28
 
29
  def __enter__(self):
30
  # Initialize resources
31
+ return ChatOpenAI(
32
+ temperature=self.temperature,
33
+ model_name=self.model_name,
34
+ api_key=VSEGPT_KEY,
35
+ base_url=OPENAI_BASE,
36
+ )
37
 
38
  def __exit__(self, exc_type, exc_value, traceback):
39
  # Cleanup resources if necessary
40
+ pass
lib/prompts.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import PromptTemplate
2
+
3
+
4
+ casual_prompt = PromptTemplate(
5
+ template="""Just answer a question as casual chatter. \n
6
+ Here is the user question: {question} \n
7
+ Always reply in Russian. Leave clear answer, without any addition.
8
+ Chat history:
9
+ {chat_history}
10
+ Answer:
11
+ """,
12
+ input_variables=["question", "chat_history"],
13
+ )
14
+
15
+ grader_3_doc_prompt = PromptTemplate(
16
+ template="""You are a grader assessing relevance of a retrieved documents to a user question. \n
17
+ Here is the first retrieved document: \n\n {document_1} \n\n
18
+ Here is the second retrieved document: \n\n {document_2} \n\n
19
+ Here is the third retrieved document: \n\n {document_3} \n\n
20
+ Here is the user question: {question} \n
21
+ If the document contains keywords related to the user question, grade it as relevant. \n
22
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
23
+ For each! document give a score from 0 to 1, score to indicate whether the document is relevant to the question. \n
24
+ Provide the scores as a JSON list that contains an objects with single key 'score' and no premable or explanation.""",
25
+ input_variables=["question", "document_1", "document_2", "document_3"],
26
+ )
27
+
28
+ rag_assistant_prompt = PromptTemplate(
29
+ template="""
30
+ SYSTEM: You are an assistant for question-answering tasks.
31
+ Use the following pieces of retrieved context to answer the question.
32
+ Use previous messages then current message higly likely
33
+ If you don't find the answer in the context, transform the question ans ask the user to specify his qusetion.
34
+
35
+ Keep the answer concise.
36
+ Print a most possible topic of conversation.
37
+ Always reply in Russian, all text must be in Russian!
38
+
39
+ Context: {context}
40
+
41
+ Previous messages: {chat_history}
42
+
43
+ Question: {question}
44
+
45
+ Answer:
46
+
47
+ """,
48
+ input_variables=["context", "chat_history", "question"],
49
+ )
50
+
51
+ classificator_question_prompt = PromptTemplate(
52
+ template="""
53
+ SYSTEM: You are message classificator. You classify message into classes:
54
+ - TAX: tax payment or other similar jura topic;
55
+ - CASUAL: casual conversation, any generic question that can't fit in the conversation;
56
+ - SPAM: any rude, useless messages.
57
+ Here is a question: {question}
58
+ Provide the answer as a JSON object that contains an object with key 'message_type' and with key 'system_message' with short remark about message and no other premable or explanation.
59
+ """,
60
+ input_variables=["question"],
61
+ )
lib/runnables.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+ from lib.model_builder import ModelBuilderV2
4
+ from lib.prompts import (
5
+ casual_prompt,
6
+ grader_3_doc_prompt,
7
+ rag_assistant_prompt,
8
+ classificator_question_prompt,
9
+ )
10
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
11
+ from langchain_core.runnables.history import RunnableWithMessageHistory
12
+ from langchain_core.chat_history import (
13
+ BaseChatMessageHistory,
14
+ InMemoryChatMessageHistory,
15
+ )
16
+ from langchain_core.runnables import ConfigurableFieldSpec
17
+
18
+ store = {}
19
+
20
+
21
+ def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHistory:
22
+ if (user_id, conversation_id) not in store:
23
+ store[(user_id, conversation_id)] = InMemoryChatMessageHistory()
24
+ return store[(user_id, conversation_id)]
25
+
26
+
27
+ class ModelConfig:
28
+ def __init__(self, model_name, temperature=0.7):
29
+ self.model_name = model_name
30
+ self.temperature = temperature
31
+
32
+
33
+ class ConfigField:
34
+ def __init__(self, id, annotation, name, description, default, is_shared):
35
+ self.id = id
36
+ self.annotation = annotation
37
+ self.name = name
38
+ self.description = description
39
+ self.default = default
40
+ self.is_shared = is_shared
41
+
42
+
43
+ USER_ID_FIELD = ConfigurableFieldSpec(
44
+ id="user_id",
45
+ annotation=str,
46
+ name="User ID",
47
+ description="Unique identifier for the user.",
48
+ default="default_user",
49
+ is_shared=True,
50
+ )
51
+
52
+ CONVERSATION_ID_FIELD = ConfigurableFieldSpec(
53
+ id="conversation_id",
54
+ annotation=str,
55
+ name="Conversation ID",
56
+ description="Unique identifier for the conversation.",
57
+ default="default_session",
58
+ is_shared=True,
59
+ )
60
+
61
+
62
+ def create_runnable_with_history(
63
+ prompt, llm, input_messages_key, history_messages_key, history_factory_config
64
+ ):
65
+ return RunnableWithMessageHistory(
66
+ prompt | llm,
67
+ get_session_history,
68
+ input_messages_key=input_messages_key,
69
+ history_messages_key=history_messages_key,
70
+ history_factory_config=history_factory_config,
71
+ )
72
+
73
+
74
+ @contextlib.contextmanager
75
+ def create_model_builder(config):
76
+ with ModelBuilderV2(config.model_name, config.temperature) as llm:
77
+ yield llm
78
+ # try:
79
+ # yield llm
80
+ # finally:
81
+ # llm.release() # Assuming ModelBuilderV2 has a release method to clear resources
82
+
83
+
84
+ casual_config = ModelConfig("openchat/openchat-7b", 0.7)
85
+ retrieval_config = ModelConfig("cohere/command-r")
86
+ rag_config = ModelConfig("mistralai/mixtral-8x22b-instruct")
87
+ classificator_msg_config = ModelConfig("openchat/openchat-7b")
88
+
89
+ history_config = [USER_ID_FIELD, CONVERSATION_ID_FIELD]
90
+
91
+ with create_model_builder(casual_config) as llm:
92
+ casual_llm = (
93
+ create_runnable_with_history(
94
+ casual_prompt, llm, "question", "chat_history", history_config
95
+ )
96
+ | StrOutputParser()
97
+ )
98
+
99
+ with create_model_builder(retrieval_config) as llm:
100
+ retrieval_grader_3 = grader_3_doc_prompt | llm | JsonOutputParser()
101
+
102
+ with create_model_builder(rag_config) as llm:
103
+ rag_chain = (
104
+ create_runnable_with_history(
105
+ rag_assistant_prompt, llm, "question", "chat_history", history_config
106
+ )
107
+ | StrOutputParser()
108
+ )
109
+
110
+
111
+ with create_model_builder(classificator_msg_config) as decider_llm:
112
+ message_classificator = (
113
+ classificator_question_prompt | decider_llm | JsonOutputParser()
114
+ )
lib/vectorestores.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_pinecone import PineconeVectorStore
2
+ from lib.embedding import EmbeddingBuilder
3
+ from langchain_community.vectorstores import FAISS
4
+
5
+
6
+ class FAISSBuilder:
7
+ def __init__(
8
+ self,
9
+ embedding_model: str = "intfloat/multilingual-e5-large",
10
+ local_path: str = "data/faiss_nk_28_05",
11
+ ):
12
+ self.embedding_model = embedding_model
13
+ self.local_path = local_path
14
+
15
+ def __enter__(self):
16
+ # Initialize resources
17
+ with EmbeddingBuilder(self.embedding_model) as rag_emb:
18
+ faiss_db = FAISS.load_local(
19
+ self.local_path, rag_emb, allow_dangerous_deserialization=True
20
+ )
21
+ return faiss_db.as_retriever()
22
+
23
+ def __exit__(self, exc_type, exc_value, traceback):
24
+ # Cleanup resources if necessary
25
+ pass
26
+
27
+
28
+ class PineConeBuilder:
29
+ def __init__(self, index_name, embedding_model):
30
+ self.index_name = index_name
31
+ self.embedding_model = embedding_model
32
+
33
+ def __enter__(self):
34
+ with EmbeddingBuilder(self.embedding_model) as embeddings:
35
+ pc_db = PineconeVectorStore.from_existing_index(self.index_name, embeddings)
36
+ return pc_db.as_retriever()
37
+
38
+ def __exit__(self, exc_type, exc_value, traceback):
39
+ # Cleanup resources if necessary
40
+ pass