SRA25 commited on
Commit
6c98f1c
Β·
verified Β·
1 Parent(s): 926b19a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +256 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,258 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import uuid
3
+ import hashlib
4
+ from typing import List, Optional, Dict, Any, TypedDict,Generic, TypeVar
5
+ from huggingface_hub import login
6
+ import logging
7
+ import time
8
+ import os
9
+ from dotenv import load_dotenv
10
+ from mydomain_agent import upload_documents, submit_feedback, get_conversations,get_conversation_history, chat_with_rag
11
 
12
+ load_dotenv()
13
+
14
+ # --- 2. Streamlit UI Components and State Management ---
15
+ st.set_page_config(page_title="Agentic WorkFlow", layout="wide")
16
+ st.title("πŸ’¬ Domain-Aware AI Agent")
17
+ st.caption("Your expert assistant across HR, Finance, and Legal Compliance.")
18
+
19
+ # Initialize session state for conversations, messages, and the current session ID
20
+ if "conversations" not in st.session_state:
21
+ st.session_state.conversations = []
22
+ if "session_id" not in st.session_state:
23
+ st.session_state.session_id = str(uuid.uuid4())
24
+ if "messages" not in st.session_state:
25
+ st.session_state.messages = []
26
+ if "retriever_ready" not in st.session_state:
27
+ st.session_state.retriever_ready = False
28
+ if "feedback_given" not in st.session_state:
29
+ st.session_state.feedback_given = {}
30
+ # New state variable to handle negative feedback comments
31
+ if "negative_feedback_for" not in st.session_state:
32
+ st.session_state.negative_feedback_for = None
33
+
34
+ # Initialize session state for storing uploaded file hashes
35
+ if 'uploaded_file_hashes' not in st.session_state:
36
+ st.session_state.uploaded_file_hashes = set()
37
+ if 'uploaded_files_info' not in st.session_state:
38
+ st.session_state.uploaded_files_info = []
39
+
40
+ def get_file_hash(file):
41
+ """Generates a unique hash for a file using its name, size, and content."""
42
+ hasher = hashlib.sha256()
43
+ # Read a small chunk of the file to ensure content-based uniqueness
44
+ # Combine with file name and size for a robust identifier
45
+ file_content = file.getvalue()
46
+ hasher.update(file.name.encode('utf-8'))
47
+ hasher.update(str(file.size).encode('utf-8'))
48
+ hasher.update(file_content[:1024]) # Use first 1KB of content
49
+ return hasher.hexdigest()
50
+ # --- 3. Helper Functions for Backend Communication ---
51
+ # def send_documents_to_backend(uploaded_files):
52
+ # try:
53
+ # for file in uploaded_files:
54
+ # process_status = upload_documents(file)
55
+ # return process_status
56
+ # except Exception as e:
57
+ # st.error(f"Error processing documents: {e}")
58
+ # return None
59
+
60
+ def send_chat_message_to_backend(prompt: str, chat_history: List[Dict[str, Any]]):
61
+ """Sends a chat message to the FastAPI backend and handles the response."""
62
+ if not prompt.strip():
63
+ return {"empty":"Invalid Question"}
64
+ history_for_api = [
65
+ {"role": msg.get("role"), "content": msg.get("content")}
66
+ for msg in chat_history
67
+ ]
68
+
69
+ payload = {
70
+ "user_question": str(prompt),
71
+ "session_id": st.session_state.session_id,
72
+ "chat_history": history_for_api,
73
+ }
74
+ print(f"Sending payload: {payload}") # Debug print
75
+ agent_name,response_dict = chat_with_rag(payload)
76
+ try:
77
+ return agent_name,response_dict
78
+ except Exception as e:
79
+ st.error(f"Error communicating with the backend")
80
+ print(f"Error communicating with the backend: {e}")
81
+ return None
82
+
83
+ def send_feedback_to_backend(telemetry_entry_id: str, feedback_score: int, feedback_text: Optional[str] = None):
84
+ """Sends feedback to the FastAPI backend."""
85
+ payload = {
86
+ "session_id": st.session_state.session_id,
87
+ "telemetry_entry_id": telemetry_entry_id,
88
+ "feedback_score": feedback_score,
89
+ "feedback_text": feedback_text
90
+ }
91
+ try:
92
+ # response = requests.post(f"{API_URL}/feedback", json=payload)
93
+ response = submit_feedback(payload)
94
+ # response.raise_for_status()
95
+ st.toast("Feedback submitted! Thank you.")
96
+ except Exception as e:
97
+ st.error(f"Error submitting feedback: {e}")
98
+
99
+ def get_conversations_from_backend() -> list:
100
+ """Fetches a list of all conversations from the backend."""
101
+ try:
102
+ # response = requests.get(f"{API_URL}/conversations")
103
+ response = get_conversations()
104
+ # response.raise_for_status()
105
+ return response
106
+ except Exception as e:
107
+ st.sidebar.error(f"Error fetching conversations: {e}")
108
+ return []
109
+
110
+ def get_conversation_history_from_backend(session_id: str):
111
+ """Fetches the messages for a specific conversation ID."""
112
+ try:
113
+ # response = requests.get(f"{API_URL}/conversations/{session_id}")
114
+
115
+ response = get_conversation_history(session_id)
116
+ return response
117
+ except Exception as e:
118
+ st.error(f"Error loading conversation history: {e}")
119
+ return None
120
+
121
+ def handle_positive_feedback(telemetry_id):
122
+ """Handles positive feedback submission."""
123
+ send_feedback_to_backend(telemetry_id, 1)
124
+ st.session_state.feedback_given[telemetry_id] = True
125
+
126
+
127
+ def handle_negative_feedback_comment_submit(telemetry_id, comment_text):
128
+ """Handles the negative feedback comment submission."""
129
+ send_feedback_to_backend(telemetry_id, -1, comment_text)
130
+ st.session_state.feedback_given[telemetry_id] = True
131
+ st.session_state.negative_feedback_for = None
132
+
133
+
134
+ def refresh_conversations():
135
+ """Refreshes the conversation list in the sidebar."""
136
+ st.session_state.conversations = get_conversations_from_backend()
137
+
138
+ # --- 4. Sidebar for Document Upload and Conversation History ---
139
+ with st.sidebar:
140
+ st.header("Load Documents")
141
+ if st.button("Process Documents", key="process_docs_button"):
142
+ newmsg, status = upload_documents()
143
+ if status:
144
+ st.session_state.retriever_ready = True
145
+ # st.success(response_data.get("message", "Documents processed and knowledge base ready!"))
146
+ st.success(newmsg)
147
+ st.session_state.messages = []
148
+ refresh_conversations() # sql query need to be added here
149
+ else:
150
+ st.session_state.retriever_ready = False
151
+ st.error(newmsg)
152
+ else:
153
+ st.warning("Please Load Document.")
154
+
155
+ st.markdown("---")
156
+ st.header("Conversations")
157
+ if st.button("βž• New Chat", key="new_chat_button", use_container_width=True, type="primary"):
158
+ st.session_state.session_id = str(uuid.uuid4())
159
+ st.session_state.messages = []
160
+ st.session_state.feedback_given = {}
161
+ st.session_state.negative_feedback_for = None
162
+ refresh_conversations()
163
+ st.rerun()
164
+
165
+ refresh_conversations()
166
+
167
+ if st.session_state.conversations:
168
+ for conv in st.session_state.conversations:
169
+ if st.button(
170
+ conv["title"],
171
+ key=f"conv_{conv['session_id']}",
172
+ use_container_width=True
173
+ ):
174
+ if st.session_state.session_id != conv["session_id"]:
175
+ st.session_state.session_id = conv["session_id"]
176
+ history = get_conversation_history_from_backend(conv["session_id"])
177
+ if history != [] or history != None:
178
+ st.session_state.messages = history
179
+ st.session_state.feedback_given = {msg.get("telemetry_id"): True for msg in history if msg.get("telemetry_id")}
180
+ else:
181
+ st.session_state.messages = []
182
+ st.session_state.negative_feedback_for = None
183
+ st.rerun()
184
+
185
+ # --- 5. Main Chat Interface ---
186
+ # Display chat messages from history on app rerun
187
+ for message in st.session_state.messages:
188
+ with st.chat_message(message["role"]):
189
+ st.markdown(message["content"])
190
+
191
+ # Display feedback buttons for the last AI response
192
+ if message["role"] == "assistant" and message.get("telemetry_id") and not st.session_state.feedback_given.get(message["telemetry_id"], False):
193
+ col1, col2 = st.columns(2)
194
+ with col1:
195
+ if st.button("πŸ‘", key=f"positive_{message['telemetry_id']}", on_click=handle_positive_feedback, args=(message['telemetry_id'],)):
196
+ pass
197
+ with col2:
198
+ if st.button("πŸ‘Ž", key=f"negative_{message['telemetry_id']}"):
199
+ st.session_state.negative_feedback_for = message['telemetry_id']
200
+ st.rerun()
201
+
202
+ # --- NEW LOGIC FOR NEGATIVE FEEDBACK COMMENT ---
203
+ # Only render the comment input if this is the message the user clicked thumbs down on
204
+ if st.session_state.negative_feedback_for == message['telemetry_id']:
205
+ with st.container():
206
+ comment = st.text_area(
207
+ "Please provide some details (optional):",
208
+ key=f"feedback_text_{message['telemetry_id']}"
209
+ )
210
+ if st.button("Submit Comment", key=f"submit_feedback_button_{message['telemetry_id']}"):
211
+ handle_negative_feedback_comment_submit(message['telemetry_id'], comment)
212
+
213
+ # Chat input for new questions
214
+ if st.session_state.retriever_ready:
215
+ if prompt := st.chat_input("Ask me anything about the uploaded documents..."):
216
+ st.session_state.messages.append({"role": "user", "content": prompt})
217
+ with st.chat_message("user"):
218
+ st.markdown(prompt)
219
+
220
+ with st.chat_message("assistant"):
221
+ with st.spinner("Thinking..."):
222
+ agent_name,response_data = send_chat_message_to_backend(prompt, st.session_state.messages)
223
+ if response_data:
224
+ if response_data.get("is_restricted"):
225
+ ai_response = response_data.get("ai_response", "Sorry, I couldn't generate a response.")
226
+ reason = response_data.get("moderation_reason")
227
+ st.markdown(ai_response)
228
+ st.markdown(reason)
229
+ elif response_data.get("empty"):
230
+ st.markdown(response_data.get("empty"))
231
+
232
+ ai_response = response_data.get("ai_response", "Sorry, I couldn't generate a response.")
233
+ telemetry_id = response_data.get("telemetry_entry_id")
234
+
235
+ st.markdown(ai_response)
236
+ st.caption(agent_name)
237
+
238
+ st.session_state.messages.append({
239
+ "role": "assistant",
240
+ "content": ai_response,
241
+ "telemetry_id": telemetry_id
242
+ })
243
+
244
+ refresh_conversations()
245
+
246
+ if telemetry_id:
247
+ col1, col2 = st.columns(2)
248
+ with col1:
249
+ if st.button("πŸ‘", key=f"positive_{telemetry_id}", on_click=handle_positive_feedback, args=(telemetry_id,)):
250
+ pass
251
+ with col2:
252
+ if st.button("πŸ‘Ž", key=f"negative_{telemetry_id}"):
253
+ st.session_state.negative_feedback_for = telemetry_id
254
+ st.rerun()
255
+ else:
256
+ st.markdown("An error occurred.")
257
+ else:
258
+ st.info("Please upload and process documents to start chatting.")