File size: 11,172 Bytes
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69a01d3
bd9870c
 
 
69a01d3
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149d75a
bd9870c
 
 
 
 
8336395
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f33fe30
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
###
# - Author: Jaelin Lee, Abhishek Dutta
# - Date: Mar 23, 2024
# - Description: Streamlit UI for mental health support chatbot using sentiment analsys, RL, BM25/ChromaDB, and LLM.
# - Note:
#   - Updated to UI to show predicted mental health condition in behind the scence regardless of the ositive/negative sentiment
###

from dotenv import load_dotenv, find_dotenv
import pandas as pd
import streamlit as st
from q_learning_chatbot import QLearningChatbot
from xgb_mental_health import MentalHealthClassifier
from bm25_retreive_question import QuestionRetriever as QuestionRetriever_bm25
from Chromadb_storage_JyotiNigam import QuestionRetriever as QuestionRetriever_chromaDB
from llm_response_generator import LLLResponseGenerator
import os
from llama_guard import moderate_chat

# Streamlit UI
st.title("MindfulMedia Mentor")

# Define states and actions
states = [
    "Negative",
    "Moderately Negative",
    "Neutral",
    "Moderately Positive",
    "Positive",
]
actions = ["encouragement", "empathy", "spiritual"]

# Initialize Q-learning chatbot and mental health classifier
chatbot = QLearningChatbot(states, actions)

# Initialize MentalHealthClassifier
# data_path = "/Users/jaelinlee/Documents/projects/fomo/input/data.csv"
data_path = os.path.join("data", "data.csv")
print(data_path)

tokenizer_model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
mental_classifier_model_path = "mental_health_model.pkl"
mental_classifier = MentalHealthClassifier(data_path, mental_classifier_model_path)


# Function to display Q-table
def display_q_table(q_values, states, actions):
    q_table_dict = {"State": states}
    for i, action in enumerate(actions):
        q_table_dict[action] = q_values[:, i]

    q_table_df = pd.DataFrame(q_table_dict)
    return q_table_df


# Initialize memory
if "entered_text" not in st.session_state:
    st.session_state.entered_text = []
if "entered_mood" not in st.session_state:
    st.session_state.entered_mood = []
if "messages" not in st.session_state:
    st.session_state.messages = []
if "user_sentiment" not in st.session_state:
    st.session_state.user_sentiment = "Neutral"
if "mood_trend" not in st.session_state:
    st.session_state.mood_trend = "Unchanged"
if "predicted_mental_category" not in st.session_state:
    st.session_state.predicted_mental_category = ""
if "ai_tone" not in st.session_state:
    st.session_state.ai_tone = "Empathy"
if "mood_trend_symbol" not in st.session_state:
    st.session_state.mood_trend_symbol = ""
if "show_question" not in st.session_state:
    st.session_state.show_question = False
if "asked_questions" not in st.session_state:
    st.session_state.asked_questions = []
# Check if 'llama_guard_enabled' is already in session state, otherwise initialize it
if "llama_guard_enabled" not in st.session_state:
    st.session_state["llama_guard_enabled"] = False  # Default value to False

# Select Question Retriever
selected_retriever_option = st.sidebar.selectbox(
    "Choose Question Retriever", ("BM25", "ChromaDB")
)
if selected_retriever_option == "BM25":
    retriever = QuestionRetriever_bm25()
if selected_retriever_option == "ChromaDB":
    retriever = QuestionRetriever_chromaDB()

for message in st.session_state.messages:
    with st.chat_message(message.get("role")):
        st.write(message.get("content"))

section_visible = True


# Collect user input
user_message = st.chat_input("Type your message here:")


# Modify the checkbox call to include a unique key parameter
llama_guard_enabled = st.sidebar.checkbox(
    "Enable LlamaGuard",
    value=st.session_state["llama_guard_enabled"],
    key="llama_guard_toggle",
)


# Update the session state based on the checkbox interaction
st.session_state["llama_guard_enabled"] = llama_guard_enabled

