Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
from streamlit_chat import message | |
from streamlit_extras.colored_header import colored_header | |
from streamlit_extras.add_vertical_space import add_vertical_space | |
from streamlit_mic_recorder import speech_to_text | |
from model_pipeline import ModelPipeLine | |
from q_learning_chatbot import QLearningChatbot | |
from gtts import gTTS | |
from io import BytesIO | |
st.set_page_config(page_title="PeacePal") | |
#image to the sidebar | |
image_path = os.path.join('images', 'sidebar.jpg') | |
st.sidebar.image(image_path, use_column_width=True) | |
st.title('PeacePal 🌱') | |
mdl = ModelPipeLine() | |
# Now you can access the retriever attribute of the ModelPipeLine instance | |
retriever = mdl.retriever | |
final_chain = mdl.create_final_chain() | |
# Define states and actions | |
states = [ | |
"Negative", | |
"Moderately Negative", | |
"Neutral", | |
"Moderately Positive", | |
"Positive", | |
] | |
# Initialize Q-learning chatbot and mental health classifier | |
chatbot = QLearningChatbot(states) | |
# Function to display Q-table | |
def display_q_table(q_values, states): | |
q_table_dict = {"State": states} | |
q_table_df = pd.DataFrame(q_table_dict) | |
return q_table_df | |
def text_to_speech(text): | |
# Use gTTS to convert text to speech | |
tts = gTTS(text=text, lang="en") | |
# Save the speech as bytes in memory | |
fp = BytesIO() | |
tts.write_to_fp(fp) | |
return fp | |
def speech_recognition_callback(): | |
# Ensure that speech output is available | |
if st.session_state.my_stt_output is None: | |
st.session_state.p01_error_message = "Please record your response again." | |
return | |
# Clear any previous error messages | |
st.session_state.p01_error_message = None | |
# Store the speech output in the session state | |
st.session_state.speech_input = st.session_state.my_stt_output | |
## generated stores AI generated responses | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"] | |
## past stores User's questions | |
if 'past' not in st.session_state: | |
st.session_state['past'] = ['Hi!'] | |
# Initialize memory | |
if "entered_text" not in st.session_state: | |
st.session_state.entered_text = [] | |
if "entered_mood" not in st.session_state: | |
st.session_state.entered_mood = [] | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "user_sentiment" not in st.session_state: | |
st.session_state.user_sentiment = "Neutral" | |
if "mood_trend" not in st.session_state: | |
st.session_state.mood_trend = "Unchanged" | |
if "mood_trend_symbol" not in st.session_state: | |
st.session_state.mood_trend_symbol = "" | |
# Layout of input/response containers | |
colored_header(label='', description='', color_name='blue-30') | |
response_container = st.container() | |
input_container = st.container() | |
# User input | |
## Function for taking user provided prompt as input | |
def get_text(): | |
input_text = st.text_input("You: ", "", key="input") | |
return input_text | |
def generate_response(prompt): | |
response = mdl.call_conversational_rag(prompt,final_chain) | |
return response['answer'] | |
# Collect user input | |
# Add a radio button to choose input mode | |
input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"]) | |
user_message = None | |
if input_mode == "Speech": | |
# Use the speech_to_text function to capture speech input | |
speech_input = speech_to_text(key="my_stt", callback=speech_recognition_callback) | |
# Check if speech input is available | |
if "speech_input" in st.session_state and st.session_state.speech_input: | |
# Display the speech input | |
# st.text(f"Speech Input: {st.session_state.speech_input}") | |
# Process the speech input as a query | |
user_message = st.session_state.speech_input | |
st.session_state.speech_input = None | |
else: | |
user_message = st.chat_input("Type your message here:") | |
## Applying the user input box | |
with input_container: | |
if user_message: | |
st.session_state.entered_text.append(user_message) | |
st.session_state.messages.append({"role": "user", "content": user_message}) | |
with st.chat_message("user"): | |
st.write(user_message) | |
with st.spinner("processing....."): | |
response = generate_response(user_message) | |
st.session_state.past.append(user_message) | |
st.session_state.generated.append(response) | |
# Detect sentiment | |
user_sentiment = chatbot.detect_sentiment(user_message) | |
# Update mood history / mood_trend | |
chatbot.update_mood_history() | |
mood_trend = chatbot.check_mood_trend() | |
# Define rewards | |
if user_sentiment in ["Positive", "Moderately Positive"]: | |
if mood_trend == "increased": | |
reward = +1 | |
mood_trend_symbol = " ⬆️" | |
elif mood_trend == "unchanged": | |
reward = +0.8 | |
mood_trend_symbol = "" | |
else: # decreased | |
reward = -0.2 | |
mood_trend_symbol = " ⬇️" | |
else: | |
if mood_trend == "increased": | |
reward = +1 | |
mood_trend_symbol = " ⬆️" | |
elif mood_trend == "unchanged": | |
reward = -0.2 | |
mood_trend_symbol = "" | |
else: # decreased | |
reward = -1 | |
mood_trend_symbol = " ⬇️" | |
print( | |
f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑" | |
) | |
# Update Q-values | |
chatbot.update_q_values( | |
user_sentiment, reward, user_sentiment | |
) | |
# Convert the response to speech | |
speech_fp = text_to_speech(response) | |
# Play the speech | |
st.audio(speech_fp, format='audio/mp3') | |
## Conditional display of AI generated responses as a function of user provided prompts | |
with response_container: | |
if st.session_state['generated']: | |
for i in range(len(st.session_state['generated'])): | |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') | |
message(st.session_state["generated"][i], key=str(I)) | |
with st.sidebar.expander("Behind the Scene", expanded=section_visible): | |
st.subheader("What AI is doing:") | |
# Use the values stored in session state | |
st.write( | |
f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})" | |
) | |
# Display Q-table | |
st.dataframe(display_q_table(chatbot.q_values, states, actions)) | |
st.write("-----------------------") | |
st.write( | |
f"- Above q-table is continuously updated after each interaction with the user. If the user's mood increases, AI gets a reward. Else, AI gets a punishment." | |
) | |
st.write(f"- Question retrieved from: {selected_retriever_option}") | |
st.write( | |
f"- If the user feels negative, moderately negative, or neutral, at the end of the AI response, it adds a mental health condition related question. The question is retrieved from DB. The categories of questions are limited to Depression, Anxiety, ADHD, Social Media Addiction, Social Isolation, and Cyberbullying which are most associated with FOMO related to excessive social media usage." | |
) |