Spaces:
Running
Running
from typing import List, Optional, Callable, Any | |
from functools import partial | |
import logging | |
from langchain_core.messages import AIMessage, BaseMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.language_models.llms import LLM | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langgraph.graph.state import StateGraph | |
from langgraph.constants import START, END | |
from ask_candid.tools.recommendation import ( | |
detect_intent_with_llm, | |
determine_context, | |
make_recommendation | |
) | |
from ask_candid.tools.question_reformulation import reformulate_question_using_history | |
from ask_candid.tools.org_seach import has_org_name, insert_org_link | |
from ask_candid.tools.search import search_agent, retriever_tool | |
from ask_candid.agents.schema import AgentState | |
from ask_candid.base.config.data import DataIndices | |
from ask_candid.utils import html_format_docs_chat | |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
def generate_with_context( | |
state: AgentState, | |
llm: LLM, | |
user_callback: Optional[Callable[[str], Any]] = None | |
) -> AgentState: | |
"""Generate answer. | |
Parameters | |
---------- | |
state : AgentState | |
The current state | |
llm : LLM | |
user_callback : Optional[Callable[[str], Any]], optional | |
Optional UI callback to inform the user of apps states, by default None | |
Returns | |
------- | |
AgentState | |
The updated state with the agent response appended to messages | |
""" | |
logger.info("---GENERATE ANSWER---") | |
if user_callback is not None: | |
try: | |
user_callback("Writing a response...") | |
except Exception as ex: | |
logger.warning("User callback was passed in but failed: %s", ex) | |
messages = state["messages"] | |
question = state["user_input"] | |
last_message = messages[-1] | |
sources_str = last_message.content | |
sources_list = last_message.artifact | |
sources_html = html_format_docs_chat(sources_list) | |
if sources_list: | |
logger.info("---ADD SOURCES---") | |
state["messages"].append(BaseMessage(content=sources_html, type="HTML")) | |
# Prompt | |
qa_system_prompt = """ | |
You are an assistant for question-answering tasks in the social and philanthropic sector. \n | |
Use the following pieces of retrieved context to answer the question at the end. \n | |
If you don't know the answer, just say that you don't know. \n | |
Keep the response professional, friendly, and as concise as possible. \n | |
Question: {question} | |
Context: {context} | |
Answer: | |
""" | |
qa_prompt = ChatPromptTemplate([ | |
("system", qa_system_prompt), | |
("human", question), | |
]) | |
rag_chain = qa_prompt | llm | StrOutputParser() | |
response = rag_chain.invoke({"context": sources_str, "question": question}) | |
return {"messages": [AIMessage(content=response)], "user_input": question} | |
def add_recommendations_pipeline_( | |
G: StateGraph, | |
llm: LLM, | |
reformulation_node_name: str = "reformulate", | |
search_node_name: str = "search_agent" | |
) -> None: | |
"""Adds execution sub-graph for recommendation engine flow. Graph changes are in-place. | |
Parameters | |
---------- | |
G : StateGraph | |
Execution graph | |
reformulation_node_name : str, optional | |
Name of the node which reforumates input queries, by default "reformulate" | |
search_node_name : str, optional | |
Name of the node which executes document search + retrieval, by default "search_agent" | |
""" | |
# Nodes for recommendation functionalities | |
G.add_node(node="detect_intent_with_llm", action=partial(detect_intent_with_llm, llm=llm)) | |
G.add_node(node="determine_context", action=determine_context) | |
G.add_node(node="make_recommendation", action=make_recommendation) | |
# Check for recommendation query first | |
# Execute until reaching END if user asks for recommendation | |
G.add_edge(start_key=reformulation_node_name, end_key="detect_intent_with_llm") | |
G.add_conditional_edges( | |
source="detect_intent_with_llm", | |
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name, | |
path_map={ | |
"determine_context": "determine_context", | |
search_node_name: search_node_name | |
}, | |
) | |
G.add_edge(start_key="determine_context", end_key="make_recommendation") | |
G.add_edge(start_key="make_recommendation", end_key=END) | |
def build_compute_graph( | |
llm: LLM, | |
indices: List[DataIndices], | |
enable_recommendations: bool = False, | |
user_callback: Optional[Callable[[str], Any]] = None | |
) -> StateGraph: | |
"""Execution graph builder, the output is the execution flow for an interaction with the assistant. | |
Parameters | |
---------- | |
llm : LLM | |
indices : List[DataIndices] | |
Semantic index names to search over | |
enable_recommendations : bool, optional | |
Set to `True` to allow the flow to generate recommendations based on context, by default False | |
user_callback : Optional[Callable[[str], Any]], optional | |
Optional UI callback to inform the user of apps states, by default None | |
Returns | |
------- | |
StateGraph | |
Execution graph | |
""" | |
candid_retriever_tool = retriever_tool(indices=indices, user_callback=user_callback) | |
retrieve = ToolNode([candid_retriever_tool]) | |
tools = [candid_retriever_tool] | |
G = StateGraph(AgentState) | |
G.add_node( | |
node="reformulate", | |
action=partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations) | |
) | |
G.add_node(node="search_agent", action=partial(search_agent, llm=llm, tools=tools)) | |
G.add_node(node="retrieve", action=retrieve) | |
G.add_node( | |
node="generate_with_context", | |
action=partial(generate_with_context, llm=llm, user_callback=user_callback) | |
) | |
G.add_node(node="has_org_name", action=partial(has_org_name, llm=llm, user_callback=user_callback)) | |
G.add_node(node="insert_org_link", action=insert_org_link) | |
if enable_recommendations: | |
add_recommendations_pipeline_( | |
G, llm=llm, | |
reformulation_node_name="reformulate", | |
search_node_name="search_agent" | |
) | |
else: | |
G.add_edge(start_key="reformulate", end_key="search_agent") | |
G.add_edge(start_key=START, end_key="reformulate") | |
G.add_conditional_edges( | |
source="search_agent", | |
path=tools_condition, | |
path_map={ | |
"tools": "retrieve", | |
END: "has_org_name", | |
}, | |
) | |
G.add_edge(start_key="retrieve", end_key="generate_with_context") | |
G.add_edge(start_key="generate_with_context", end_key="has_org_name") | |
G.add_conditional_edges( | |
source="has_org_name", | |
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict | |
path_map={ | |
"insert_org_link": "insert_org_link", | |
END: END | |
}, | |
) | |
G.add_edge(start_key="insert_org_link", end_key=END) | |
return G | |