miniondenis commited on
Commit
eb56c9e
1 Parent(s): 9c32f1b

feat: add configurable

Browse files
Files changed (5) hide show
  1. config.yml +14 -0
  2. lib/config.py +27 -0
  3. lib/graph.py +54 -3
  4. lib/prompts.py +1 -2
  5. lib/runnables.py +10 -8
config.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ casual_conversation:
3
+ model: openchat/openchat-7b
4
+ temperature: 0.7
5
+ multiquery_retrieval:
6
+ model: openchat/openchat-7b
7
+ temperature: 0.3
8
+ classificator_msg:
9
+ model: openchat/openchat-7b
10
+ temperature: 0
11
+ rag:
12
+ model: cohere/command-r
13
+ temperature: 0
14
+
lib/config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import yaml
4
+
5
+
6
+ class Config:
7
+ def __init__(self, config_path: Path):
8
+ self.config_path = config_path
9
+ self._config = self._load_config()
10
+
11
+ def _load_config(self) -> dict:
12
+ with open(self.config_path, "r") as file:
13
+ return yaml.safe_load(file)
14
+
15
+ def get(self, *keys, default=None):
16
+ config = self._config
17
+ for key in keys:
18
+ config = config.get(key, default)
19
+ if config is default:
20
+ break
21
+ return config
22
+
23
+ def __getitem__(self, item: str):
24
+ return self._config.get(item)
25
+
26
+ def __repr__(self) -> str:
27
+ return f"Config({self._config})"
lib/graph.py CHANGED
@@ -15,6 +15,25 @@ from lib.runnables import (
15
  message_classificator,
16
  )
17
  from langgraph.graph import END, StateGraph
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  class GraphState(TypedDict):
@@ -32,7 +51,6 @@ class GraphState(TypedDict):
32
  generation: str
33
  documents: List[Document]
34
  filtered_documets: List[Document]
35
- is_fuse: bool
36
  count_regenerations: int
37
 
38
 
@@ -162,6 +180,40 @@ def generate(state):
162
  return {"documents": documents, "question": question, "generation": generation}
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def grade_documents(state):
166
  """
167
  Determines whether the retrieved documents are relevant to the question.
@@ -199,7 +251,6 @@ def grade_documents(state):
199
  filtered_docs.append(documents[ind_d + j])
200
  else:
201
  print("---GRADE: DOCUMENT NOT RELEVANT---")
202
- is_fuse = len(filtered_docs) / len(documents) <= 0.5
203
 
204
  return {"documents": filtered_docs, "question": question}
205
 
@@ -264,7 +315,7 @@ def build_workflow():
264
  # Define the nodes
265
  workflow.add_node("start_point", start_point)
266
  workflow.add_node("retrieve", retrieve) # retrieve
267
- workflow.add_node("grade_documents", grade_documents) # grade documents
268
  workflow.add_node("generate", generate) # generate
269
  workflow.add_node("casual_chat", casual_chat) # simple chat
270
  workflow.add_node("add_sources", add_sources)
 
15
  message_classificator,
16
  )
17
  from langgraph.graph import END, StateGraph
18
+ from transformers import AutoModel, AutoTokenizer
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ model_name = "intfloat/multilingual-e5-large"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = AutoModel.from_pretrained(model_name)
25
+ device = torch.device("cuda")
26
+ model.to(device)
27
+ SIMILARITY_TRESHHOLD = 0.8
28
+
29
+
30
+ def get_embeddings(texts):
31
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
32
+ inputs.to(device)
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+ embeddings = torch.mean(outputs.last_hidden_state, dim=1)
36
+ return embeddings
37
 
38
 
39
  class GraphState(TypedDict):
 
51
  generation: str
52
  documents: List[Document]
53
  filtered_documets: List[Document]
 
54
  count_regenerations: int
55
 
56
 
 
180
  return {"documents": documents, "question": question, "generation": generation}
181
 
182
 
183
+ def grade_documents_by_embed(state):
184
+ """
185
+ Determines whether the retrieved documents are relevant to the question.
186
+
187
+ Args:
188
+ state (dict): The current graph state
189
+
190
+ Returns:
191
+ state (dict): Updates documents key with only filtered relevant documents
192
+ """
193
+ question = state["question"]
194
+ documents = state["documents"]
195
+
196
+ # Score each doc
197
+ filtered_docs = []
198
+
199
+ query_embedding = get_embeddings([question])
200
+ document_embeddings = get_embeddings([doc.page_content for doc in documents])
201
+
202
+ # Calculate cosine similarity
203
+ similarity_scores = F.cosine_similarity(query_embedding, document_embeddings)
204
+ for doc, score in zip(documents, similarity_scores):
205
+ if score >= SIMILARITY_TRESHHOLD:
206
+ filtered_docs.append(doc)
207
+ sorted_documents = [
208
+ doc[0]
209
+ for doc in sorted(
210
+ zip(documents, similarity_scores), key=lambda x: x[1], reverse=True
211
+ )
212
+ ]
213
+ cut_off_documents = sorted_documents[:5]
214
+ return {"documents": cut_off_documents, "question": question}
215
+
216
+
217
  def grade_documents(state):
218
  """
