jaelin215 commited on
Commit
00e4075
1 Parent(s): ba852a7

updated from dagshub

Browse files
Files changed (3) hide show
  1. llama_guard.py +99 -29
  2. requirements.txt +6 -2
  3. streamlit_app.py +344 -0
llama_guard.py CHANGED
@@ -1,44 +1,82 @@
 
 
 
 
1
  from dotenv import load_dotenv, find_dotenv
2
  import os
3
  import requests
4
 
 
 
 
5
  # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
6
  load_dotenv(find_dotenv())
7
  HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
8
 
9
- with open("data/llama_guard-unsafe_files.txt", "r") as file:
10
- unsafe_categories = file.read()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
13
  def moderation_prompt_for_chat(chat):
14
- # For simplicity, we assume the chat dictionary correctly alternates "user" and "assistant" roles
15
- # Feel free to include error checking if needed
16
- conversation = [turn["content"] for turn in chat]
17
- is_agent = len(conversation) % 2 == 0
18
- role = "Agent" if len(conversation) % 2 == 0 else "User"
19
- prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories.
20
-
21
- <BEGIN UNSAFE CONTENT CATEGORIES>
22
- {unsafe_categories}
23
- <END UNSAFE CONTENT CATEGORIES>
24
-
25
- <BEGIN CONVERSATION>
26
-
27
- """
28
-
29
- # Alternate User/Agent turns, inserting 2 newlines between each
30
- for i, m in enumerate(conversation):
31
- role = "User" if i % 2 == 0 else "Agent"
32
- prompt += f"{role}: {m}\n\n"
33
- prompt += f"""<END CONVERSATION>
34
-
35
- Provide your safety assessment for {role} in the above conversation:
36
- - First line must read 'safe' or 'unsafe'.
37
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
38
  return prompt
39
 
40
 
41
  def query(payload):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
43
  bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
44
  headers = {
@@ -47,14 +85,13 @@ def query(payload):
47
  "Content-Type": "application/json",
48
  }
49
  response = requests.post(API_URL, headers=headers, json=payload)
50
-
51
  return response.json()
52
 
53
 
54
  def moderate_chat(chat):
55
  prompt = moderation_prompt_for_chat(chat)
56
 
57
- output = query(
58
  {
59
  "inputs": prompt,
60
  "parameters": {
@@ -66,4 +103,37 @@ def moderate_chat(chat):
66
  }
67
  )
68
 
69
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ # Docs:- https://huggingface.co/meta-llama/LlamaGuard-7b
5
  from dotenv import load_dotenv, find_dotenv
6
  import os
7
  import requests
8
 
9
+ model_id = "meta-llama/LlamaGuard-7b"
10
+
11
+
12
  # Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN'
13
  load_dotenv(find_dotenv())
14
  HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
15
 
16
+ # updated on March 24th
 
17
 
18
+ # reading unsafe categories file
19
+ try:
20
+ with open("data/processed/llama_guard-unsafe_categories.txt", "r") as file:
21
+ unsafe_categories = file.read()
22
+ except FileNotFoundError:
23
+ print("File not found. Please check the file path.")
24
+ unsafe_categories = (
25
+ "" # Or handle the missing file as appropriate for your application
26
+ )
27
+ except PermissionError:
28
+ print("Permission denied. You don't have the rights to read the file.")
29
+ unsafe_categories = "" # Or handle the lack of permissions as appropriate
30
+ except Exception as e: # Catch any other exceptions
31
+ print(f"An error occurred while reading the file: {e}")
32
+ unsafe_categories = "" # Or handle unexpected errors as appropriate
33
 
34
+
35
+ # added on March 24th
36
  def moderation_prompt_for_chat(chat):
37
+ prompt = f"""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
38
+ <BEGIN UNSAFE CONTENT CATEGORIES>
39
+ {unsafe_categories}
40
+ <END UNSAFE CONTENT CATEGORIES>
41
+
42
+ <BEGIN CONVERSATION>
43
+
44
+ User : {chat}
45
+
46
+ <END CONVERSATION>
47
+
48
+ Provide your safety assessment for 'User' in the above conversation:
49
+ - First line must read 'safe' or 'unsafe'.
50
+ - If unsafe, a second line must include a comma-separated list of violated categories.[/INST]"""
 
 
 
 
 
 
 
 
 
 
51
  return prompt
52
 
53
 
54
  def query(payload):
55
+ API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
56
+ bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
57
+ headers = {
58
+ "Accept": "application/json",
59
+ "Authorization": bearer_txt,
60
+ "Content-Type": "application/json",
61
+ }
62
+ try:
63
+ response = requests.post(API_URL, headers=headers, json=payload)
64
+ response.raise_for_status() # This will raise an exception for HTTP error responses
65
+ return response.json(), None
66
+ except requests.exceptions.HTTPError as http_err:
67
+ error_message = f"HTTP error occurred: {http_err}"
68
+ print(error_message)
69
+ except requests.exceptions.ConnectionError:
70
+ error_message = "Could not connect to the API endpoint."
71
+ print(error_message)
72
+ except Exception as err:
73
+ error_message = f"An error occurred: {err}"
74
+ print(error_message)
75
+
76
+ return None, error_message
77
+
78
+
79
+ def query1(payload):
80
  API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud"
81
  bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}"
82
  headers = {
 
85
  "Content-Type": "application/json",
86
  }
87
  response = requests.post(API_URL, headers=headers, json=payload)
 
88
  return response.json()
89
 
90
 
91
  def moderate_chat(chat):
92
  prompt = moderation_prompt_for_chat(chat)
93
 
94
+ output, error_msg = query(
95
  {
96
  "inputs": prompt,
97
  "parameters": {
 
103
  }
104
  )
105
 
106
+ print("Llamaguard prompt****", prompt)
107
+ print("Llamaguard output****", output)
108
+
109
+ return output, error_msg
110
+
111
+
112
+ # added on March 24th
113
+ def load_category_names_from_string(file_content):
114
+ """Load category codes and names from a string into a dictionary."""
115
+ category_names = {}
116
+ lines = file_content.split("\n")
117
+ for line in lines:
118
+ if line.startswith("O"):
119
+ parts = line.split(":")
120
+ if len(parts) == 2:
121
+ code = parts[0].strip()
122
+ name = parts[1].strip()
123
+ category_names[code] = name
124
+ return category_names
125
+
126
+
127
+ def get_category_name(input_str):
128
+ """Return the category name given a category code from an input string."""
129
+ # Load the category names from the file content
130
+ category_names = load_category_names_from_string(unsafe_categories)
131
+
132
+ # Extract the category code from the input string
133
+ category_code = input_str.split("\n")[1].strip()
134
+
135
+ # Find the full category name using the code
136
+ category_name = category_names.get(category_code, "Unknown Category")
137
+
138
+ # return f"{category_code} : {category_name}"
139
+ return f"{category_name}"
requirements.txt CHANGED
@@ -7,6 +7,7 @@ appnope==0.1.4
7
  asttokens==2.4.1
8
  async-timeout==4.0.3
9
  attrs==23.2.0
 
10
  black==24.2.0
11
  blinker==1.7.0
12
  cachetools==5.3.3
@@ -34,6 +35,7 @@ google-auth==2.28.2
34
  googleapis-common-protos==1.62.0
35
  grpcio==1.62.1
36
  grpcio-status==1.62.1
 
37
  h11==0.14.0
38
  httpcore==1.0.4
39
  httpx==0.27.0
@@ -113,9 +115,12 @@ seaborn==0.13.2
113
  six==1.16.0
114
  smmap==5.0.1
115
  sniffio==1.3.1
 
 
116
  SQLAlchemy==2.0.28
117
  stack-data==0.6.3
118
- streamlit==1.32.0
 
119
  sympy==1.12
120
  tenacity==8.2.3
121
  threadpoolctl==3.3.0
@@ -137,4 +142,3 @@ wcwidth==0.2.13
137
  xgboost==2.0.3
138
  yarl==1.9.4
139
  zipp==3.17.0
140
- beautifulsoup4
 
7
  asttokens==2.4.1
8
  async-timeout==4.0.3
9
  attrs==23.2.0
10
+ beautifulsoup4==4.12.3
11
  black==24.2.0
12
  blinker==1.7.0
13
  cachetools==5.3.3
 
35
  googleapis-common-protos==1.62.0
36
  grpcio==1.62.1
37
  grpcio-status==1.62.1
38
+ gTTS==2.5.1
39
  h11==0.14.0
40
  httpcore==1.0.4
41
  httpx==0.27.0
 
115
  six==1.16.0
116
  smmap==5.0.1
117
  sniffio==1.3.1
118
+ soupsieve==2.5
119
+ SpeechRecognition==3.10.1
120
  SQLAlchemy==2.0.28
121
  stack-data==0.6.3
122
+ streamlit==1.32.2
123
+ streamlit_mic_recorder==0.0.8
124
  sympy==1.12
125
  tenacity==8.2.3
126
  threadpoolctl==3.3.0
 
142
  xgboost==2.0.3
143
  yarl==1.9.4
144
  zipp==3.17.0
 
streamlit_app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # - Author: Jaelin Lee, Abhishek Dutta
3
+ # - Date: Mar 23, 2024
4
+ # - Description: Streamlit UI for mental health support chatbot using sentiment analsys, RL, BM25/ChromaDB, and LLM.
5
+
6
+ # - Note:
7
+ # - Updated to UI to show predicted mental health condition in behind the scence regardless of the ositive/negative sentiment
8
+ ###
9
+
10
+ from dotenv import load_dotenv, find_dotenv
11
+ import pandas as pd
12
+ import streamlit as st
13
+ from q_learning_chatbot import QLearningChatbot
14
+ from xgb_mental_health import MentalHealthClassifier
15
+ from bm25_retreive_question import QuestionRetriever as QuestionRetriever_bm25
16
+ from Chromadb_storage_JyotiNigam import QuestionRetriever as QuestionRetriever_chromaDB
17
+ from llm_response_generator import LLLResponseGenerator
18
+ import os
19
+ from llama_guard import moderate_chat, get_category_name
20
+
21
+ from gtts import gTTS
22
+ from io import BytesIO
23
+ from streamlit_mic_recorder import speech_to_text
24
+
25
+ import re
26
+
27
+ # Streamlit UI
28
+ st.title("MindfulMedia Mentor")
29
+
30
+ # Define states and actions
31
+ states = [
32
+ "Negative",
33
+ "Moderately Negative",
34
+ "Neutral",
35
+ "Moderately Positive",
36
+ "Positive",
37
+ ]
38
+ actions = ["encouragement", "empathy", "spiritual"]
39
+
40
+ # Initialize Q-learning chatbot and mental health classifier
41
+ chatbot = QLearningChatbot(states, actions)
42
+
43
+ # Initialize MentalHealthClassifier
44
+ # data_path = "/Users/jaelinlee/Documents/projects/fomo/input/data.csv"
45
+ data_path = os.path.join("data", "processed", "data.csv")
46
+ print(data_path)
47
+
48
+ tokenizer_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
49
+ mental_classifier_model_path = "app/mental_health_model.pkl"
50
+ mental_classifier = MentalHealthClassifier(data_path, mental_classifier_model_path)
51
+
52
+
53
+ # Function to display Q-table
54
+ def display_q_table(q_values, states, actions):
55
+ q_table_dict = {"State": states}
56
+ for i, action in enumerate(actions):
57
+ q_table_dict[action] = q_values[:, i]
58
+
59
+ q_table_df = pd.DataFrame(q_table_dict)
60
+ return q_table_df
61
+
62
+
63
+ def text_to_speech(text):
64
+ # Use gTTS to convert text to speech
65
+ tts = gTTS(text=text, lang="en")
66
+ # Save the speech as bytes in memory
67
+ fp = BytesIO()
68
+ tts.write_to_fp(fp)
69
+ return fp
70
+
71
+
72
+ def speech_recognition_callback():
73
+ # Ensure that speech output is available
74
+ if st.session_state.my_stt_output is None:
75
+ st.session_state.p01_error_message = "Please record your response again."
76
+ return
77
+
78
+ # Clear any previous error messages
79
+ st.session_state.p01_error_message = None
80
+
81
+ # Store the speech output in the session state
82
+ st.session_state.speech_input = st.session_state.my_stt_output
83
+
84
+
85
+ def remove_html_tags(text):
86
+ clean_text = re.sub("<.*?>", "", text)
87
+ return clean_text
88
+
89
+
90
+ # Initialize memory
91
+ if "entered_text" not in st.session_state:
92
+ st.session_state.entered_text = []
93
+ if "entered_mood" not in st.session_state:
94
+ st.session_state.entered_mood = []
95
+ if "messages" not in st.session_state:
96
+ st.session_state.messages = []
97
+ if "user_sentiment" not in st.session_state:
98
+ st.session_state.user_sentiment = "Neutral"
99
+ if "mood_trend" not in st.session_state:
100
+ st.session_state.mood_trend = "Unchanged"
101
+ if "predicted_mental_category" not in st.session_state:
102
+ st.session_state.predicted_mental_category = ""
103
+ if "ai_tone" not in st.session_state:
104
+ st.session_state.ai_tone = "Empathy"
105
+ if "mood_trend_symbol" not in st.session_state:
106
+ st.session_state.mood_trend_symbol = ""
107
+ if "show_question" not in st.session_state:
108
+ st.session_state.show_question = False
109
+ if "asked_questions" not in st.session_state:
110
+ st.session_state.asked_questions = []
111
+ # Check if 'llama_guard_enabled' is already in session state, otherwise initialize it
112
+ if "llama_guard_enabled" not in st.session_state:
113
+ st.session_state["llama_guard_enabled"] = True # Default value to True
114
+
115
+ # Select Question Retriever
116
+ selected_retriever_option = st.sidebar.selectbox(
117
+ "Choose Question Retriever", ("BM25", "ChromaDB")
118
+ )
119
+ if selected_retriever_option == "BM25":
120
+ retriever = QuestionRetriever_bm25()
121
+ if selected_retriever_option == "ChromaDB":
122
+ retriever = QuestionRetriever_chromaDB()
123
+
124
+ for message in st.session_state.messages:
125
+ with st.chat_message(message.get("role")):
126
+ st.write(message.get("content"))
127
+
128
+ section_visible = True
129
+
130
+ # Collect user input
131
+ # Add a radio button to choose input mode
132
+ input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"])
133
+ user_message = None
134
+ if input_mode == "Speech":
135
+ # Use the speech_to_text function to capture speech input
136
+ speech_input = speech_to_text(key="my_stt", callback=speech_recognition_callback)
137
+ # Check if speech input is available
138
+ if "speech_input" in st.session_state and st.session_state.speech_input:
139
+ # Display the speech input
140
+ # st.text(f"Speech Input: {st.session_state.speech_input}")
141
+
142
+ # Process the speech input as a query
143
+ user_message = st.session_state.speech_input
144
+ st.session_state.speech_input = None
145
+ else:
146
+ user_message = st.chat_input("Type your message here:")
147
+
148
+
149
+ # Modify the checkbox call to include a unique key parameter
150
+ llama_guard_enabled = st.sidebar.checkbox(
151
+ "Enable LlamaGuard",
152
+ value=st.session_state["llama_guard_enabled"],
153
+ key="llama_guard_toggle",
154
+ )
155
+
156
+
157
+ # Update the session state based on the checkbox interaction
158
+ st.session_state["llama_guard_enabled"] = llama_guard_enabled
159
+
160
+ # Take user input
161
+ if user_message:
162
+ st.session_state.entered_text.append(user_message)
163
+
164
+ st.session_state.messages.append({"role": "user", "content": user_message})
165
+ with st.chat_message("user"):
166
+ st.write(user_message)
167
+
168
+ is_safe = True
169
+ if st.session_state["llama_guard_enabled"]:
170
+ # guard_status = moderate_chat(user_prompt)
171
+ guard_status, error = moderate_chat(user_message)
172
+ if error:
173
+ st.error(f"Failed to retrieve data from Llama Guard: {error}")
174
+ else:
175
+ if "unsafe" in guard_status[0]["generated_text"]:
176
+ is_safe = False
177
+ # added on March 24th
178
+ unsafe_category_name = get_category_name(
179
+ guard_status[0]["generated_text"]
180
+ )
181
+
182
+ if is_safe == False:
183
+ response = f"I see you are asking something about {unsafe_category_name} Due to eithical and safety reasons, I can't provide the help you need. Please reach out to someone who can, like a family member, friend, or therapist. In urgent situations, contact emergency services or a crisis hotline. Remember, asking for help is brave, and you're not alone."
184
+ st.session_state.messages.append({"role": "ai", "content": response})
185
+ with st.chat_message("ai"):
186
+ st.markdown(response)
187
+ speech_fp = text_to_speech(response)
188
+ # Play the speech
189
+ st.audio(speech_fp, format="audio/mp3")
190
+ else:
191
+ # Detect mental condition
192
+ with st.spinner("Processing..."):
193
+ mental_classifier.initialize_tokenizer(tokenizer_model_name)
194
+ mental_classifier.preprocess_data()
195
+ predicted_mental_category = mental_classifier.predict_category(user_message)
196
+ print("Predicted mental health condition:", predicted_mental_category)
197
+
198
+ # Detect sentiment
199
+ user_sentiment = chatbot.detect_sentiment(user_message)
200
+
201
+ # Retrieve question
202
+ if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]:
203
+ question = retriever.get_response(
204
+ user_message, predicted_mental_category
205
+ )
206
+ show_question = True
207
+ else:
208
+ show_question = False
209
+ question = ""
210
+ # predicted_mental_category = ""
211
+
212
+ # Update mood history / mood_trend
213
+ chatbot.update_mood_history()
214
+ mood_trend = chatbot.check_mood_trend()
215
+
216
+ # Define rewards
217
+ if user_sentiment in ["Positive", "Moderately Positive"]:
218
+ if mood_trend == "increased":
219
+ reward = +1
220
+ mood_trend_symbol = " ⬆️"
221
+ elif mood_trend == "unchanged":
222
+ reward = +0.8
223
+ mood_trend_symbol = ""
224
+ else: # decreased
225
+ reward = -0.2
226
+ mood_trend_symbol = " ⬇️"
227
+ else:
228
+ if mood_trend == "increased":
229
+ reward = +1
230
+ mood_trend_symbol = " ⬆️"
231
+ elif mood_trend == "unchanged":
232
+ reward = -0.2
233
+ mood_trend_symbol = ""
234
+ else: # decreased
235
+ reward = -1
236
+ mood_trend_symbol = " ⬇️"
237
+
238
+ print(
239
+ f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑"
240
+ )
241
+
242
+ # Update Q-values
243
+ chatbot.update_q_values(
244
+ user_sentiment, chatbot.actions[0], reward, user_sentiment
245
+ )
246
+
247
+ # Get recommended action based on the updated Q-values
248
+ ai_tone = chatbot.get_action(user_sentiment)
249
+ print(ai_tone)
250
+
251
+ print(st.session_state.messages)
252
+
253
+ # LLM Response Generator
254
+ load_dotenv(find_dotenv())
255
+
256
+ llm_model = LLLResponseGenerator()
257
+ temperature = 0.1
258
+ max_length = 128
259
+
260
+ # Collect all messages exchanged so far into a single text string
261
+ all_messages = "\n".join(
262
+ [message.get("content") for message in st.session_state.messages]
263
+ )
264
+
265
+ # Question asked to the user: {question}
266
+
267
+ template = """INSTRUCTIONS: {context}
268
+
269
+ Respond to the user with a tone of {ai_tone}.
270
+
271
+ Response by the user: {user_text}
272
+ Response;
273
+ """
274
+ context = f"You are a mental health supporting non-medical assistant. Provide some advice and ask a relevant question back to the user. {all_messages}"
275
+
276
+ llm_response = llm_model.llm_inference(
277
+ model_type="huggingface",
278
+ question=question,
279
+ prompt_template=template,
280
+ context=context,
281
+ ai_tone=ai_tone,
282
+ questionnaire=predicted_mental_category,
283
+ user_text=user_message,
284
+ temperature=temperature,
285
+ max_length=max_length,
286
+ )
287
+
288
+ llm_response = remove_html_tags(llm_response)
289
+
290
+ if show_question:
291
+ llm_reponse_with_quesiton = f"{llm_response}\n\n{question}"
292
+ else:
293
+ llm_reponse_with_quesiton = llm_response
294
+
295
+ # Append the user and AI responses to the chat history
296
+ st.session_state.messages.append(
297
+ {"role": "ai", "content": llm_reponse_with_quesiton}
298
+ )
299
+
300
+ with st.chat_message("ai"):
301
+ st.markdown(llm_reponse_with_quesiton)
302
+ # Convert the response to speech
303
+ speech_fp = text_to_speech(llm_reponse_with_quesiton)
304
+ # Play the speech
305
+ st.audio(speech_fp, format="audio/mp3")
306
+ # st.write(f"{llm_response}")
307
+ # if show_question:
308
+ # st.write(f"{question}")
309
+ # else:
310
+ # user doesn't feel negative.
311
+ # get question to ecourage even more positive behaviour
312
+
313
+ # Update data to memory
314
+ st.session_state.user_sentiment = user_sentiment
315
+ st.session_state.mood_trend = mood_trend
316
+ st.session_state.predicted_mental_category = predicted_mental_category
317
+ st.session_state.ai_tone = ai_tone
318
+ st.session_state.mood_trend_symbol = mood_trend_symbol
319
+ st.session_state.show_question = show_question
320
+
321
+ # Show/hide "Behind the Scene" section
322
+ # section_visible = st.sidebar.button('Show/Hide Behind the Scene')
323
+
324
+ with st.sidebar.expander("Behind the Scene", expanded=section_visible):
325
+ st.subheader("What AI is doing:")
326
+ # Use the values stored in session state
327
+ st.write(
328
+ f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})"
329
+ )
330
+ # if st.session_state.show_question:
331
+ st.write(
332
+ f"- Possible Mental Condition: {st.session_state.predicted_mental_category.capitalize()}"
333
+ )
334
+ st.write(f"- AI Tone: {st.session_state.ai_tone.capitalize()}")
335
+ st.write(f"- Question retrieved from: {selected_retriever_option}")
336
+ st.write(
337
+ f"- If the user feels negative, moderately negative, or neutral, at the end of the AI response, it adds a mental health condition related question. The question is retrieved from DB. The categories of questions are limited to Depression, Anxiety, ADHD, Social Media Addiction, Social Isolation, and Cyberbullying which are most associated with FOMO related to excessive social media usage."
338
+ )
339
+ st.write(
340
+ f"- Below q-table is continuously updated after each interaction with the user. If the user's mood increases, AI gets a reward. Else, AI gets a punishment."
341
+ )
342
+
343
+ # Display Q-table
344
+ st.dataframe(display_q_table(chatbot.q_values, states, actions))