Spaces:
Sleeping
Sleeping
from langgraph.graph import StateGraph, START, END, MessagesState | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from typing import Annotated | |
from typing_extensions import TypedDict | |
from langchain_core.tools import tool | |
from langchain_community.utilities import WikipediaAPIWrapper | |
from langchain_community.tools import WikipediaQueryRun | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import gradio as gr | |
import os | |
import uuid | |
from datetime import datetime | |
# Get API key from Hugging Face Spaces secrets | |
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
if not GOOGLE_API_KEY: | |
raise ValueError("Please set GOOGLE_API_KEY in your Hugging Face Spaces secrets") | |
os.environ['GOOGLE_API_KEY'] = GOOGLE_API_KEY | |
# Initialize Gemini Flash 2.0 Model | |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001") | |
# Define the State | |
class State(TypedDict): | |
messages: Annotated[list, add_messages] | |
# Tool 1: Wikipedia | |
wiki_api_wrapper = WikipediaAPIWrapper(top_k_results=1) | |
wikipedia_tool = WikipediaQueryRun(api_wrapper=wiki_api_wrapper) | |
# Tool 2: Historical Events | |
def historical_events(date_input: str) -> str: | |
"""Provide a list of important historical events for a given date.""" | |
try: | |
res = llm.invoke(f"List important historical events that occurred on {date_input}.") | |
return res.content | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Tool 3: Palindrome Checker | |
def palindrome_checker(text: str) -> str: | |
"""Check if a word or phrase is a palindrome.""" | |
cleaned = ''.join(c.lower() for c in text if c.isalnum()) | |
if cleaned == cleaned[::-1]: | |
return f"'{text}' is a palindrome." | |
else: | |
return f"'{text}' is not a palindrome." | |
# Bind tools | |
tools = [wikipedia_tool, historical_events, palindrome_checker] | |
tool_node = ToolNode(tools=tools) | |
# Bind tools to the LLM | |
model_with_tools = llm.bind_tools(tools) | |
def should_continue(state: MessagesState): | |
last_message = state["messages"][-1] | |
if last_message.tool_calls: | |
return "tools" | |
return END | |
def call_model(state: MessagesState): | |
messages = state["messages"] | |
response = model_with_tools.invoke(messages) | |
return {"messages": [response]} | |
# Build LangGraph | |
builder = StateGraph(State) | |
builder.add_node("chatbot", call_model) | |
builder.add_node("tools", tool_node) | |
builder.add_edge(START, "chatbot") | |
builder.add_conditional_edges("chatbot", should_continue, {"tools": "tools", END: END}) | |
builder.add_edge("tools", "chatbot") | |
# Add memory | |
memory = MemorySaver() | |
app = builder.compile(checkpointer=memory) | |
# Global conversation storage for each session | |
conversations = {} | |
def format_message_for_display(msg, msg_type="ai"): | |
"""Format a message for markdown display""" | |
timestamp = datetime.now().strftime("%H:%M") | |
if msg_type == "human": | |
return f"**π€ You** *({timestamp})*\n\n{msg}\n\n---\n" | |
elif msg_type == "tool": | |
tool_name = getattr(msg, 'name', 'Unknown Tool') | |
return f"**π§ {tool_name}** *({timestamp})*\n```\n{msg.content}\n```\n" | |
else: # AI message | |
return f"**π€ Assistant** *({timestamp})*\n\n{msg.content}\n\n---\n" | |
def chatbot_conversation(message, history, session_id): | |
"""Main chatbot function that maintains conversation history""" | |
# Generate session ID if not provided | |
if not session_id: | |
session_id = str(uuid.uuid4()) | |
# Get or create conversation config for this session | |
config = {"configurable": {"thread_id": session_id}} | |
# Initialize conversation history if new session | |
if session_id not in conversations: | |
conversations[session_id] = [] | |
# Add user message to display history | |
conversations[session_id].append(("human", message)) | |
# Prepare input for LangGraph | |
inputs = {"messages": [HumanMessage(content=message)]} | |
try: | |
# Invoke the app and get the complete response | |
result = app.invoke(inputs, config) | |
# Extract the final messages from the result | |
final_messages = result["messages"] | |
# Process the messages to separate tools and AI responses | |
for msg in final_messages: | |
if isinstance(msg, HumanMessage): | |
continue # Skip human messages as we already added them | |
elif msg.content: | |
if hasattr(msg, 'name') and msg.name: | |
# Tool response | |
conversations[session_id].append(("tool", msg)) | |
else: | |
# AI response | |
conversations[session_id].append(("ai", msg)) | |
except Exception as e: | |
error_msg = f"β Error: {str(e)}" | |
conversations[session_id].append(("ai", type('obj', (object,), {'content': error_msg}))) | |
# Format the entire conversation for display | |
formatted_history = "" | |
for msg_type, msg_content in conversations[session_id]: | |
if msg_type == "human": | |
formatted_history += format_message_for_display(msg_content, "human") | |
elif msg_type == "tool": | |
formatted_history += format_message_for_display(msg_content, "tool") | |
else: # ai | |
formatted_history += format_message_for_display(msg_content, "ai") | |
return formatted_history, session_id | |
def clear_conversation(): | |
"""Clear the current conversation""" | |
return "", str(uuid.uuid4()) | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft(), title="π Gemini Flash 2.0 Chatbot") as demo: | |
gr.Markdown(""" | |
# π Gemini Flash 2.0 + LangGraph Chatbot | |
**LangGraph-powered conversational AI using Google's Gemini Flash 2.0** | |
π **Available Tools:** | |
- π **Wikipedia Search** - Get information from Wikipedia | |
- π **Palindrome Checker** - Check if text is a palindrome | |
- π **Historical Events** - Find events that happened on specific dates | |
π‘ **Try asking:** *"Tell me about AI, then check if 'radar' is a palindrome"* | |
""") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
# Chat history display | |
chat_history = gr.Markdown( | |
value="π€ **Assistant**: Hello! I'm your AI assistant powered by Gemini Flash 2.0. I can help you with Wikipedia searches, check palindromes, and find historical events. What would you like to know?\n\n---\n", | |
label="π¬ Conversation" | |
) | |
# Input area | |
with gr.Row(): | |
message_input = gr.Textbox( | |
placeholder="Type your message here...", | |
label="Your message", | |
scale=4, | |
lines=2 | |
) | |
send_btn = gr.Button("Send π", scale=1, variant="primary") | |
# Control buttons | |
with gr.Row(): | |
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
with gr.Column(scale=1): | |
gr.Markdown(""" | |
### π‘ Example Queries: | |
- "What is machine learning?" | |
- "Is 'level' a palindrome?" | |
- "What happened on June 6, 1944?" | |
- "Tell me about Python programming" | |
- "Check if 'A man a plan a canal Panama' is a palindrome" | |
""") | |
# Hidden session ID state | |
session_id = gr.State(value=str(uuid.uuid4())) | |
# Event handlers | |
def send_message(message, history, session_id): | |
if message.strip(): | |
return chatbot_conversation(message, history, session_id) + ("",) | |
return history, session_id, message | |
send_btn.click( | |
send_message, | |
inputs=[message_input, chat_history, session_id], | |
outputs=[chat_history, session_id, message_input] | |
) | |
message_input.submit( | |
send_message, | |
inputs=[message_input, chat_history, session_id], | |
outputs=[chat_history, session_id, message_input] | |
) | |
clear_btn.click( | |
clear_conversation, | |
outputs=[chat_history, session_id] | |
) | |
if __name__ == "__main__": | |
demo.launch() |