brainsqueeze's picture
UI callbacks and style changes
cc80c3d verified
from typing import List, Optional, Callable, Any
from functools import partial
import logging
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.language_models.llms import LLM
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.graph.state import StateGraph
from langgraph.constants import START, END
from ask_candid.tools.recommendation import (
detect_intent_with_llm,
determine_context,
make_recommendation
)
from ask_candid.tools.question_reformulation import reformulate_question_using_history
from ask_candid.tools.org_seach import has_org_name, insert_org_link
from ask_candid.tools.search import search_agent, retriever_tool
from ask_candid.agents.schema import AgentState
from ask_candid.base.config.data import DataIndices
from ask_candid.utils import html_format_docs_chat
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def generate_with_context(
state: AgentState,
llm: LLM,
user_callback: Optional[Callable[[str], Any]] = None
) -> AgentState:
"""Generate answer.
Parameters
----------
state : AgentState
The current state
llm : LLM
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
AgentState
The updated state with the agent response appended to messages
"""
logger.info("---GENERATE ANSWER---")
if user_callback is not None:
try:
user_callback("Writing a response...")
except Exception as ex:
logger.warning("User callback was passed in but failed: %s", ex)
messages = state["messages"]
question = state["user_input"]
last_message = messages[-1]
sources_str = last_message.content
sources_list = last_message.artifact
sources_html = html_format_docs_chat(sources_list)
if sources_list:
logger.info("---ADD SOURCES---")
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
# Prompt
qa_system_prompt = """
You are an assistant for question-answering tasks in the social and philanthropic sector. \n
Use the following pieces of retrieved context to answer the question at the end. \n
If you don't know the answer, just say that you don't know. \n
Keep the response professional, friendly, and as concise as possible. \n
Question: {question}
Context: {context}
Answer:
"""
qa_prompt = ChatPromptTemplate([
("system", qa_system_prompt),
("human", question),
])
rag_chain = qa_prompt | llm | StrOutputParser()
response = rag_chain.invoke({"context": sources_str, "question": question})
return {"messages": [AIMessage(content=response)], "user_input": question}
def add_recommendations_pipeline_(
G: StateGraph,
llm: LLM,
reformulation_node_name: str = "reformulate",
search_node_name: str = "search_agent"
) -> None:
"""Adds execution sub-graph for recommendation engine flow. Graph changes are in-place.
Parameters
----------
G : StateGraph
Execution graph
reformulation_node_name : str, optional
Name of the node which reforumates input queries, by default "reformulate"
search_node_name : str, optional
Name of the node which executes document search + retrieval, by default "search_agent"
"""
# Nodes for recommendation functionalities
G.add_node(node="detect_intent_with_llm", action=partial(detect_intent_with_llm, llm=llm))
G.add_node(node="determine_context", action=determine_context)
G.add_node(node="make_recommendation", action=make_recommendation)
# Check for recommendation query first
# Execute until reaching END if user asks for recommendation
G.add_edge(start_key=reformulation_node_name, end_key="detect_intent_with_llm")
G.add_conditional_edges(
source="detect_intent_with_llm",
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
path_map={
"determine_context": "determine_context",
search_node_name: search_node_name
},
)
G.add_edge(start_key="determine_context", end_key="make_recommendation")
G.add_edge(start_key="make_recommendation", end_key=END)
def build_compute_graph(
llm: LLM,
indices: List[DataIndices],
enable_recommendations: bool = False,
user_callback: Optional[Callable[[str], Any]] = None
) -> StateGraph:
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
Parameters
----------
llm : LLM
indices : List[DataIndices]
Semantic index names to search over
enable_recommendations : bool, optional
Set to `True` to allow the flow to generate recommendations based on context, by default False
user_callback : Optional[Callable[[str], Any]], optional
Optional UI callback to inform the user of apps states, by default None
Returns
-------
StateGraph
Execution graph
"""
candid_retriever_tool = retriever_tool(indices=indices, user_callback=user_callback)
retrieve = ToolNode([candid_retriever_tool])
tools = [candid_retriever_tool]
G = StateGraph(AgentState)
G.add_node(
node="reformulate",
action=partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations)
)
G.add_node(node="search_agent", action=partial(search_agent, llm=llm, tools=tools))
G.add_node(node="retrieve", action=retrieve)
G.add_node(
node="generate_with_context",
action=partial(generate_with_context, llm=llm, user_callback=user_callback)
)
G.add_node(node="has_org_name", action=partial(has_org_name, llm=llm, user_callback=user_callback))
G.add_node(node="insert_org_link", action=insert_org_link)
if enable_recommendations:
add_recommendations_pipeline_(
G, llm=llm,
reformulation_node_name="reformulate",
search_node_name="search_agent"
)
else:
G.add_edge(start_key="reformulate", end_key="search_agent")
G.add_edge(start_key=START, end_key="reformulate")
G.add_conditional_edges(
source="search_agent",
path=tools_condition,
path_map={
"tools": "retrieve",
END: "has_org_name",
},
)
G.add_edge(start_key="retrieve", end_key="generate_with_context")
G.add_edge(start_key="generate_with_context", end_key="has_org_name")
G.add_conditional_edges(
source="has_org_name",
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
path_map={
"insert_org_link": "insert_org_link",
END: END
},
)
G.add_edge(start_key="insert_org_link", end_key=END)
return G