Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import re | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
from langgraph.graph import StateGraph, MessagesState, START, END | |
from langgraph.types import Command | |
from langchain_core.messages import HumanMessage | |
from langgraph.prebuilt import create_react_agent | |
from langchain_anthropic import ChatAnthropic | |
# Load API Key | |
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY") | |
# LangGraph setup | |
llm = ChatAnthropic(model="claude-3-5-sonnet-latest") | |
def make_system_prompt(suffix: str) -> str: | |
return ( | |
"You are a helpful AI assistant, collaborating with other assistants." | |
" Use the provided tools to progress towards answering the question." | |
" If you are unable to fully answer, that's OK, another assistant with different tools " | |
" will help where you left off. Execute what you can to make progress." | |
" If you or any of the other assistants have the final answer or deliverable," | |
" prefix your response with FINAL ANSWER so the team knows to stop." | |
f"\n{suffix}" | |
) | |
def research_node(state: MessagesState) -> Command[str]: | |
agent = create_react_agent( | |
llm, | |
tools=[], | |
state_modifier=make_system_prompt("You can only do research.") | |
) | |
result = agent.invoke(state) | |
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "chart_generator" | |
result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="researcher") | |
return Command(update={"messages": result["messages"]}, goto=goto) | |
def chart_node(state: MessagesState) -> Command[str]: | |
agent = create_react_agent( | |
llm, | |
tools=[], | |
state_modifier=make_system_prompt("You can only generate charts.") | |
) | |
result = agent.invoke(state) | |
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "researcher" | |
result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="chart_generator") | |
return Command(update={"messages": result["messages"]}, goto=goto) | |
# Create the LangGraph workflow | |
workflow = StateGraph(MessagesState) | |
workflow.add_node("researcher", research_node) | |
workflow.add_node("chart_generator", chart_node) | |
workflow.add_edge(START, "researcher") | |
workflow.add_edge("researcher", "chart_generator") | |
workflow.add_edge("chart_generator", END) | |
graph = workflow.compile() | |
def extract_chart_data(text): | |
print("π§ͺ Raw LLM Output to parse:\n", text) | |
matches = re.findall(r'(\b19\d{2}|\b20\d{2})[^\d]{1,10}(\$?\d+(\.\d+)?)', text) | |
if not matches: | |
print("β No year-value pairs found.") | |
return None, None | |
years = [] | |
values = [] | |
for match in matches: | |
year = match[0] | |
value_str = match[1].replace('$', '') | |
try: | |
value = float(value_str) | |
years.append(year) | |
values.append(value) | |
except ValueError: | |
continue | |
print("β Extracted:", years, values) | |
return years, values if years and values else (None, None) | |
def generate_plot(years, values): | |
fig, ax = plt.subplots() | |
ax.bar(years, values) | |
ax.set_title("Generated Chart") | |
ax.set_xlabel("Year") | |
ax.set_ylabel("Value") | |
buf = BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
return buf | |
def run_langgraph(user_input): | |
print("π© Input to LangGraph:", user_input) | |
events = graph.stream( | |
{"messages": [("user", user_input)]}, | |
{"recursion_limit": 150} | |
) | |
final_message = None | |
for event in events: | |
print("πΉ Event:", event) | |
if "messages" in event and event["messages"]: | |
for m in event["messages"]: | |
print("πΈ Message:", m.content) | |
final_message = event["messages"][-1].content | |
return final_message or "No output generated" | |
def process_input(user_input): | |
# π Toggle this to test graph generation without LLM | |
STATIC_TEST = False | |
if STATIC_TEST: | |
dummy_output = """ | |
FINAL ANSWER: | |
Here is the GDP of the USA: | |
2019: 21.4 | |
2020: 20.9 | |
2021: 22.1 | |
2022: 23.0 | |
2023: 24.3 | |
""" | |
years, values = extract_chart_data(dummy_output) | |
if years and values: | |
chart = generate_plot(years, values) | |
return dummy_output, chart | |
else: | |
return dummy_output, None | |
# Run actual LangGraph-based flow | |
result_text = run_langgraph(user_input) | |
years, values = extract_chart_data(result_text) | |
if years and values: | |
chart = generate_plot(years, values) | |
return result_text, chart | |
else: | |
return result_text, None | |
# Gradio interface | |
interface = gr.Interface( | |
fn=process_input, | |
inputs="text", | |
outputs=[ | |
gr.Textbox(label="Generated Response"), | |
gr.Image(type="pil", label="Generated Chart") | |
], | |
title="LangGraph Research Automation", | |
description="Enter your research task (e.g., 'Get GDP data for the USA over the past 5 years and create a chart.')" | |
) | |
if __name__ == "__main__": | |
interface.launch(share=True, ssr_mode=False) | |