Spaces:
Sleeping
Sleeping
import os | |
from typing import Annotated | |
from typing_extensions import TypedDict | |
from langgraph.graph import StateGraph, START, END | |
from langgraph.graph.message import add_messages | |
from langchain_huggingface import HuggingFaceEndpoint | |
from dotenv import load_dotenv | |
import logging | |
import gradio as gr | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Initialize Hugging Face endpoint | |
llm = HuggingFaceEndpoint( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.3", | |
huggingfacehub_api_token=HF_TOKEN.strip(), | |
temperature=0.7, | |
max_new_tokens=200 | |
) | |
# Define the state structure | |
class State(TypedDict): | |
messages: Annotated[list, add_messages] | |
# Create a state graph builder | |
graph_builder = StateGraph(State) | |
# Define the chatbot function | |
def chatbot(state: State): | |
try: | |
logging.info(f"Input Messages: {state['messages']}") | |
response = llm.invoke(state["messages"]) | |
logging.info(f"LLM Response: {response}") | |
return {"messages": [response]} | |
except Exception as e: | |
logging.error(f"Error: {str(e)}") | |
return {"messages": [f"Error: {str(e)}"]} | |
# Add nodes and edges to the state graph | |
graph_builder.add_node("chatbot", chatbot) | |
graph_builder.add_edge(START, "chatbot") | |
graph_builder.add_edge("chatbot", END) | |
# Compile the state graph | |
graph = graph_builder.compile() | |
# Function to stream updates from the graph | |
def stream_graph_updates(user_input: str): | |
""" | |
Stream updates from the graph based on user input and return the assistant's reply. | |
""" | |
assistant_reply = "" | |
for event in graph.stream({"messages": [("user", user_input)]}): | |
for value in event.values(): | |
if isinstance(value["messages"][-1], dict): | |
# If it's a dict, extract 'content' | |
assistant_reply = value["messages"][-1].get("content", "") | |
elif isinstance(value["messages"][-1], str): | |
# If it's a string, use it directly | |
assistant_reply = value["messages"][-1] | |
return assistant_reply | |
# Gradio chatbot function using the streaming updates | |
def gradio_chatbot(user_message: str): | |
""" | |
Handle Gradio user input, process through the graph, and return only the assistant's reply. | |
""" | |
try: | |
return stream_graph_updates(user_message) | |
except Exception as e: | |
logging.error(f"Error in Gradio chatbot: {str(e)}") | |
return f"Error: {str(e)}" | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=gradio_chatbot, | |
inputs=gr.Textbox(placeholder="Enter your message", label="Your Message"), | |
outputs=gr.Textbox(label="Assistant's Reply"), | |
title="Chatbot", | |
description="Interactive chatbot using a state graph and Hugging Face Endpoint." | |
) | |
if __name__ == "__main__": | |
interface.launch(share=True) | |