Spaces:
Sleeping
Sleeping
File size: 5,548 Bytes
5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 e5abf47 dfde201 e5abf47 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 5b2ce15 dfde201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import gradio as gr
import json
from typing import Annotated
from typing_extensions import TypedDict
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import ToolMessage
from dotenv import load_dotenv
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
# Initialize the HuggingFace model
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
huggingfacehub_api_token=HF_TOKEN.strip(),
temperature=0.7,
max_new_tokens=200
)
# Initialize Tavily Search tool
tool = TavilySearchResults(max_results=2)
tools = [tool]
# Define the state structure
class State(TypedDict):
messages: Annotated[list, add_messages]
# Create a state graph builder
graph_builder = StateGraph(State)
# Define the chatbot function
def chatbot(state: State):
try:
# Get the last message and ensure it's a string
input_message = state["messages"][-1] if state["messages"] else ""
# Ensure that input_message is a string (check the type)
if isinstance(input_message, str):
query = input_message # If it's already a string, use it directly
elif hasattr(input_message, 'content') and isinstance(input_message.content, str):
query = input_message.content # Extract the content if it's a HumanMessage object
else:
raise ValueError("Input message is not in the correct format")
logging.info(f"Input Message: {query}")
# Invoke the LLM for a response
response = llm.invoke([query])
logging.info(f"LLM Response: {response}")
# Now, invoke Tavily Search and get the results
search_results = tool.invoke({"query": query})
# Extract URLs from search results
urls = [result.get("url", "No URL found") for result in search_results]
# Prepare the result to include URL information
result_with_url = {
"role": "assistant", # Set the role to 'assistant'
"content": response, # Set the response as content
"urls": urls # Include the URLs of the search results
}
return {"messages": state["messages"] + [result_with_url]}
except Exception as e:
logging.error(f"Error: {str(e)}")
return {"messages": state["messages"] + [f"Error: {str(e)}"]}
# Add tool node to the graph
class BasicToolNode:
"""A node that runs the tools requested in the last AIMessage."""
def __init__(self, tools: list) -> None:
self.tools_by_name = {tool.name: tool for tool in tools}
def __call__(self, inputs: dict):
if messages := inputs.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")
outputs = []
for tool_call in message.tool_calls:
tool_result = self.tools_by_name[tool_call["name"]].invoke(
tool_call["args"]
)
outputs.append(
ToolMessage(
content=json.dumps(tool_result),
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
# Add tool node to the graph
tool_node = BasicToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)
# Define the conditional routing function
def route_tools(state: State):
"""
Route to the ToolNode if the last message has tool calls.
Otherwise, route to the end.
"""
if isinstance(state, list):
ai_message = state[-1]
elif messages := state.get("messages", []):
ai_message = messages[-1]
else:
raise ValueError(f"No messages found in input state to tool_edge: {state}")
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
return "tools"
return END
# Add nodes and conditional edges to the state graph
graph_builder.add_node("chatbot", chatbot)
graph_builder.add_conditional_edges(
"chatbot",
route_tools,
{"tools": "tools", END: END}
)
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge(START, "chatbot")
graph = graph_builder.compile()
# Gradio interface
def chat_interface(input_text, state):
# Prepare state if not provided
if state is None:
state = {"messages": []}
# Append user input to state
state["messages"].append(input_text)
# Process state through the graph
updated_state = graph.invoke(state)
return updated_state["messages"][-1], updated_state
# Create Gradio app
with gr.Blocks() as demo:
gr.Markdown("### Chatbot with Tavily Search Integration")
chat_state = gr.State({"messages": []})
with gr.Row():
with gr.Column():
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2)
submit_button = gr.Button("Submit")
with gr.Column():
chatbot_output = gr.Textbox(label="Chatbot Response", interactive=False, lines=4)
submit_button.click(chat_interface, inputs=[user_input, chat_state], outputs=[chatbot_output, chat_state])
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()
|