File size: 5,137 Bytes
893865b
d712efb
 
 
 
893865b
 
ea761af
893865b
 
f0007ca
ea761af
 
893865b
ea761af
893865b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea761af
893865b
 
 
 
ea761af
893865b
 
 
 
 
ea761af
893865b
 
 
 
ea761af
893865b
 
d712efb
893865b
 
 
 
 
 
 
 
ea761af
d712efb
d8ede0d
ea761af
d712efb
ea761af
d8ede0d
 
 
 
 
 
 
 
 
 
 
 
 
d712efb
d8ede0d
 
ea761af
 
 
 
 
 
 
 
 
 
 
893865b
d712efb
d8a1ce0
893865b
 
 
d8a1ce0
 
893865b
d712efb
0b9fbf1
d712efb
 
0b9fbf1
 
 
893865b
 
d712efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea761af
 
 
 
 
 
 
893865b
d712efb
893865b
 
 
ea761af
 
 
 
893865b
ea761af
893865b
 
 
a57a271
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import os
import re
import matplotlib.pyplot as plt
from io import BytesIO
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.types import Command
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent
from langchain_anthropic import ChatAnthropic

# Load API Key
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")

# LangGraph setup
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")

def make_system_prompt(suffix: str) -> str:
    return (
        "You are a helpful AI assistant, collaborating with other assistants."
        " Use the provided tools to progress towards answering the question."
        " If you are unable to fully answer, that's OK, another assistant with different tools "
        " will help where you left off. Execute what you can to make progress."
        " If you or any of the other assistants have the final answer or deliverable,"
        " prefix your response with FINAL ANSWER so the team knows to stop."
        f"\n{suffix}"
    )

def research_node(state: MessagesState) -> Command[str]:
    agent = create_react_agent(
        llm,
        tools=[],
        state_modifier=make_system_prompt("You can only do research.")
    )
    result = agent.invoke(state)
    goto = END if "FINAL ANSWER" in result["messages"][-1].content else "chart_generator"
    result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="researcher")
    return Command(update={"messages": result["messages"]}, goto=goto)

def chart_node(state: MessagesState) -> Command[str]:
    agent = create_react_agent(
        llm,
        tools=[],
        state_modifier=make_system_prompt("You can only generate charts.")
    )
    result = agent.invoke(state)
    goto = END if "FINAL ANSWER" in result["messages"][-1].content else "researcher"
    result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="chart_generator")
    return Command(update={"messages": result["messages"]}, goto=goto)

# Create the LangGraph workflow
workflow = StateGraph(MessagesState)
workflow.add_node("researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_edge(START, "researcher")
workflow.add_edge("researcher", "chart_generator")
workflow.add_edge("chart_generator", END)
graph = workflow.compile()

def extract_chart_data(text):
    print("πŸ§ͺ Raw LLM Output to parse:\n", text)
    matches = re.findall(r'(\b19\d{2}|\b20\d{2})[^\d]{1,10}(\$?\d+(\.\d+)?)', text)
    if not matches:
        print("❌ No year-value pairs found.")
        return None, None

    years = []
    values = []
    for match in matches:
        year = match[0]
        value_str = match[1].replace('$', '')
        try:
            value = float(value_str)
            years.append(year)
            values.append(value)
        except ValueError:
            continue

    print("βœ… Extracted:", years, values)
    return years, values if years and values else (None, None)

def generate_plot(years, values):
    fig, ax = plt.subplots()
    ax.bar(years, values)
    ax.set_title("Generated Chart")
    ax.set_xlabel("Year")
    ax.set_ylabel("Value")
    buf = BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    return buf

def run_langgraph(user_input):
    print("πŸ“© Input to LangGraph:", user_input)
    events = graph.stream(
        {"messages": [("user", user_input)]},
        {"recursion_limit": 150}
    )

    final_message = None
    for event in events:
        print("πŸ”Ή Event:", event)
        if "messages" in event and event["messages"]:
            for m in event["messages"]:
                print("πŸ”Έ Message:", m.content)
            final_message = event["messages"][-1].content

    return final_message or "No output generated"

def process_input(user_input):
    # πŸ” Toggle this to test graph generation without LLM
    STATIC_TEST = False

    if STATIC_TEST:
        dummy_output = """
        FINAL ANSWER:
        Here is the GDP of the USA:
        2019: 21.4
        2020: 20.9
        2021: 22.1
        2022: 23.0
        2023: 24.3
        """
        years, values = extract_chart_data(dummy_output)
        if years and values:
            chart = generate_plot(years, values)
            return dummy_output, chart
        else:
            return dummy_output, None

    # Run actual LangGraph-based flow
    result_text = run_langgraph(user_input)
    years, values = extract_chart_data(result_text)
    if years and values:
        chart = generate_plot(years, values)
        return result_text, chart
    else:
        return result_text, None

# Gradio interface
interface = gr.Interface(
    fn=process_input,
    inputs="text",
    outputs=[
        gr.Textbox(label="Generated Response"),
        gr.Image(type="pil", label="Generated Chart")
    ],
    title="LangGraph Research Automation",
    description="Enter your research task (e.g., 'Get GDP data for the USA over the past 5 years and create a chart.')"
)

if __name__ == "__main__":
    interface.launch(share=True, ssr_mode=False)