Spaces:
Sleeping
Sleeping
File size: 5,900 Bytes
75115cd 540db73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import operator
import functools
from typing import Annotated, Sequence, TypedDict, Union, Optional
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langgraph.graph import StateGraph, END
from application.agents.scraper_agent import scraper_agent
from application.agents.extractor_agent import extractor_agent
from application.utils.logger import get_logger
load_dotenv()
logger = get_logger()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.")
raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")
MEMBERS = ["Scraper", "Extractor"]
OPTIONS = ["FINISH"] + MEMBERS
SUPERVISOR_SYSTEM_PROMPT = (
"You are a supervisor tasked with managing a conversation between the following workers: {members}. "
"Given the user's request and the previous messages, determine what to do next:\n"
"- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n"
"- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n"
"- If the task is complete, choose 'FINISH'.\n"
"- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n"
"Each worker will perform its task and report back.\n"
"When you respond directly, make sure your message is friendly and helpful."
)
FUNCTION_DEF = {
"name": "route_or_respond",
"description": "Select the next role OR respond directly.",
"parameters": {
"title": "RouteOrRespondSchema",
"type": "object",
"properties": {
"next": {
"title": "Next Worker",
"anyOf": [{"enum": OPTIONS}],
"description": "Choose next worker if needed."
},
"response": {
"title": "Supervisor Response",
"type": "string",
"description": "Respond directly if no worker action is needed."
}
},
"required": [],
},
}
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next: Optional[str]
response: Optional[str]
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
def agent_node(state: AgentState, agent: Runnable, name: str) -> dict:
logger.info(f"Agent {name} invoked.")
try:
result = agent.invoke(state)
logger.info(f"Agent {name} completed successfully.")
return {"messages": [HumanMessage(content=result["output"], name=name)]}
except Exception as e:
logger.exception(f"Agent {name} failed with error: {str(e)}")
raise
prompt = ChatPromptTemplate.from_messages(
[
("system", SUPERVISOR_SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.",
),
]
).partial(options=str(OPTIONS), members=", ".join(MEMBERS))
# supervisor_chain = (
# prompt
# | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond")
# | JsonOutputFunctionsParser()
# )
supervisor_chain = (
prompt
| llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond")
| JsonOutputKeyToolsParser(key_name="route_or_respond")
)
def supervisor_node(state: AgentState) -> AgentState:
logger.info("Supervisor invoked.")
output = supervisor_chain.invoke(state)
logger.info(f"Supervisor output: {output}")
if isinstance(output, list) and len(output) > 0:
output = output[0]
next_step = output.get("next")
response = output.get("response")
if not next_step and not response:
raise ValueError(f"Supervisor produced invalid output: {output}")
return {
"messages": state["messages"],
"next": next_step,
"response": response,
}
workflow = StateGraph(AgentState)
workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper"))
workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor"))
workflow.add_node("supervisor", supervisor_node)
# workflow.add_node("supervisor", supervisor_chain)
workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]})
for member in MEMBERS:
workflow.add_edge(member, "supervisor")
def router(state: AgentState):
if state.get("response"):
return "supervisor_response"
return state.get("next")
conditional_map = {member: member for member in MEMBERS}
conditional_map["FINISH"] = END
conditional_map["supervisor_response"] = "supervisor_response"
workflow.add_conditional_edges("supervisor", router, conditional_map)
workflow.set_entry_point("supervisor")
graph = workflow.compile()
# # === Example Run ===
# if __name__ == "__main__":
# logger.info("Starting the graph execution...")
# initial_message = HumanMessage(content="Can you get zalando pdf link")
# input_state = {"messages": [initial_message]}
# for step in graph.stream(input_state):
# if "__end__" not in step:
# logger.info(f"Graph Step Output: {step}")
# print(step)
# print("----")
# logger.info("Graph execution completed.") |