File size: 9,763 Bytes
a4c8f8f
 
 
 
1830956
8465f96
a4c8f8f
1830956
ee42ff5
8465f96
a4c8f8f
 
 
 
1830956
 
 
 
 
 
 
ab3e284
a4c8f8f
 
 
 
 
1830956
 
a4c8f8f
1bf564f
a4c8f8f
 
94ce3aa
 
 
 
 
 
 
 
 
 
a4c8f8f
 
 
1830956
a4c8f8f
 
 
 
 
 
5723ff3
 
 
 
 
 
 
 
a4c8f8f
 
1830956
a4c8f8f
1830956
a4c8f8f
1830956
 
2cbbd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
1830956
 
 
 
 
 
 
 
 
 
 
 
 
 
2cbbd1d
 
 
a4c8f8f
1830956
a4c8f8f
 
 
 
 
1830956
 
 
a4c8f8f
2cbbd1d
1830956
 
 
 
 
a4c8f8f
1830956
 
a4c8f8f
1830956
2cbbd1d
1830956
 
 
 
 
 
 
2cbbd1d
1830956
 
 
 
 
 
 
 
 
 
 
2cbbd1d
1830956
 
 
 
 
 
 
 
 
 
 
 
 
 
2cbbd1d
1830956
a4c8f8f
1830956
 
 
a4c8f8f
 
1830956
 
 
2cbbd1d
8465f96
 
 
 
1830956
 
 
 
0d069a0
1830956
 
8465f96
0d069a0
8465f96
2cbbd1d
 
1830956
5723ff3
 
 
 
 
 
 
 
1830956
 
 
 
 
 
 
 
 
 
 
2cbbd1d
5723ff3
 
2cbbd1d
 
 
 
 
8465f96
2cbbd1d
a4c8f8f
1830956
2cbbd1d
1830956
2cbbd1d
 
1830956
 
 
a4c8f8f
2cbbd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1830956
2cbbd1d
1830956
2cbbd1d
1830956
2cbbd1d
1830956
2cbbd1d
1830956
 
8465f96
1830956
 
2cbbd1d
1830956
a4c8f8f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import pandas as pd
import streamlit as st
from q_learning_chatbot import QLearningChatbot
from xgb_mental_health import MentalHealthClassifier
from bm25_retreive_question import QuestionRetriever as QuestionRetriever_bm25
from Chromadb_storage_JyotiNigam import QuestionRetriever as QuestionRetriever_chromaDB
from llm_response_generator import LLLResponseGenerator
import os
import re

# Streamlit UI
st.title("FOMO Fix - RL-based Mental Health Assistant")

# Define states and actions
states = [
    "Negative",
    "Moderately Negative",
    "Neutral",
    "Moderately Positive",
    "Positive",
]
actions = ["encouragement", "empathy", "spiritual"]

# Initialize Q-learning chatbot and mental health classifier
chatbot = QLearningChatbot(states, actions)

# Initialize MentalHealthClassifier
# data_path = "/Users/jaelinlee/Documents/projects/fomo/input/data.csv"
data_path = "data/data.csv"
tokenizer_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
mental_classifier_model_path = "mental_health_model.pkl"
mental_classifier = MentalHealthClassifier(data_path, mental_classifier_model_path)

if not os.path.exists(mental_classifier_model_path):
    mental_classifier.initialize_tokenizer(tokenizer_model_name)
    X, y = mental_classifier.preprocess_data()
    y_test, y_pred = mental_classifier.train_model(X, y)
    mental_classifier.save_model()
else:
    mental_classifier.load_model()
    mental_classifier.initialize_tokenizer(tokenizer_model_name)  # Ensure tokenizer is initialized if loading model from pickle
    # X, y = mental_classifier.preprocess_data()  # Preprocess data again if needed
    # mental_classifier.model.fit(X, y)  # Fit the loaded model to the data

# Function to display Q-table
def display_q_table(q_values, states, actions):
    q_table_dict = {"State": states}
    for i, action in enumerate(actions):
        q_table_dict[action] = q_values[:, i]

    q_table_df = pd.DataFrame(q_table_dict)
    return q_table_df

def remove_html_tags(text):
    # clean_text = re.sub("<.*?>", "", text)
    clean_text = re.sub(r'<.*?>|- |"|\\n', '', text)
    # Remove indentation
    clean_text = clean_text.strip()
    # Remove new lines
    clean_text = clean_text.replace('\n', ' ')
    return clean_text

# Initialize memory
if "entered_text" not in st.session_state:
    st.session_state.entered_text = []
if "entered_mood" not in st.session_state:
    st.session_state.entered_mood = []
if "messages" not in st.session_state:
    st.session_state.messages = []
if "user_sentiment" not in st.session_state:
    st.session_state.user_sentiment =  "Neutral"
if "mood_trend" not in st.session_state:
    st.session_state.mood_trend =  "Unchanged"
if "predicted_mental_category" not in st.session_state:
    st.session_state.predicted_mental_category =  ""
if "ai_tone" not in st.session_state:
    st.session_state.ai_tone =  "Empathy"
if "mood_trend_symbol" not in st.session_state:
    st.session_state.mood_trend_symbol =  ""
if "show_question" not in st.session_state:
    st.session_state.show_question =  False
if "asked_questions" not in st.session_state:
    st.session_state.asked_questions = []

