File size: 8,299 Bytes
4457016
8f6874d
 
 
 
 
 
 
 
 
74257c6
 
a50705b
 
 
 
 
 
 
 
 
 
e376dbf
8f6874d
74257c6
 
a8098c5
 
 
74257c6
8f6874d
 
 
e376dbf
8f6874d
 
 
 
 
a50705b
 
8f6874d
 
e376dbf
8f6874d
 
 
 
 
 
 
 
 
 
 
 
 
 
74257c6
 
 
 
8f6874d
74257c6
8f6874d
 
e376dbf
74257c6
a50705b
 
 
74257c6
8f6874d
 
 
 
 
a50705b
8f6874d
e376dbf
8f6874d
 
 
e376dbf
a50705b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e733e4
 
 
 
 
 
 
 
 
 
 
a50705b
 
1e733e4
a50705b
 
1e733e4
a50705b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e376dbf
 
a50705b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing import Annotated
from typing_extensions import TypedDict
from langchain_core.tools import tool
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.tools import WikipediaQueryRun
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
import gradio as gr
import os
import uuid
from datetime import datetime

# Get API key from Hugging Face Spaces secrets
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
if not GOOGLE_API_KEY:
    raise ValueError("Please set GOOGLE_API_KEY in your Hugging Face Spaces secrets")

os.environ['GOOGLE_API_KEY'] = GOOGLE_API_KEY

# Initialize Gemini Flash 2.0 Model
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001")

# Define the State
class State(TypedDict):
    messages: Annotated[list, add_messages]

# Tool 1: Wikipedia
wiki_api_wrapper = WikipediaAPIWrapper(top_k_results=1)
wikipedia_tool = WikipediaQueryRun(api_wrapper=wiki_api_wrapper)

# Tool 2: Historical Events
@tool
def historical_events(date_input: str) -> str:
    """Provide a list of important historical events for a given date."""
    try:
        res = llm.invoke(f"List important historical events that occurred on {date_input}.")
        return res.content
    except Exception as e:
        return f"Error: {str(e)}"

# Tool 3: Palindrome Checker
@tool
def palindrome_checker(text: str) -> str:
    """Check if a word or phrase is a palindrome."""
    cleaned = ''.join(c.lower() for c in text if c.isalnum())
    if cleaned == cleaned[::-1]:
        return f"'{text}' is a palindrome."
    else:
        return f"'{text}' is not a palindrome."

# Bind tools
tools = [wikipedia_tool, historical_events, palindrome_checker]
tool_node = ToolNode(tools=tools)

# Bind tools to the LLM
model_with_tools = llm.bind_tools(tools)

def should_continue(state: MessagesState):
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        return "tools"
    return END

def call_model(state: MessagesState):
    messages = state["messages"]
    response = model_with_tools.invoke(messages)
    return {"messages": [response]}

# Build LangGraph
builder = StateGraph(State)
builder.add_node("chatbot", call_model)
builder.add_node("tools", tool_node)
builder.add_edge(START, "chatbot")
builder.add_conditional_edges("chatbot", should_continue, {"tools": "tools", END: END})
builder.add_edge("tools", "chatbot")

# Add memory
memory = MemorySaver()
app = builder.compile(checkpointer=memory)

# Global conversation storage for each session
conversations = {}

def format_message_for_display(msg, msg_type="ai"):
    """Format a message for markdown display"""
    timestamp = datetime.now().strftime("%H:%M")
    
    if msg_type == "human":
        return f"**πŸ‘€ You** *({timestamp})*\n\n{msg}\n\n---\n"
    elif msg_type == "tool":
        tool_name = getattr(msg, 'name', 'Unknown Tool')
        return f"**πŸ”§ {tool_name}** *({timestamp})*\n```\n{msg.content}\n```\n"
    else:  # AI message
        return f"**πŸ€– Assistant** *({timestamp})*\n\n{msg.content}\n\n---\n"

