|
import streamlit as st |
|
import os |
|
import pandas as pd |
|
from streamlit_chat import message |
|
from streamlit_extras.colored_header import colored_header |
|
from streamlit_extras.add_vertical_space import add_vertical_space |
|
from streamlit_mic_recorder import speech_to_text |
|
from model_pipeline import ModelPipeLine |
|
from q_learning_chatbot import QLearningChatbot |
|
|
|
from gtts import gTTS |
|
from io import BytesIO |
|
st.set_page_config(page_title="PeacePal") |
|
|
|
|
|
|
|
|
|
st.title('PeacePal 🌱') |
|
|
|
mdl = ModelPipeLine() |
|
|
|
retriever = mdl.retriever |
|
|
|
final_chain = mdl.create_final_chain() |
|
|
|
|
|
states = [ |
|
"Negative", |
|
"Moderately Negative", |
|
"Neutral", |
|
"Moderately Positive", |
|
"Positive", |
|
] |
|
|
|
|
|
chatbot = QLearningChatbot(states) |
|
|
|
|
|
def display_q_table(q_values, states): |
|
q_table_dict = {"State": states} |
|
q_table_df = pd.DataFrame(q_table_dict) |
|
return q_table_df |
|
|
|
def text_to_speech(text): |
|
|
|
tts = gTTS(text=text, lang="en") |
|
|
|
fp = BytesIO() |
|
tts.write_to_fp(fp) |
|
return fp |
|
|
|
|
|
def speech_recognition_callback(): |
|
|
|
if st.session_state.my_stt_output is None: |
|
st.session_state.p01_error_message = "Please record your response again." |
|
return |
|
|
|
|
|
st.session_state.p01_error_message = None |
|
|
|
|
|
st.session_state.speech_input = st.session_state.my_stt_output |
|
|
|
|
|
if 'generated' not in st.session_state: |
|
st.session_state['generated'] = ["I'm your Mental health Assistant, How may I help you?"] |
|
|
|
if 'past' not in st.session_state: |
|
st.session_state['past'] = ['Hi!'] |
|
|
|
|
|
if "entered_text" not in st.session_state: |
|
st.session_state.entered_text = [] |
|
if "entered_mood" not in st.session_state: |
|
st.session_state.entered_mood = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
if "user_sentiment" not in st.session_state: |
|
st.session_state.user_sentiment = "Neutral" |
|
if "mood_trend" not in st.session_state: |
|
st.session_state.mood_trend = "Unchanged" |
|
if "mood_trend_symbol" not in st.session_state: |
|
st.session_state.mood_trend_symbol = "" |
|
|
|
|
|
|
|
colored_header(label='', description='', color_name='blue-30') |
|
response_container = st.container() |
|
input_container = st.container() |
|
|
|
|
|
|
|
def get_text(): |
|
input_text = st.text_input("You: ", "", key="input") |
|
return input_text |
|
|
|
def generate_response(prompt): |
|
response = mdl.call_conversational_rag(prompt,final_chain) |
|
return response['answer'] |
|
|
|
|
|
|
|
input_mode = st.sidebar.radio("Select input mode:", ["Text", "Speech"]) |
|
user_message = None |
|
if input_mode == "Speech": |
|
|
|
speech_input = speech_to_text(key="my_stt", callback=speech_recognition_callback) |
|
|
|
if "speech_input" in st.session_state and st.session_state.speech_input: |
|
|
|
|
|
|
|
|
|
user_message = st.session_state.speech_input |
|
st.session_state.speech_input = None |
|
else: |
|
user_message = st.chat_input("Type your message here:") |
|
|
|
|
|
with input_container: |
|
if user_message: |
|
st.session_state.entered_text.append(user_message) |
|
st.session_state.messages.append({"role": "user", "content": user_message}) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.write(user_message) |
|
|
|
|
|
with st.spinner("Processing..."): |
|
response = generate_response(user_message) |
|
st.session_state.past.append(user_message) |
|
st.session_state.messages.append({"role": "ai", "content": response}) |
|
|
|
|
|
user_sentiment = chatbot.detect_sentiment(user_message) |
|
|
|
|
|
chatbot.update_mood_history() |
|
mood_trend = chatbot.check_mood_trend() |
|
|
|
|
|
if user_sentiment in ["Positive", "Moderately Positive"]: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = +0.8 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -0.2 |
|
mood_trend_symbol = " ⬇️" |
|
else: |
|
if mood_trend == "increased": |
|
reward = +1 |
|
mood_trend_symbol = " ⬆️" |
|
elif mood_trend == "unchanged": |
|
reward = -0.2 |
|
mood_trend_symbol = "" |
|
else: |
|
reward = -1 |
|
mood_trend_symbol = " ⬇️" |
|
|
|
print(f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑") |
|
|
|
|
|
|
|
chatbot.update_q_values(user_sentiment, reward, user_sentiment) |
|
|
|
|
|
with st.chat_message("ai"): |
|
st.markdown(response) |
|
st.session_state.user_sentiment = user_sentiment |
|
st.session_state.mood_trend = mood_trend |
|
st.session_state.mood_trend_symbol = mood_trend_symbol |
|
|
|
|
|
speech_fp = text_to_speech(response) |
|
|
|
st.audio(speech_fp, format='audio/mp3') |
|
|
|
|
|
if st.session_state['generated']: |
|
for i, generated_response in enumerate(st.session_state['generated']): |
|
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
|
message(generated_response, key=str(i) + '_ai') |
|
|
|
with st.sidebar.expander("Sentiment Analysis"): |
|
|
|
st.write( |
|
f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})" |
|
) |
|
|
|
|
|
st.dataframe(display_q_table(chatbot.q_values, states)) |