# Take user input
if user_message:
    st.session_state.entered_text.append(user_message)

    st.session_state.messages.append({"role": "user", "content": user_message})
    with st.chat_message("user"):
        st.write(user_message)

    is_safe = True
    if st.session_state["llama_guard_enabled"]:
        chat = [
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": ""},
        ]
        guard_status = moderate_chat(chat)
        if "unsafe" in guard_status[0]["generated_text"]:
            is_safe = False
        print("Guard status", guard_status)

    if is_safe == False:
        response = "Due to eithical and safety reasons, I can't provide the help you need. Please reach out to someone who can, like a family member, friend, or therapist. In urgent situations, contact emergency services or a crisis hotline. Remember, asking for help is brave, and you're not alone."
        st.session_state.messages.append({"role": "ai", "content": response})
        with st.chat_message("ai"):
            st.markdown(response)
    else:
        # Detect mental condition
        with st.spinner("Processing..."):
            mental_classifier.initialize_tokenizer(tokenizer_model_name)
            mental_classifier.preprocess_data()
            predicted_mental_category = mental_classifier.predict_category(user_message)
            print("Predicted mental health condition:", predicted_mental_category)

            # Detect sentiment
            user_sentiment = chatbot.detect_sentiment(user_message)

            # Retrieve question
            if user_sentiment in ["Negative", "Moderately Negative", "Neutral"]:
                question = retriever.get_response(
                    user_message, predicted_mental_category
                )
                show_question = True
            else:
                show_question = False
                question = ""
                # predicted_mental_category = ""

            # Update mood history / mood_trend
            chatbot.update_mood_history()
            mood_trend = chatbot.check_mood_trend()

            # Define rewards
            if user_sentiment in ["Positive", "Moderately Positive"]:
                if mood_trend == "increased":
                    reward = +1
                    mood_trend_symbol = " ⬆️"
                elif mood_trend == "unchanged":
                    reward = +0.8
                    mood_trend_symbol = ""
                else:  # decreased
                    reward = -0.2
                    mood_trend_symbol = " ⬇️"
            else:
                if mood_trend == "increased":
                    reward = +1
                    mood_trend_symbol = " ⬆️"
                elif mood_trend == "unchanged":
                    reward = -0.2
                    mood_trend_symbol = ""
                else:  # decreased
                    reward = -1
                    mood_trend_symbol = " ⬇️"

            print(
                f"mood_trend - sentiment - reward: {mood_trend} - {user_sentiment} - 🛑{reward}🛑"
            )

            # Update Q-values
            chatbot.update_q_values(
                user_sentiment, chatbot.actions[0], reward, user_sentiment
            )

            # Get recommended action based on the updated Q-values
            ai_tone = chatbot.get_action(user_sentiment)
            print(ai_tone)

            print(st.session_state.messages)

            # LLM Response Generator
            load_dotenv(find_dotenv())

            llm_model = LLLResponseGenerator()
            temperature = 0.5
            max_length = 128

            # Collect all messages exchanged so far into a single text string
            all_messages = "\n".join(
                [message.get("content") for message in st.session_state.messages]
            )

            # Question asked to the user: {question}

            template = """INSTRUCTIONS: {context}
            
                Respond to the user with a tone of {ai_tone}. 
                
                Response by the user: {user_text}  
                Response;
                """
            context = f"You are a mental health supporting non-medical assistant. Provide some advice and ask a relevant question back to the user. {all_messages}"

            llm_response = llm_model.llm_inference(
                model_type="huggingface",
                question=question,
                prompt_template=template,
                context=context,
                ai_tone=ai_tone,
                questionnaire=predicted_mental_category,
                user_text=user_message,
                temperature=temperature,
                max_length=max_length,
            )

            if show_question:
                llm_reponse_with_quesiton = f"{llm_response}\n\n{question}"
            else:
                llm_reponse_with_quesiton = llm_response

            # Append the user and AI responses to the chat history
            st.session_state.messages.append(
                {"role": "ai", "content": llm_reponse_with_quesiton}
            )

        with st.chat_message("ai"):
            st.markdown(llm_reponse_with_quesiton)
            # st.write(f"{llm_response}")
            # if show_question:
            #     st.write(f"{question}")
            # else:
            # user doesn't feel negative.
            # get question to ecourage even more positive behaviour

            # Update data to memory
            st.session_state.user_sentiment = user_sentiment
            st.session_state.mood_trend = mood_trend
            st.session_state.predicted_mental_category = predicted_mental_category
            st.session_state.ai_tone = ai_tone
            st.session_state.mood_trend_symbol = mood_trend_symbol
            st.session_state.show_question = show_question

    # Show/hide "Behind the Scene" section
    # section_visible = st.sidebar.button('Show/Hide Behind the Scene')

    with st.sidebar.expander("Behind the Scene", expanded=section_visible):
        st.subheader("What AI is doing:")
        # Use the values stored in session state
        st.write(
            f"- Detected User Tone: {st.session_state.user_sentiment} ({st.session_state.mood_trend.capitalize()}{st.session_state.mood_trend_symbol})"
        )
        # if st.session_state.show_question:
        st.write(
            f"- Possible Mental Condition: {st.session_state.predicted_mental_category.capitalize()}"
        )
        st.write(f"- AI Tone: {st.session_state.ai_tone.capitalize()}")
        st.write(f"- Question retrieved from: {selected_retriever_option}")
        st.write(
            f"- If the user feels negative, moderately negative, or neutral, at the end of the AI response, it adds a mental health condition related question. The question is retrieved from DB. The categories of questions are limited to Depression, Anxiety, ADHD, Social Media Addiction, Social Isolation, and Cyberbullying which are most associated with FOMO related to excessive social media usage."
        )
        st.write(
            f"- Below q-table is continuously updated after each interaction with the user. If the user's mood increases, AI gets a reward. Else, AI gets a punishment."
        )

        # Display Q-table
        st.dataframe(display_q_table(chatbot.q_values, states, actions))