File size: 4,346 Bytes
04b1d6c
 
4a8958c
 
 
 
 
 
 
 
 
04b1d6c
 
 
 
 
4a8958c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import sys
from typing import List, Literal
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.prebuilt import create_react_agent
from langgraph.graph import MessagesState, END
from langgraph.types import Command
from langgraph.graph import StateGraph, START
from IPython.display import Image, display
import re

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from agents.weather_agent import WeatherAgent
from agents.pdf_agent import PDFAgent


def split_questions(user_message: str) -> List[str]:
    # Naive split on ' and ', ' then ', case insensitive
    parts = re.split(r'\band then\b|\band\b|\bthen\b', user_message, flags=re.IGNORECASE)
    return [part.strip() for part in parts if part.strip()]

def classify_question(question: str) -> Literal["pdf_agent", "weather_agent"]:
    # Simple keyword-based classification
    if re.search(r'\bweather\b', question, re.IGNORECASE):
        return "weather_agent"
    else:
        return "pdf_agent"

def pdf_agent_node(state: MessagesState) -> Command[Literal["weather_agent", END]]:
    pdf_agent = PDFAgent(pdf_path="Sharath_OnePage.pdf")
    user_message = None
    for message in reversed(state["messages"]):
        if isinstance(message, HumanMessage):
            user_message = message.content
            break
    if user_message is None:
        raise ValueError("No user message found in state.")

    result = pdf_agent.agent.invoke({"input": user_message})
    # Extract string from result
    if isinstance(result, dict):
        # Try common keys
        text_result = result.get("output") or result.get("text") or str(result)
    else:
        text_result = str(result)

    final_msg = HumanMessage(content=text_result, name="pdf_agent")
    goto = get_next_node(final_msg, "weather_agent")
    return Command(
        update={"messages": state["messages"] + [final_msg]},
        goto=goto,
    )

def weather_agent_node(state: MessagesState) -> Command[Literal["pdf_agent", END]]:
    weather_agent = WeatherAgent()
    user_message = None
    for message in reversed(state["messages"]):
        if isinstance(message, HumanMessage):
            user_message = message.content
            break
    if user_message is None:
        raise ValueError("No user message found in state.")

    match = re.search(r"weather in ([\w\s,]+)", user_message, re.IGNORECASE)
    location = match.group(1).strip() if match else user_message
    result = weather_agent.ask(location)
    final_msg = HumanMessage(content=result, name="weather_agent")
    goto = get_next_node(final_msg, "pdf_agent")
    return Command(
        update={"messages": state["messages"] + [final_msg]},
        goto=goto,
    )

def get_next_node(last_message: BaseMessage, goto: str):
    if "FINAL ANSWER" in last_message.content:
        return END
    return goto

def build_graph():
    workflow = StateGraph(MessagesState)
    workflow.add_node("pdf_agent", pdf_agent_node)
    workflow.add_node("weather_agent", weather_agent_node)

    workflow.add_edge(START, "pdf_agent")
    workflow.add_edge("pdf_agent", "weather_agent")
    workflow.add_edge("weather_agent", END)

    graph = workflow.compile()
    return graph

if __name__ == "__main__":
    graph = build_graph()
    display(Image(graph.get_graph().draw_mermaid_png()))

    # Full user input with multiple questions
    user_input = "What organizations has Sharath worked for and tell me the weather in Mumbai"

    # Split into sub-questions
    questions = split_questions(user_input)

    # Prepare empty message list to accumulate conversation
    messages = []

    # Process each question routed to the correct agent node
    for question in questions:
        agent_name = classify_question(question)
        # Run the corresponding node manually with current messages + new question
        state = {"messages": messages + [HumanMessage(content=question)]}
        if agent_name == "pdf_agent":
            cmd = pdf_agent_node(state)
        else:
            cmd = weather_agent_node(state)

        # Update messages with agent response
        messages = cmd.update["messages"]

    # Print all agent responses
    for msg in messages:
        if not isinstance(msg, HumanMessage):
            continue
        print(f"{msg.name or 'user'}: {msg.content}")