Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import pandas as pd | |
from streamlit_chat import message | |
from streamlit_extras.colored_header import colored_header | |
from streamlit_extras.add_vertical_space import add_vertical_space | |
from streamlit_mic_recorder import speech_to_text | |
from model_pipelineV2 import ModelPipeLine | |
from q_learning_chatbot import QLearningChatbot | |
from gtts import gTTS | |
from io import BytesIO | |
st.set_page_config(page_title="PeacePal") | |
#image to the sidebar | |
#image_path = os.path.join('images', 'sidebar.jpg') | |
#st.sidebar.image(image_path, use_column_width=True) | |
st.title('PeacePal 🌱') | |
mdl = ModelPipeLine() | |
# Now you can access the retriever attribute of the ModelPipeLine instance | |
# retriever = mdl.retriever | |
final_chain = mdl.create_final_chain() | |
# Define states and actions | |
states = [ | |
"Negative", | |
"Moderately Negative", | |
"Neutral", | |
"Moderately Positive", | |
"Positive", | |
] | |
# Initialize Q-learning chatbot and mental health classifier | |
chatbot = QLearningChatbot(states) | |
# Function to display Q-table | |
def display_q_table(q_values, states): | |
q_table_dict = {"State": states} | |
q_table_df = pd.DataFrame(q_table_dict) | |
return q_table_df | |
def text_to_speech(text): | |
# Use gTTS to convert text to speech | |
tts = gTTS(text=text, lang="en") | |
# Save the speech as bytes in memory | |
fp = BytesIO() | |
tts.write_to_fp(fp) | |
return fp | |
def speech_recognition_callback(): | |
# Ensure that speech output is available | |
if st.session_state.my_stt_output is None: | |
st.session_state.p01_error_message = "Please record your response again." | |
return | |
# Clear any previous error messages | |
st.session_state.p01_error_message = None | |
# Store the speech output in the session state | |
st.session_state.speech_input = st.session_state.my_stt_output | |
## generated stores AI generated responses | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"] | |
## past stores User's questions | |
if 'past' not in st.session_state: | |
st.session_state['past'] = ['Hi!'] | |
# Initialize memory | |
if "entered_text" not in st.session_state: | |
st.session_state.entered_text = [] | |
if "entered_mood" not in st.session_state: | |
st.session_state.entered_mood = [] | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "user_sentiment" not in st.session_state: | |
st.session_state.user_sentiment = "Neutral" | |
if "mood_trend" not in st.session_state: | |
st.session_state.mood_trend = "Unchanged" | |
if "mood_trend_symbol" not in st.session_state: | |
st.session_state.mood_trend_symbol = "" | |
# Layout of input/response containers | |
colored_header(label='', description='', color_name='blue-30') | |
response_container = st.container() | |
input_container = st.container() | |
# User input | |
## Function for taking user provided prompt as input | |
def get_text(): | |
input_text = st.text_input("You: ", "", key="input") | |
return input_text | |
def generate_response(prompt): | |
response = mdl.call_conversational_rag(prompt,final_chain) | |
return response['answer'] | |
# Collect user input | |
# Add a radio button to choose input mode | |
input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"]) | |
user_message = None | |
if input_mode == "Speech": | |
# Use the speech_to_text function to capture speech input | |
speech_input = speech_to_text(key="my_stt", callback=speech_recognition_callback) | |
# Check if speech input is available | |
if "speech_input" in st.session_state and st.session_state.speech_input: | |
# Display the speech input | |
# st.text(f"Speech Input: {st.session_state.speech_input}") | |
# Process the speech input as a query | |
user_message = st.session_state.speech_input | |
st.session_state.speech_input = None | |
else: | |
user_message = st.chat_input("Type your message here:") | |
## Applying the user input box | |
with input_container: | |
if user_message: | |
st.session_state.entered_text.append(user_message) | |
st.session_state.messages.append({"role": "user", "content": user_message}) | |
# Display the user's message | |
with st.chat_message("user"): | |
st.write(user_message) | |
# Process the user's message and generate a response | |
with st.spinner("Processing..."): | |
response = generate_response(user_message) | |
st.session_state.past.append(user_message) | |
st.session_state.messages.append({"role": "ai", "content": response}) | |
# Detect sentiment | |
user_sentiment = chatbot.detect_sentiment(user_message) | |
# Update mood history / mood_trend | |
chatbot.update_mood_history() | |
mood_trend = chatbot.check_mood_trend() | |
# Define rewards | |
if user_sentiment in ["Positive", "Moderately Positive"]: | |
if mood_trend == "increased": | |
reward = +1 | |
mood_trend_symbol = " ⬆️" | |
elif mood_trend == "unchanged": | |
reward = +0.8 | |
mood_trend_symbol = "" | |
else: # decreased | |
reward = -0.2 | |
mood_trend_symbol = " ⬇️" | |
else: | |
if mood_trend == "increased": | |
reward = +1 | |
mood_trend_symbol = " ⬆️" | |
elif mood_trend == "unchanged": | |
reward = -0.2 | |
mood_trend_symbol = "" | |
else: # decreased | |
reward = -1 | |
mood_trend_symbol = " ⬇️" | |
print(f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑") | |
# Update Q-values | |
chatbot.update_q_values(user_sentiment, reward, user_sentiment) | |
# Display the AI's response | |
with st.chat_message("ai"): | |
st.markdown(response) | |
st.session_state.user_sentiment = user_sentiment | |
st.session_state.mood_trend = mood_trend | |
st.session_state.mood_trend_symbol = mood_trend_symbol | |
# Convert the response to speech | |
speech_fp = text_to_speech(response) | |
# Play the speech | |
st.audio(speech_fp, format='audio/mp3') | |
with st.sidebar.expander("Sentiment Analysis"): | |
# Use the values stored in session state | |
st.write( | |
f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})" | |
) | |
# Display Q-table | |
st.dataframe(display_q_table(chatbot.q_values, states)) |