PeacePal / app.py
SwatGarg's picture
Update app.py
7ec21ac verified
raw
history blame
6.84 kB
import streamlit as st
import os
import pandas as pd
from streamlit_chat import message
from streamlit_extras.colored_header import colored_header
from streamlit_extras.add_vertical_space import add_vertical_space
from streamlit_mic_recorder import speech_to_text
from model_pipeline import ModelPipeLine
from q_learning_chatbot import QLearningChatbot
from gtts import gTTS
from io import BytesIO
st.set_page_config(page_title="PeacePal")
#image to the sidebar
#image_path = os.path.join('images', 'sidebar.jpg')
#st.sidebar.image(image_path, use_column_width=True)
st.title('PeacePal 🌱')
mdl = ModelPipeLine()
# Now you can access the retriever attribute of the ModelPipeLine instance
retriever = mdl.retriever
final_chain = mdl.create_final_chain()
# Define states and actions
states = [
"Negative",
"Moderately Negative",
"Neutral",
"Moderately Positive",
"Positive",
]
# Initialize Q-learning chatbot and mental health classifier
chatbot = QLearningChatbot(states)
# Function to display Q-table
def display_q_table(q_values, states):
q_table_dict = {"State": states}
q_table_df = pd.DataFrame(q_table_dict)
return q_table_df
def text_to_speech(text):
# Use gTTS to convert text to speech
tts = gTTS(text=text, lang="en")
# Save the speech as bytes in memory
fp = BytesIO()
tts.write_to_fp(fp)
return fp
def speech_recognition_callback():
# Ensure that speech output is available
if st.session_state.my_stt_output is None:
st.session_state.p01_error_message = "Please record your response again."
return
# Clear any previous error messages
st.session_state.p01_error_message = None
# Store the speech output in the session state
st.session_state.speech_input = st.session_state.my_stt_output
## generated stores AI generated responses
if 'generated' not in st.session_state:
st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"]
## past stores User's questions
if 'past' not in st.session_state:
st.session_state['past'] = ['Hi!']
# Initialize memory
if "entered_text" not in st.session_state:
st.session_state.entered_text = []
if "entered_mood" not in st.session_state:
st.session_state.entered_mood = []
if "messages" not in st.session_state:
st.session_state.messages = []
if "user_sentiment" not in st.session_state:
st.session_state.user_sentiment = "Neutral"
if "mood_trend" not in st.session_state:
st.session_state.mood_trend = "Unchanged"
if "mood_trend_symbol" not in st.session_state:
st.session_state.mood_trend_symbol = ""
# Layout of input/response containers
colored_header(label='', description='', color_name='blue-30')
response_container = st.container()
input_container = st.container()
# User input
## Function for taking user provided prompt as input
def get_text():
input_text = st.text_input("You: ", "", key="input")
return input_text
def generate_response(prompt):
response = mdl.call_conversational_rag(prompt,final_chain)
return response['answer']
# Collect user input
# Add a radio button to choose input mode
input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"])
user_message = None
if input_mode == "Speech":
# Use the speech_to_text function to capture speech input
speech_input = speech_to_text(key="my_stt", callback=speech_recognition_callback)
# Check if speech input is available
if "speech_input" in st.session_state and st.session_state.speech_input:
# Display the speech input
# st.text(f"Speech Input: {st.session_state.speech_input}")
# Process the speech input as a query
user_message = st.session_state.speech_input
st.session_state.speech_input = None
else:
user_message = st.chat_input("Type your message here:")
## Applying the user input box
with input_container:
if user_message:
st.session_state.entered_text.append(user_message)
st.session_state.messages.append({"role": "user", "content": user_message})
# Display the user's message
with st.chat_message("user"):
st.write(user_message)
# Process the user's message and generate a response
with st.spinner("Processing..."):
response = generate_response(user_message)
st.session_state.past.append(user_message)
st.session_state.messages.append({"role": "ai", "content": response})
# Detect sentiment
user_sentiment = chatbot.detect_sentiment(user_message)
# Update mood history / mood_trend
chatbot.update_mood_history()
mood_trend = chatbot.check_mood_trend()
# Define rewards
if user_sentiment in ["Positive", "Moderately Positive"]:
if mood_trend == "increased":
reward = +1
mood_trend_symbol = " ⬆️"
elif mood_trend == "unchanged":
reward = +0.8
mood_trend_symbol = ""
else: # decreased
reward = -0.2
mood_trend_symbol = " ⬇️"
else:
if mood_trend == "increased":
reward = +1
mood_trend_symbol = " ⬆️"
elif mood_trend == "unchanged":
reward = -0.2
mood_trend_symbol = ""
else: # decreased
reward = -1
mood_trend_symbol = " ⬇️"
print(f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑")
# Update Q-values
chatbot.update_q_values(user_sentiment, reward, user_sentiment)
# Display the AI's response
with st.chat_message("ai"):
st.markdown(response)
st.session_state.user_sentiment = user_sentiment
st.session_state.mood_trend = mood_trend
st.session_state.mood_trend_symbol = mood_trend_symbol
# Convert the response to speech
speech_fp = text_to_speech(response)
# Play the speech
st.audio(speech_fp, format='audio/mp3')
with st.sidebar.expander("Sentiment Analysis"):
# Use the values stored in session state
st.write(
f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})"
)
# Display Q-table
st.dataframe(display_q_table(chatbot.q_values, states))