Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
import streamlit as st | |
from q_learning_chatbot import QLearningChatbot | |
from xgb_mental_health import MentalHealthClassifier | |
from bm25_retreive_question import QuestionRetriever | |
from llm_response_generator import LLLResponseGenerator | |
# Streamlit UI | |
st.title("FOMO Fix - RL-based Mental Health Assistant") | |
# Define states and actions | |
states = ['Negative', 'Moderately Negative', 'Neutral', 'Moderately Positive', 'Positive'] | |
actions = ['encouragement', 'empathy'] | |
# Initialize Q-learning chatbot and mental health classifier | |
chatbot = QLearningChatbot(states, actions) | |
# Initialize MentalHealthClassifier | |
data_path = "data.csv" | |
tokenizer_model_name = "nlptown/bert-base-multilingual-uncased-sentiment" | |
mental_classifier_model_path = 'mental_health_model.pkl' | |
mental_classifier = MentalHealthClassifier(data_path, mental_classifier_model_path) | |
# Function to display Q-table | |
def display_q_table(q_values, states, actions): | |
q_table_dict = {'State': states} | |
for i, action in enumerate(actions): | |
q_table_dict[action] = q_values[:, i] | |
q_table_df = pd.DataFrame(q_table_dict) | |
return q_table_df | |
# Initialize memory | |
if 'entered_text' not in st.session_state: | |
st.session_state.entered_text = [] | |
if 'entered_mood' not in st.session_state: | |
st.session_state.entered_mood = [] | |
# Collect user input | |
user_message = st.text_input("Type your message here:") | |
# Take user input | |
if user_message: | |
st.session_state.entered_text.append(user_message) | |
# Detect mental condition | |
mental_classifier.initialize_tokenizer(tokenizer_model_name) | |
mental_classifier.preprocess_data() | |
predicted_mental_category = mental_classifier.predict_category(user_message) | |
print("Predicted mental health condition:", predicted_mental_category) | |
# st.subheader("π " + f"{predicted_mental_category.capitalize()}") | |
# Retrieve question | |
retriever = QuestionRetriever() | |
question = retriever.get_response(user_message, predicted_mental_category) | |
# st.write(question) | |
# Detect sentiment | |
user_sentiment = chatbot.detect_sentiment(user_message) | |
# Update mood history / moode_trend | |
chatbot.update_mood_history() | |
mood_trend = chatbot.check_mood_trend() | |
# Define rewards | |
if user_sentiment in ["Positive", "Moderately Positive"]: | |
if mood_trend == "increased": | |
reward = +0.8 | |
else: # decresed | |
reward = -0.3 | |
else: | |
if mood_trend == "increased": | |
reward = +1 | |
else: | |
reward = -1 | |
print(f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - π{reward}π -- (a)") | |
# Update Q-values | |
chatbot.update_q_values(user_sentiment, chatbot.actions[0], reward, user_sentiment) | |
# Get recommended action based on the updated Q-values | |
ai_tone = chatbot.get_action(user_sentiment) | |
print(ai_tone) | |
#-------------- | |
# LLM Response Generator | |
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
llm_model = LLLResponseGenerator() | |
temperature = 0.1 | |
max_length = 128 | |
template = """INSTRUCTIONS: {context} | |
Respond to the user with a tone of {ai_tone}. | |
Question asked to the user: {question} | |
Response by the user: {user_text} | |
Response; | |
""" | |
context = "You are a mental health supporting non-medical assistant. Provide some advice and ask a relevant question back to the user." | |
llm_response = llm_model.llm_inference( | |
model_type="huggingface", | |
question=question, | |
prompt_template=template, | |
context=context, | |
ai_tone=ai_tone, | |
questionnaire=predicted_mental_category, | |
user_text=user_message, | |
temperature=temperature, | |
max_length=max_length, | |
) | |
st.write(f"{llm_response}") | |
st.write(f"{question}") | |
st.subheader("Behind the Scence - What AI is doing:") | |
st.write(f"- User Tone: {user_sentiment}, Possibly {predicted_mental_category.capitalize()}") | |
st.write(f"- AI Tone: {ai_tone.capitalize()}") | |
# st.write(f"Question: {question}") | |
# Display results | |
# st.subheader(f"{user_sentiment.capitalize()}") | |
# st.write("->" + f"{ai_tone.capitalize()}") | |
# st.write(f"Mood {chatbot.check_mood_trend()}") | |
# st.write(f"{ai_tone.capitalize()}, {chatbot.check_mood_trend()}") | |
# Display Q-table | |
st.dataframe(display_q_table(chatbot.q_values, states, actions)) | |
# Display mood history | |
# st.subheader("Mood History (Recent 5):") | |
# for mood_now in reversed(chatbot.mood_history[-5:]): #st.session_state.entered_mood[-5:], chatbot.mood_history[-5:]): #st.session_state.entered_text[-5:] | |
# st.write(f"{mood_now}") | |