fixs
Browse files- app.py +1 -1
- climateqa/engine/chains/graph_retriever.py +2 -2
- climateqa/engine/graph.py +7 -4
app.py
CHANGED
@@ -92,7 +92,7 @@ reranker = get_reranker("nano")
|
|
92 |
|
93 |
# Create vectorstore and retriever
|
94 |
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
95 |
-
vectorstore_graphs = Chroma(persist_directory="/home/tim/ai4s/climate_qa/
|
96 |
|
97 |
# agent = make_graph_agent(llm,vectorstore,reranker)
|
98 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
|
|
92 |
|
93 |
# Create vectorstore and retriever
|
94 |
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
95 |
+
vectorstore_graphs = Chroma(persist_directory="/home/tim/ai4s/climate_qa/climate-question-answering/data/vectorstore_owid", embedding_function=embeddings_function)
|
96 |
|
97 |
# agent = make_graph_agent(llm,vectorstore,reranker)
|
98 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
climateqa/engine/chains/graph_retriever.py
CHANGED
@@ -50,7 +50,7 @@ def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_
|
|
50 |
print("---- Retrieving graphs ----")
|
51 |
|
52 |
POSSIBLE_SOURCES = ["IEA", "OWID"]
|
53 |
-
questions = state["
|
54 |
sources_input = state["sources_input"]
|
55 |
|
56 |
auto_mode = "auto" in sources_input
|
@@ -90,7 +90,7 @@ def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_
|
|
90 |
docs_question = retriever.get_relevant_documents(question)
|
91 |
|
92 |
# Rerank
|
93 |
-
if reranker is not None:
|
94 |
with suppress_output():
|
95 |
docs_question = rerank_docs(reranker,docs_question,question)
|
96 |
else:
|
|
|
50 |
print("---- Retrieving graphs ----")
|
51 |
|
52 |
POSSIBLE_SOURCES = ["IEA", "OWID"]
|
53 |
+
questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
|
54 |
sources_input = state["sources_input"]
|
55 |
|
56 |
auto_mode = "auto" in sources_input
|
|
|
90 |
docs_question = retriever.get_relevant_documents(question)
|
91 |
|
92 |
# Rerank
|
93 |
+
if reranker is not None and docs_question!=[]:
|
94 |
with suppress_output():
|
95 |
docs_question = rerank_docs(reranker,docs_question,question)
|
96 |
else:
|
climateqa/engine/graph.py
CHANGED
@@ -18,6 +18,9 @@ from .chains.translation import make_translation_node
|
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
from .chains.retrieve_documents import make_retriever_node
|
20 |
from .chains.answer_rag import make_rag_node
|
|
|
|
|
|
|
21 |
|
22 |
class GraphState(TypedDict):
|
23 |
"""
|
@@ -89,8 +92,8 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
89 |
transform_query = make_query_transform_node(llm)
|
90 |
translate_query = make_translation_node(llm)
|
91 |
answer_chitchat = make_chitchat_node(llm)
|
92 |
-
answer_ai_impact = make_ai_impact_node(llm)
|
93 |
-
retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker)
|
94 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
95 |
# answer_rag_graph = make_rag_graph_node(llm)
|
96 |
answer_rag = make_rag_node(llm, with_docs=True)
|
@@ -108,7 +111,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
108 |
# workflow.add_node("translate_query_ai", translate_query)
|
109 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
110 |
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
111 |
-
workflow.add_node("answer_ai_impact", answer_ai_impact)
|
112 |
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
113 |
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
114 |
# workflow.add_node("retrieve_graphs_ai", retrieve_graphs)
|
@@ -162,7 +165,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
162 |
workflow.add_edge("answer_rag_no_docs", END)
|
163 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
164 |
# workflow.add_edge("answer_chitchat", END)
|
165 |
-
workflow.add_edge("answer_ai_impact", END)
|
166 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
167 |
# workflow.add_edge("answer_ai_impact", "translate_query_ai")
|
168 |
# workflow.add_edge("translate_query_ai", "transform_query_ai")
|
|
|
18 |
from .chains.intent_categorization import make_intent_categorization_node
|
19 |
from .chains.retrieve_documents import make_retriever_node
|
20 |
from .chains.answer_rag import make_rag_node
|
21 |
+
from .chains.graph_retriever import make_graph_retriever_node
|
22 |
+
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
23 |
+
from .chains.set_defaults import set_defaults
|
24 |
|
25 |
class GraphState(TypedDict):
|
26 |
"""
|
|
|
92 |
transform_query = make_query_transform_node(llm)
|
93 |
translate_query = make_translation_node(llm)
|
94 |
answer_chitchat = make_chitchat_node(llm)
|
95 |
+
# answer_ai_impact = make_ai_impact_node(llm)
|
96 |
+
retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker, llm)
|
97 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
98 |
# answer_rag_graph = make_rag_graph_node(llm)
|
99 |
answer_rag = make_rag_node(llm, with_docs=True)
|
|
|
111 |
# workflow.add_node("translate_query_ai", translate_query)
|
112 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
113 |
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
114 |
+
# workflow.add_node("answer_ai_impact", answer_ai_impact)
|
115 |
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
116 |
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
117 |
# workflow.add_node("retrieve_graphs_ai", retrieve_graphs)
|
|
|
165 |
workflow.add_edge("answer_rag_no_docs", END)
|
166 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
167 |
# workflow.add_edge("answer_chitchat", END)
|
168 |
+
# workflow.add_edge("answer_ai_impact", END)
|
169 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
170 |
# workflow.add_edge("answer_ai_impact", "translate_query_ai")
|
171 |
# workflow.add_edge("translate_query_ai", "transform_query_ai")
|