pratikshahp commited on
Commit
0458e3c
·
verified ·
1 Parent(s): 4ead768

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Annotated
3
+ from typing_extensions import TypedDict
4
+ from langgraph.graph import StateGraph, START, END
5
+ from langgraph.graph.message import add_messages
6
+ from langchain_huggingface import HuggingFaceEndpoint
7
+ from dotenv import load_dotenv
8
+ import logging
9
+ import gradio as gr
10
+
11
+ # Initialize logging
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+
18
+ # Initialize Hugging Face endpoint
19
+ llm = HuggingFaceEndpoint(
20
+ repo_id="mistralai/Mistral-7B-Instruct-v0.3",
21
+ huggingfacehub_api_token=HF_TOKEN.strip(),
22
+ temperature=0.7,
23
+ max_new_tokens=200
24
+ )
25
+
26
+ # Define the state structure
27
+ class State(TypedDict):
28
+ messages: Annotated[list, add_messages]
29
+
30
+ # Create a state graph builder
31
+ graph_builder = StateGraph(State)
32
+
33
+ # Define the chatbot function
34
+ def chatbot(state: State):
35
+ try:
36
+ logging.info(f"Input Messages: {state['messages']}")
37
+ response = llm.invoke(state["messages"])
38
+ logging.info(f"LLM Response: {response}")
39
+ return {"messages": [response]}
40
+ except Exception as e:
41
+ logging.error(f"Error: {str(e)}")
42
+ return {"messages": [f"Error: {str(e)}"]}
43
+
44
+ # Add nodes and edges to the state graph
45
+ graph_builder.add_node("chatbot", chatbot)
46
+ graph_builder.add_edge(START, "chatbot")
47
+ graph_builder.add_edge("chatbot", END)
48
+
49
+ # Compile the state graph
50
+ graph = graph_builder.compile()
51
+
52
+ # Function to stream updates from the graph
53
+ def stream_graph_updates(user_input: str):
54
+ """
55
+ Stream updates from the graph based on user input and return the assistant's reply.
56
+ """
57
+ assistant_reply = ""
58
+ for event in graph.stream({"messages": [("user", user_input)]}):
59
+ for value in event.values():
60
+ assistant_reply = value["messages"][-1].content
61
+ return assistant_reply
62
+
63
+ # Gradio chatbot function using the streaming updates
64
+ def gradio_chatbot(user_message: str):
65
+ """
66
+ Handle Gradio user input, process through the graph, and return only the assistant's reply.
67
+ """
68
+ try:
69
+ return stream_graph_updates(user_message)
70
+ except Exception as e:
71
+ logging.error(f"Error in Gradio chatbot: {str(e)}")
72
+ return f"Error: {str(e)}"
73
+
74
+ # Create Gradio interface
75
+ interface = gr.Interface(
76
+ fn=gradio_chatbot,
77
+ inputs=gr.Textbox(placeholder="Enter your message", label="Your Message"),
78
+ outputs=gr.Textbox(label="Assistant's Reply"),
79
+ title="Chatbot",
80
+ description="Interactive chatbot using a state graph and Hugging Face Endpoint."
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ interface.launch()