def chatbot_conversation(message, history, session_id):
    """Main chatbot function that maintains conversation history"""
    
    # Generate session ID if not provided
    if not session_id:
        session_id = str(uuid.uuid4())
    
    # Get or create conversation config for this session
    config = {"configurable": {"thread_id": session_id}}
    
    # Initialize conversation history if new session
    if session_id not in conversations:
        conversations[session_id] = []
    
    # Add user message to display history
    conversations[session_id].append(("human", message))
    
    # Prepare input for LangGraph
    inputs = {"messages": [HumanMessage(content=message)]}
    
    try:
        # Invoke the app and get the complete response
        result = app.invoke(inputs, config)
        
        # Extract the final messages from the result
        final_messages = result["messages"]
        
        # Process the messages to separate tools and AI responses
        for msg in final_messages:
            if isinstance(msg, HumanMessage):
                continue  # Skip human messages as we already added them
            elif msg.content:
                if hasattr(msg, 'name') and msg.name:
                    # Tool response
                    conversations[session_id].append(("tool", msg))
                else:
                    # AI response
                    conversations[session_id].append(("ai", msg))
        
    except Exception as e:
        error_msg = f"❌ Error: {str(e)}"
        conversations[session_id].append(("ai", type('obj', (object,), {'content': error_msg})))
    
    # Format the entire conversation for display
    formatted_history = ""
    for msg_type, msg_content in conversations[session_id]:
        if msg_type == "human":
            formatted_history += format_message_for_display(msg_content, "human")
        elif msg_type == "tool":
            formatted_history += format_message_for_display(msg_content, "tool")
        else:  # ai
            formatted_history += format_message_for_display(msg_content, "ai")
    
    return formatted_history, session_id

def clear_conversation():
    """Clear the current conversation"""
    return "", str(uuid.uuid4())

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="πŸš€ Gemini Flash 2.0 Chatbot") as demo:
    gr.Markdown("""
    # πŸš€ Gemini Flash 2.0 + LangGraph Chatbot
    
    **LangGraph-powered conversational AI using Google's Gemini Flash 2.0**
    
    πŸ” **Available Tools:**
    - πŸ“š **Wikipedia Search** - Get information from Wikipedia
    - πŸ”„ **Palindrome Checker** - Check if text is a palindrome  
    - πŸ“… **Historical Events** - Find events that happened on specific dates
    
    πŸ’‘ **Try asking:** *"Tell me about AI, then check if 'radar' is a palindrome"*
    """)
    
    with gr.Row():
        with gr.Column(scale=4):
            # Chat history display
            chat_history = gr.Markdown(
                value="πŸ€– **Assistant**: Hello! I'm your AI assistant powered by Gemini Flash 2.0. I can help you with Wikipedia searches, check palindromes, and find historical events. What would you like to know?\n\n---\n",
                label="πŸ’¬ Conversation"
            )
            
            # Input area
            with gr.Row():
                message_input = gr.Textbox(
                    placeholder="Type your message here...",
                    label="Your message",
                    scale=4,
                    lines=2
                )
                send_btn = gr.Button("Send πŸš€", scale=1, variant="primary")
            
            # Control buttons
            with gr.Row():
                clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
                
        with gr.Column(scale=1):
            gr.Markdown("""
            ### πŸ’‘ Example Queries:
            - "What is machine learning?"
            - "Is 'level' a palindrome?"
            - "What happened on June 6, 1944?"
            - "Tell me about Python programming"
            - "Check if 'A man a plan a canal Panama' is a palindrome"
            """)
    
    # Hidden session ID state
    session_id = gr.State(value=str(uuid.uuid4()))
    
    # Event handlers
    def send_message(message, history, session_id):
        if message.strip():
            return chatbot_conversation(message, history, session_id) + ("",)
        return history, session_id, message
    
    send_btn.click(
        send_message,
        inputs=[message_input, chat_history, session_id],
        outputs=[chat_history, session_id, message_input]
    )
    
    message_input.submit(
        send_message,
        inputs=[message_input, chat_history, session_id],
        outputs=[chat_history, session_id, message_input]
    )
    
    clear_btn.click(
        clear_conversation,
        outputs=[chat_history, session_id]
    )

if __name__ == "__main__":
    demo.launch()