# Select Question Retriever
selected_retriever_option = st.sidebar.selectbox(
    "Choose Question Retriever", ("BM25", "ChromaDB")
)
if selected_retriever_option == "BM25":
    retriever = QuestionRetriever_bm25()
if selected_retriever_option == "ChromaDB":
    retriever = QuestionRetriever_chromaDB()

for message in st.session_state.messages:
    with st.chat_message(message.get("role")):
        st.write(message.get("content"))

section_visible = False


# Collect user input
user_message = st.chat_input("Type your message here:")

# Take user input
if user_message:
    st.session_state.entered_text.append(user_message)

    st.session_state.messages.append({"role": "user", "content": user_message})
    with st.chat_message("user"):
        st.write(user_message)

    # Detect mental condition
    with st.spinner("Processing..."):
        mental_classifier.initialize_tokenizer(tokenizer_model_name)
        mental_classifier.preprocess_data()
        predicted_mental_category = mental_classifier.predict_category(user_message)
        print("Predicted mental health condition:", predicted_mental_category)

        # Detect sentiment
        user_sentiment = chatbot.detect_sentiment(user_message)

        # Retrieve question
        if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]:
            question = retriever.get_response(user_message, predicted_mental_category)
            show_question = True
        else:
            show_question = False
            question = ""
            predicted_mental_category = ""

        # Update mood history / mood_trend
        chatbot.update_mood_history()
        mood_trend = chatbot.check_mood_trend()

        # Define rewards
        if user_sentiment in ["Positive", "Moderately Positive"]:
            if mood_trend == "increased":
                reward = +1
                mood_trend_symbol = " ⬆️"
            elif mood_trend == "unchanged":
                reward = +0.8
                mood_trend_symbol = ""
            else:  # decreased
                reward = -0.2
                mood_trend_symbol = " ⬇️"
        else:
            if mood_trend == "increased":
                reward = +1
                mood_trend_symbol = " ⬆️"
            elif mood_trend == "unchanged":
                reward = -0.2
                mood_trend_symbol = ""
            else:  # decreased
                reward = -1
                mood_trend_symbol = " ⬇️"

        print(
            f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑"
        )

        # Update Q-values
        chatbot.update_q_values(
            user_sentiment, chatbot.actions[0], reward, user_sentiment
        )

        # Get recommended action based on the updated Q-values
        ai_tone = chatbot.get_action(user_sentiment)
        print(ai_tone)

        print(st.session_state.messages)



        # LLM Response Generator
        HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')

        llm_model = LLLResponseGenerator()
        temperature = 0.5
        max_length = 128

        # Collect all messages exchanged so far into a single text string
        all_messages = "\n".join([message.get("content") for message in st.session_state.messages[-3:-1]])

        #Question asked to the user: {question}

        template = """INSTRUCTIONS: {context}
            
                Respond to the user with a tone of {ai_tone}. 
                
                Response by the user: {user_text}  
                Response;
                """
        context = f"You are a mental health supporting non-medical assistant. Provide brief advice. DO NOT ASK ANY QUESTION. DO NOT REPEAT YOURSELF. {all_messages}" # and ask a relevant question back to the user
        
        llm_response = llm_model.llm_inference(
            model_type="huggingface",
            question=question,
            prompt_template=template,
            context=context,
            ai_tone=ai_tone,
            questionnaire=predicted_mental_category,
            user_text=user_message,
            temperature=temperature,
            max_length=max_length,
        )

        llm_response = remove_html_tags(llm_response)

        if show_question:
            llm_reponse_with_quesiton = f"{llm_response}\n\n{question}"
        else:
            llm_reponse_with_quesiton = llm_response

        # Append the user and AI responses to the chat history
        st.session_state.messages.append({"role": "ai", "content": llm_reponse_with_quesiton})

    with st.chat_message("ai"):
        st.markdown(llm_reponse_with_quesiton)
    # st.write(f"{llm_response}")
    # if show_question:
    #     st.write(f"{question}")
    # else:
    # user doesn't feel negative.
    # get question to ecourage even more positive behaviour


        # Update data to memory
        st.session_state.user_sentiment =  user_sentiment
        st.session_state.mood_trend =  mood_trend
        st.session_state.predicted_mental_category =  predicted_mental_category
        st.session_state.ai_tone =  ai_tone
        st.session_state.mood_trend_symbol =  mood_trend_symbol
        st.session_state.show_question =  show_question

# Show/hide "Behind the Scene" section
# section_visible = st.sidebar.button('Show/Hide Behind the Scene')

with st.sidebar.expander('Behind the Scene', expanded=section_visible):
    st.subheader("What AI is doing:")
    # Use the values stored in session state
    st.write(
        f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})"
    )
    if st.session_state.show_question:
        st.write(
            f"- Possible Mental Condition: {st.session_state.predicted_mental_category.capitalize()}"
        )
    st.write(f"- AI Tone: {st.session_state.ai_tone.capitalize()}")
    st.write(f"- Question retrieved from: {selected_retriever_option}")
    st.write(
        f"- If the user feels negative, moderately negative, or neutral, at the end of the AI response, it adds a mental health condition related question. The question is retrieved from DB. The categories of questions are limited to Depression, Anxiety, and ADHD which are most associated with FOMO related to excessive social media usage."
    )
    st.write(
        f"- Below q-table is continuously updated after each interaction with the user. If the user's mood increases, AI gets a reward. Else, AI gets a punishment."
    )

    # Display Q-table
    st.dataframe(display_q_table(chatbot.q_values, states, actions))