219
  Determines whether the retrieved documents are relevant to the question.
 
251
  filtered_docs.append(documents[ind_d + j])
252
  else:
253
  print("---GRADE: DOCUMENT NOT RELEVANT---")
 
254
 
255
  return {"documents": filtered_docs, "question": question}
256
 
 
315
  # Define the nodes
316
  workflow.add_node("start_point", start_point)
317
  workflow.add_node("retrieve", retrieve) # retrieve
318
+ workflow.add_node("grade_documents", grade_documents_by_embed) # grade documents
319
  workflow.add_node("generate", generate) # generate
320
  workflow.add_node("casual_chat", casual_chat) # simple chat
321
  workflow.add_node("add_sources", add_sources)
lib/prompts.py CHANGED
@@ -29,11 +29,10 @@ rag_assistant_prompt = PromptTemplate(
29
  template="""
30
  SYSTEM: You are an assistant for question-answering tasks.
31
  Use the following pieces of retrieved context to answer the question.
32
- Use previous messages then current message higly likely
33
  If you don't find the answer in the context, transform the question ans ask the user to specify his qusetion.
34
 
35
  Keep the answer concise.
36
- Print a most possible topic of conversation.
37
  Always reply in Russian, all text must be in Russian!
38
 
39
  Context: {context}
 
29
  template="""
30
  SYSTEM: You are an assistant for question-answering tasks.
31
  Use the following pieces of retrieved context to answer the question.
32
+ Use a 'Previous messages' as a part of context.
33
  If you don't find the answer in the context, transform the question ans ask the user to specify his qusetion.
34
 
35
  Keep the answer concise.
 
36
  Always reply in Russian, all text must be in Russian!
37
 
38
  Context: {context}
lib/runnables.py CHANGED
@@ -1,6 +1,7 @@
1
  import contextlib
2
 
3
  from lib.model_builder import ModelBuilderV2
 
4
  from lib.prompts import (
5
  casual_prompt,
6
  grader_3_doc_prompt,
@@ -16,6 +17,7 @@ from langchain_core.chat_history import (
16
  from langchain_core.runnables import ConfigurableFieldSpec
17
 
18
  store = {}
 
19
 
20
 
21
  def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHistory:
@@ -25,9 +27,9 @@ def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHi
25
 
26
 
27
  class ModelConfig:
28
- def __init__(self, model_name, temperature=0.7):
29
- self.model_name = model_name
30
- self.temperature = temperature
31
 
32
 
33
  class ConfigField:
@@ -81,10 +83,10 @@ def create_model_builder(config):
81
  # llm.release() # Assuming ModelBuilderV2 has a release method to clear resources
82
 
83
 
84
- casual_config = ModelConfig("openchat/openchat-7b", 0.7)
85
- retrieval_config = ModelConfig("cohere/command-r")
86
- rag_config = ModelConfig("mistralai/mixtral-8x22b-instruct")
87
- classificator_msg_config = ModelConfig("openchat/openchat-7b")
88
 
89
  history_config = [USER_ID_FIELD, CONVERSATION_ID_FIELD]
90
 
@@ -96,7 +98,7 @@ with create_model_builder(casual_config) as llm:
96
  | StrOutputParser()
97
  )
98
 
99
- with create_model_builder(retrieval_config) as llm:
100
  retrieval_grader_3 = grader_3_doc_prompt | llm | JsonOutputParser()
101
 
102
  with create_model_builder(rag_config) as llm:
 
1
  import contextlib
2
 
3
  from lib.model_builder import ModelBuilderV2
4
+ from lib.config import Config
5
  from lib.prompts import (
6
  casual_prompt,
7
  grader_3_doc_prompt,
 
17
  from langchain_core.runnables import ConfigurableFieldSpec
18
 
19
  store = {}
20
+ config = Config("config.yml")
21
 
22
 
23
  def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHistory:
 
27
 
28
 
29
  class ModelConfig:
30
+ def __init__(self, config_key):
31
+ self.model_name = config.get("models", config_key, "model")
32
+ self.temperature = config.get("models", config_key, "temperature")
33
 
34
 
35
  class ConfigField:
 
83
  # llm.release() # Assuming ModelBuilderV2 has a release method to clear resources
84
 
85
 
86
+ casual_config = ModelConfig("casual_conversation")
87
+ multiquery_config = ModelConfig("multiquery_retrieval")
88
+ rag_config = ModelConfig("rag")
89
+ classificator_msg_config = ModelConfig("classificator_msg")
90
 
91
  history_config = [USER_ID_FIELD, CONVERSATION_ID_FIELD]
92
 
 
98
  | StrOutputParser()
99
  )
100
 
101
+ with create_model_builder(multiquery_config) as llm:
102
  retrieval_grader_3 = grader_3_doc_prompt | llm | JsonOutputParser()
103
 
104
  with create_model_builder(rag_config) as llm: