SwatGarg commited on
Commit
0ca1b78
1 Parent(s): 6ede6bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -30
app.py CHANGED
@@ -9,12 +9,16 @@ from q_learning_chatbot import QLearningChatbot
9
 
10
  from gtts import gTTS
11
  from io import BytesIO
 
 
 
 
 
 
12
 
13
  mdl = ModelPipeLine()
14
  final_chain = mdl.create_final_chain()
15
 
16
- st.set_page_config(page_title="PeacePal")
17
-
18
  # Define states and actions
19
  states = [
20
  "Negative",
@@ -24,15 +28,35 @@ states = [
24
  "Positive",
25
  ]
26
 
27
- # Add logo to the sidebar
28
- #logo_path = os.path.join('images', 'logo.jpeg')
29
- #st.sidebar.image(logo_path, use_column_width=True)
30
 
31
- # Add image to the sidebar
32
- image_path = os.path.join('images', 'sidebar.jpg')
33
- st.sidebar.image(image_path, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
34
 
35
- st.title('PeacePal 🌱')
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  ## generated stores AI generated responses
38
  if 'generated' not in st.session_state:
@@ -41,6 +65,24 @@ if 'generated' not in st.session_state:
41
  if 'past' not in st.session_state:
42
  st.session_state['past'] = ['Hi!']
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # Layout of input/response containers
45
 
46
  colored_header(label='', description='', color_name='blue-30')
@@ -56,28 +98,7 @@ def get_text():
56
  def generate_response(prompt):
57
  response = mdl.call_conversational_rag(prompt,final_chain)
58
  return response['answer']
59
-
60
- def text_to_speech(text):
61
- # Use gTTS to convert text to speech
62
- tts = gTTS(text=text, lang='en')
63
- # Save the speech as bytes in memory
64
- fp = BytesIO()
65
- tts.write_to_fp(fp)
66
- return fp
67
-
68
- def speech_recognition_callback():
69
- # Ensure that speech output is available
70
- if st.session_state.my_stt_output is None:
71
- st.session_state.p01_error_message = "Please record your response again."
72
- return
73
 
74
- # Clear any previous error messages
75
- st.session_state.p01_error_message = None
76
-
77
- # Store the speech output in the session state
78
- st.session_state.speech_input = st.session_state.my_stt_output
79
-
80
-
81
  ## Applying the user input box
82
  with input_container:
83
  # Add a radio button to choose input mode
@@ -101,11 +122,60 @@ with input_container:
101
  response = generate_response(query)
102
  st.session_state.past.append(query)
103
  st.session_state.generated.append(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # Convert the response to speech
106
  speech_fp = text_to_speech(response)
107
  # Play the speech
108
  st.audio(speech_fp, format='audio/mp3')
 
109
  else:
110
  # Add a text input field for query
111
  query = st.text_input("Query: ", key="input")
@@ -116,7 +186,59 @@ with input_container:
116
  response = generate_response(query)
117
  st.session_state.past.append(query)
118
  st.session_state.generated.append(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # Convert the response to speech
121
  speech_fp = text_to_speech(response)
122
  # Play the speech
 
9
 
10
  from gtts import gTTS
11
  from io import BytesIO
12
+ st.set_page_config(page_title="PeacePal")
13
+ #image to the sidebar
14
+ image_path = os.path.join('images', 'sidebar.jpg')
15
+ st.sidebar.image(image_path, use_column_width=True)
16
+
17
+ st.title('PeacePal 🌱')
18
 
19
  mdl = ModelPipeLine()
20
  final_chain = mdl.create_final_chain()
21
 
 
 
22
  # Define states and actions
23
  states = [
24
  "Negative",
 
28
  "Positive",
29
  ]
30
 
31
+ # Initialize Q-learning chatbot and mental health classifier
32
+ chatbot = QLearningChatbot(states)
 
33
 
34
+ # Function to display Q-table
35
+ def display_q_table(q_values, states):
36
+ q_table_dict = {"State": states}
37
+ q_table_df = pd.DataFrame(q_table_dict)
38
+ return q_table_df
39
+
40
+ def text_to_speech(text):
41
+ # Use gTTS to convert text to speech
42
+ tts = gTTS(text=text, lang="en")
43
+ # Save the speech as bytes in memory
44
+ fp = BytesIO()
45
+ tts.write_to_fp(fp)
46
+ return fp
47
 
48
+
49
+ def speech_recognition_callback():
50
+ # Ensure that speech output is available
51
+ if st.session_state.my_stt_output is None:
52
+ st.session_state.p01_error_message = "Please record your response again."
53
+ return
54
+
55
+ # Clear any previous error messages
56
+ st.session_state.p01_error_message = None
57
+
58
+ # Store the speech output in the session state
59
+ st.session_state.speech_input = st.session_state.my_stt_output
60
 
61
  ## generated stores AI generated responses
62
  if 'generated' not in st.session_state:
 
65
  if 'past' not in st.session_state:
66
  st.session_state['past'] = ['Hi!']
67
 
68
+ # Initialize memory
69
+ if "entered_text" not in st.session_state:
70
+ st.session_state.entered_text = []
71
+ if "entered_mood" not in st.session_state:
72
+ st.session_state.entered_mood = []
73
+ if "messages" not in st.session_state:
74
+ st.session_state.messages = []
75
+ if "user_sentiment" not in st.session_state:
76
+ st.session_state.user_sentiment = "Neutral"
77
+ if "mood_trend" not in st.session_state:
78
+ st.session_state.mood_trend = "Unchanged"
79
+ if "mood_trend_symbol" not in st.session_state:
80
+ st.session_state.mood_trend_symbol = ""
81
+ if "show_question" not in st.session_state:
82
+ st.session_state.show_question = False
83
+ if "asked_questions" not in st.session_state:
84
+ st.session_state.asked_questions = []
85
+
86
  # Layout of input/response containers
87
 
88
  colored_header(label='', description='', color_name='blue-30')
 
98
  def generate_response(prompt):
99
  response = mdl.call_conversational_rag(prompt,final_chain)
100
  return response['answer']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
 
 
102
  ## Applying the user input box
103
  with input_container:
104
  # Add a radio button to choose input mode
 
122
  response = generate_response(query)
123
  st.session_state.past.append(query)
124
  st.session_state.generated.append(response)
125
+ # Detect sentiment
126
+ user_sentiment = chatbot.detect_sentiment(user_message)
127
+
128
+ # Retrieve question
129
+ if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]:
130
+ question = retriever.get_response(
131
+ user_message, predicted_mental_category
132
+ )
133
+ st.session_state.asked_questions.append(question)
134
+ show_question = True
135
+ else:
136
+ show_question = False
137
+ question = ""
138
 
139
+ # Update mood history / mood_trend
140
+ chatbot.update_mood_history()
141
+ mood_trend = chatbot.check_mood_trend()
142
+
143
+ # Define rewards
144
+ if user_sentiment in ["Positive", "Moderately Positive"]:
145
+ if mood_trend == "increased":
146
+ reward = +1
147
+ mood_trend_symbol = " ⬆️"
148
+ elif mood_trend == "unchanged":
149
+ reward = +0.8
150
+ mood_trend_symbol = ""
151
+ else: # decreased
152
+ reward = -0.2
153
+ mood_trend_symbol = " ⬇️"
154
+ else:
155
+ if mood_trend == "increased":
156
+ reward = +1
157
+ mood_trend_symbol = " ⬆️"
158
+ elif mood_trend == "unchanged":
159
+ reward = -0.2
160
+ mood_trend_symbol = ""
161
+ else: # decreased
162
+ reward = -1
163
+ mood_trend_symbol = " ⬇️"
164
+
165
+ print(
166
+ f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑"
167
+ )
168
+
169
+ # Update Q-values
170
+ chatbot.update_q_values(
171
+ user_sentiment, reward, user_sentiment
172
+ )
173
+
174
  # Convert the response to speech
175
  speech_fp = text_to_speech(response)
176
  # Play the speech
177
  st.audio(speech_fp, format='audio/mp3')
178
+
179
  else:
180
  # Add a text input field for query
181
  query = st.text_input("Query: ", key="input")
 
186
  response = generate_response(query)
187
  st.session_state.past.append(query)
188
  st.session_state.generated.append(response)
189
+ # Detect sentiment
190
+ user_sentiment = chatbot.detect_sentiment(user_message)
191
+
192
+ # Retrieve question
193
+ if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]:
194
+ question = retriever.get_response(
195
+ user_message, predicted_mental_category
196
+ )
197
+ st.session_state.asked_questions.append(question)
198
+ show_question = True
199
+ else:
200
+ show_question = False
201
+ question = ""
202
+ # Convert the response to speech
203
+ speech_fp = text_to_speech(response)
204
+ # Play the speech
205
+ st.audio(speech_fp, format='audio/mp3')
206
 
207
+ # Update mood history / mood_trend
208
+ chatbot.update_mood_history()
209
+ mood_trend = chatbot.check_mood_trend()
210
+
211
+ # Define rewards
212
+ if user_sentiment in ["Positive", "Moderately Positive"]:
213
+ if mood_trend == "increased":
214
+ reward = +1
215
+ mood_trend_symbol = " ⬆️"
216
+ elif mood_trend == "unchanged":
217
+ reward = +0.8
218
+ mood_trend_symbol = ""
219
+ else: # decreased
220
+ reward = -0.2
221
+ mood_trend_symbol = " ⬇️"
222
+ else:
223
+ if mood_trend == "increased":
224
+ reward = +1
225
+ mood_trend_symbol = " ⬆️"
226
+ elif mood_trend == "unchanged":
227
+ reward = -0.2
228
+ mood_trend_symbol = ""
229
+ else: # decreased
230
+ reward = -1
231
+ mood_trend_symbol = " ⬇️"
232
+
233
+ print(
234
+ f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑"
235
+ )
236
+
237
+ # Update Q-values
238
+ chatbot.update_q_values(
239
+ user_sentiment, reward, user_sentiment
240
+ )
241
+
242
  # Convert the response to speech
243
  speech_fp = text_to_speech(response)
244
  # Play the speech