Spaces:
Sleeping
Sleeping
File size: 5,137 Bytes
893865b d712efb 893865b ea761af 893865b f0007ca ea761af 893865b ea761af 893865b ea761af 893865b ea761af 893865b ea761af 893865b ea761af 893865b d712efb 893865b ea761af d712efb d8ede0d ea761af d712efb ea761af d8ede0d d712efb d8ede0d ea761af 893865b d712efb d8a1ce0 893865b d8a1ce0 893865b d712efb 0b9fbf1 d712efb 0b9fbf1 893865b d712efb ea761af 893865b d712efb 893865b ea761af 893865b ea761af 893865b a57a271 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import os
import re
import matplotlib.pyplot as plt
from io import BytesIO
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.types import Command
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent
from langchain_anthropic import ChatAnthropic
# Load API Key
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
# LangGraph setup
llm = ChatAnthropic(model="claude-3-5-sonnet-latest")
def make_system_prompt(suffix: str) -> str:
return (
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
f"\n{suffix}"
)
def research_node(state: MessagesState) -> Command[str]:
agent = create_react_agent(
llm,
tools=[],
state_modifier=make_system_prompt("You can only do research.")
)
result = agent.invoke(state)
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "chart_generator"
result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="researcher")
return Command(update={"messages": result["messages"]}, goto=goto)
def chart_node(state: MessagesState) -> Command[str]:
agent = create_react_agent(
llm,
tools=[],
state_modifier=make_system_prompt("You can only generate charts.")
)
result = agent.invoke(state)
goto = END if "FINAL ANSWER" in result["messages"][-1].content else "researcher"
result["messages"][-1] = HumanMessage(content=result["messages"][-1].content, name="chart_generator")
return Command(update={"messages": result["messages"]}, goto=goto)
# Create the LangGraph workflow
workflow = StateGraph(MessagesState)
workflow.add_node("researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_edge(START, "researcher")
workflow.add_edge("researcher", "chart_generator")
workflow.add_edge("chart_generator", END)
graph = workflow.compile()
def extract_chart_data(text):
print("π§ͺ Raw LLM Output to parse:\n", text)
matches = re.findall(r'(\b19\d{2}|\b20\d{2})[^\d]{1,10}(\$?\d+(\.\d+)?)', text)
if not matches:
print("β No year-value pairs found.")
return None, None
years = []
values = []
for match in matches:
year = match[0]
value_str = match[1].replace('$', '')
try:
value = float(value_str)
years.append(year)
values.append(value)
except ValueError:
continue
print("β
Extracted:", years, values)
return years, values if years and values else (None, None)
def generate_plot(years, values):
fig, ax = plt.subplots()
ax.bar(years, values)
ax.set_title("Generated Chart")
ax.set_xlabel("Year")
ax.set_ylabel("Value")
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
return buf
def run_langgraph(user_input):
print("π© Input to LangGraph:", user_input)
events = graph.stream(
{"messages": [("user", user_input)]},
{"recursion_limit": 150}
)
final_message = None
for event in events:
print("πΉ Event:", event)
if "messages" in event and event["messages"]:
for m in event["messages"]:
print("πΈ Message:", m.content)
final_message = event["messages"][-1].content
return final_message or "No output generated"
def process_input(user_input):
# π Toggle this to test graph generation without LLM
STATIC_TEST = False
if STATIC_TEST:
dummy_output = """
FINAL ANSWER:
Here is the GDP of the USA:
2019: 21.4
2020: 20.9
2021: 22.1
2022: 23.0
2023: 24.3
"""
years, values = extract_chart_data(dummy_output)
if years and values:
chart = generate_plot(years, values)
return dummy_output, chart
else:
return dummy_output, None
# Run actual LangGraph-based flow
result_text = run_langgraph(user_input)
years, values = extract_chart_data(result_text)
if years and values:
chart = generate_plot(years, values)
return result_text, chart
else:
return result_text, None
# Gradio interface
interface = gr.Interface(
fn=process_input,
inputs="text",
outputs=[
gr.Textbox(label="Generated Response"),
gr.Image(type="pil", label="Generated Chart")
],
title="LangGraph Research Automation",
description="Enter your research task (e.g., 'Get GDP data for the USA over the past 5 years and create a chart.')"
)
if __name__ == "__main__":
interface.launch(share=True, ssr_mode=False)
|