Spaces:
Sleeping
Sleeping
miniondenis
commited on
Commit
•
eb56c9e
1
Parent(s):
9c32f1b
feat: add configurable
Browse files- config.yml +14 -0
- lib/config.py +27 -0
- lib/graph.py +54 -3
- lib/prompts.py +1 -2
- lib/runnables.py +10 -8
config.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models:
|
2 |
+
casual_conversation:
|
3 |
+
model: openchat/openchat-7b
|
4 |
+
temperature: 0.7
|
5 |
+
multiquery_retrieval:
|
6 |
+
model: openchat/openchat-7b
|
7 |
+
temperature: 0.3
|
8 |
+
classificator_msg:
|
9 |
+
model: openchat/openchat-7b
|
10 |
+
temperature: 0
|
11 |
+
rag:
|
12 |
+
model: cohere/command-r
|
13 |
+
temperature: 0
|
14 |
+
|
lib/config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
|
6 |
+
class Config:
|
7 |
+
def __init__(self, config_path: Path):
|
8 |
+
self.config_path = config_path
|
9 |
+
self._config = self._load_config()
|
10 |
+
|
11 |
+
def _load_config(self) -> dict:
|
12 |
+
with open(self.config_path, "r") as file:
|
13 |
+
return yaml.safe_load(file)
|
14 |
+
|
15 |
+
def get(self, *keys, default=None):
|
16 |
+
config = self._config
|
17 |
+
for key in keys:
|
18 |
+
config = config.get(key, default)
|
19 |
+
if config is default:
|
20 |
+
break
|
21 |
+
return config
|
22 |
+
|
23 |
+
def __getitem__(self, item: str):
|
24 |
+
return self._config.get(item)
|
25 |
+
|
26 |
+
def __repr__(self) -> str:
|
27 |
+
return f"Config({self._config})"
|
lib/graph.py
CHANGED
@@ -15,6 +15,25 @@ from lib.runnables import (
|
|
15 |
message_classificator,
|
16 |
)
|
17 |
from langgraph.graph import END, StateGraph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class GraphState(TypedDict):
|
@@ -32,7 +51,6 @@ class GraphState(TypedDict):
|
|
32 |
generation: str
|
33 |
documents: List[Document]
|
34 |
filtered_documets: List[Document]
|
35 |
-
is_fuse: bool
|
36 |
count_regenerations: int
|
37 |
|
38 |
|
@@ -162,6 +180,40 @@ def generate(state):
|
|
162 |
return {"documents": documents, "question": question, "generation": generation}
|
163 |
|
164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
def grade_documents(state):
|
166 |
"""
|
167 |
Determines whether the retrieved documents are relevant to the question.
|
@@ -199,7 +251,6 @@ def grade_documents(state):
|
|
199 |
filtered_docs.append(documents[ind_d + j])
|
200 |
else:
|
201 |
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
202 |
-
is_fuse = len(filtered_docs) / len(documents) <= 0.5
|
203 |
|
204 |
return {"documents": filtered_docs, "question": question}
|
205 |
|
@@ -264,7 +315,7 @@ def build_workflow():
|
|
264 |
# Define the nodes
|
265 |
workflow.add_node("start_point", start_point)
|
266 |
workflow.add_node("retrieve", retrieve) # retrieve
|
267 |
-
workflow.add_node("grade_documents",
|
268 |
workflow.add_node("generate", generate) # generate
|
269 |
workflow.add_node("casual_chat", casual_chat) # simple chat
|
270 |
workflow.add_node("add_sources", add_sources)
|
|
|
15 |
message_classificator,
|
16 |
)
|
17 |
from langgraph.graph import END, StateGraph
|
18 |
+
from transformers import AutoModel, AutoTokenizer
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
model_name = "intfloat/multilingual-e5-large"
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
+
model = AutoModel.from_pretrained(model_name)
|
25 |
+
device = torch.device("cuda")
|
26 |
+
model.to(device)
|
27 |
+
SIMILARITY_TRESHHOLD = 0.8
|
28 |
+
|
29 |
+
|
30 |
+
def get_embeddings(texts):
|
31 |
+
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
32 |
+
inputs.to(device)
|
33 |
+
with torch.no_grad():
|
34 |
+
outputs = model(**inputs)
|
35 |
+
embeddings = torch.mean(outputs.last_hidden_state, dim=1)
|
36 |
+
return embeddings
|
37 |
|
38 |
|
39 |
class GraphState(TypedDict):
|
|
|
51 |
generation: str
|
52 |
documents: List[Document]
|
53 |
filtered_documets: List[Document]
|
|
|
54 |
count_regenerations: int
|
55 |
|
56 |
|
|
|
180 |
return {"documents": documents, "question": question, "generation": generation}
|
181 |
|
182 |
|
183 |
+
def grade_documents_by_embed(state):
|
184 |
+
"""
|
185 |
+
Determines whether the retrieved documents are relevant to the question.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
state (dict): The current graph state
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
state (dict): Updates documents key with only filtered relevant documents
|
192 |
+
"""
|
193 |
+
question = state["question"]
|
194 |
+
documents = state["documents"]
|
195 |
+
|
196 |
+
# Score each doc
|
197 |
+
filtered_docs = []
|
198 |
+
|
199 |
+
query_embedding = get_embeddings([question])
|
200 |
+
document_embeddings = get_embeddings([doc.page_content for doc in documents])
|
201 |
+
|
202 |
+
# Calculate cosine similarity
|
203 |
+
similarity_scores = F.cosine_similarity(query_embedding, document_embeddings)
|
204 |
+
for doc, score in zip(documents, similarity_scores):
|
205 |
+
if score >= SIMILARITY_TRESHHOLD:
|
206 |
+
filtered_docs.append(doc)
|
207 |
+
sorted_documents = [
|
208 |
+
doc[0]
|
209 |
+
for doc in sorted(
|
210 |
+
zip(documents, similarity_scores), key=lambda x: x[1], reverse=True
|
211 |
+
)
|
212 |
+
]
|
213 |
+
cut_off_documents = sorted_documents[:5]
|
214 |
+
return {"documents": cut_off_documents, "question": question}
|
215 |
+
|
216 |
+
|
217 |
def grade_documents(state):
|
218 |
"""
|
219 |
Determines whether the retrieved documents are relevant to the question.
|
|
|
251 |
filtered_docs.append(documents[ind_d + j])
|
252 |
else:
|
253 |
print("---GRADE: DOCUMENT NOT RELEVANT---")
|
|
|
254 |
|
255 |
return {"documents": filtered_docs, "question": question}
|
256 |
|
|
|
315 |
# Define the nodes
|
316 |
workflow.add_node("start_point", start_point)
|
317 |
workflow.add_node("retrieve", retrieve) # retrieve
|
318 |
+
workflow.add_node("grade_documents", grade_documents_by_embed) # grade documents
|
319 |
workflow.add_node("generate", generate) # generate
|
320 |
workflow.add_node("casual_chat", casual_chat) # simple chat
|
321 |
workflow.add_node("add_sources", add_sources)
|
lib/prompts.py
CHANGED
@@ -29,11 +29,10 @@ rag_assistant_prompt = PromptTemplate(
|
|
29 |
template="""
|
30 |
SYSTEM: You are an assistant for question-answering tasks.
|
31 |
Use the following pieces of retrieved context to answer the question.
|
32 |
-
Use
|
33 |
If you don't find the answer in the context, transform the question ans ask the user to specify his qusetion.
|
34 |
|
35 |
Keep the answer concise.
|
36 |
-
Print a most possible topic of conversation.
|
37 |
Always reply in Russian, all text must be in Russian!
|
38 |
|
39 |
Context: {context}
|
|
|
29 |
template="""
|
30 |
SYSTEM: You are an assistant for question-answering tasks.
|
31 |
Use the following pieces of retrieved context to answer the question.
|
32 |
+
Use a 'Previous messages' as a part of context.
|
33 |
If you don't find the answer in the context, transform the question ans ask the user to specify his qusetion.
|
34 |
|
35 |
Keep the answer concise.
|
|
|
36 |
Always reply in Russian, all text must be in Russian!
|
37 |
|
38 |
Context: {context}
|
lib/runnables.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import contextlib
|
2 |
|
3 |
from lib.model_builder import ModelBuilderV2
|
|
|
4 |
from lib.prompts import (
|
5 |
casual_prompt,
|
6 |
grader_3_doc_prompt,
|
@@ -16,6 +17,7 @@ from langchain_core.chat_history import (
|
|
16 |
from langchain_core.runnables import ConfigurableFieldSpec
|
17 |
|
18 |
store = {}
|
|
|
19 |
|
20 |
|
21 |
def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHistory:
|
@@ -25,9 +27,9 @@ def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHi
|
|
25 |
|
26 |
|
27 |
class ModelConfig:
|
28 |
-
def __init__(self,
|
29 |
-
self.model_name =
|
30 |
-
self.temperature = temperature
|
31 |
|
32 |
|
33 |
class ConfigField:
|
@@ -81,10 +83,10 @@ def create_model_builder(config):
|
|
81 |
# llm.release() # Assuming ModelBuilderV2 has a release method to clear resources
|
82 |
|
83 |
|
84 |
-
casual_config = ModelConfig("
|
85 |
-
|
86 |
-
rag_config = ModelConfig("
|
87 |
-
classificator_msg_config = ModelConfig("
|
88 |
|
89 |
history_config = [USER_ID_FIELD, CONVERSATION_ID_FIELD]
|
90 |
|
@@ -96,7 +98,7 @@ with create_model_builder(casual_config) as llm:
|
|
96 |
| StrOutputParser()
|
97 |
)
|
98 |
|
99 |
-
with create_model_builder(
|
100 |
retrieval_grader_3 = grader_3_doc_prompt | llm | JsonOutputParser()
|
101 |
|
102 |
with create_model_builder(rag_config) as llm:
|
|
|
1 |
import contextlib
|
2 |
|
3 |
from lib.model_builder import ModelBuilderV2
|
4 |
+
from lib.config import Config
|
5 |
from lib.prompts import (
|
6 |
casual_prompt,
|
7 |
grader_3_doc_prompt,
|
|
|
17 |
from langchain_core.runnables import ConfigurableFieldSpec
|
18 |
|
19 |
store = {}
|
20 |
+
config = Config("config.yml")
|
21 |
|
22 |
|
23 |
def get_session_history(user_id: str, conversation_id: str) -> BaseChatMessageHistory:
|
|
|
27 |
|
28 |
|
29 |
class ModelConfig:
|
30 |
+
def __init__(self, config_key):
|
31 |
+
self.model_name = config.get("models", config_key, "model")
|
32 |
+
self.temperature = config.get("models", config_key, "temperature")
|
33 |
|
34 |
|
35 |
class ConfigField:
|
|
|
83 |
# llm.release() # Assuming ModelBuilderV2 has a release method to clear resources
|
84 |
|
85 |
|
86 |
+
casual_config = ModelConfig("casual_conversation")
|
87 |
+
multiquery_config = ModelConfig("multiquery_retrieval")
|
88 |
+
rag_config = ModelConfig("rag")
|
89 |
+
classificator_msg_config = ModelConfig("classificator_msg")
|
90 |
|
91 |
history_config = [USER_ID_FIELD, CONVERSATION_ID_FIELD]
|
92 |
|
|
|
98 |
| StrOutputParser()
|
99 |
)
|
100 |
|
101 |
+
with create_model_builder(multiquery_config) as llm:
|
102 |
retrieval_grader_3 = grader_3_doc_prompt | llm | JsonOutputParser()
|
103 |
|
104 |
with create_model_builder(rag_config) as llm:
|