Spaces:
Runtime error
Runtime error
Create q_learning_chatbot.py
Browse files- q_learning_chatbot.py +82 -0
q_learning_chatbot.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from xgb_mental_health import MentalHealthClassifier
|
5 |
+
import pickle
|
6 |
+
import streamlit as st
|
7 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
+
from torch.nn.functional import softmax
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class QLearningChatbot:
|
14 |
+
def __init__(self, states, actions, learning_rate=0.9, discount_factor=0.1):
|
15 |
+
self.states = states
|
16 |
+
self.actions = actions
|
17 |
+
self.learning_rate = learning_rate
|
18 |
+
self.discount_factor = discount_factor
|
19 |
+
self.q_values = np.random.rand(len(states), len(actions))
|
20 |
+
self.mood = "Neutral"
|
21 |
+
self.mood_history = []
|
22 |
+
self.mood_history_int = []
|
23 |
+
self.tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
|
24 |
+
self.bert_sentiment_model_path = "bert_sentiment.pkl"
|
25 |
+
self.bert_sentiment_model = self.load_model() if os.path.exists(self.bert_sentiment_model_path) else AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
|
26 |
+
|
27 |
+
|
28 |
+
def detect_sentiment(self, input_text):
|
29 |
+
# Encode the text
|
30 |
+
encoded_input = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=512)
|
31 |
+
|
32 |
+
# Perform inference
|
33 |
+
with torch.no_grad():
|
34 |
+
output = self.bert_sentiment_model(**encoded_input)
|
35 |
+
|
36 |
+
# Process the output (softmax to get probabilities)
|
37 |
+
scores = softmax(output.logits, dim=1)
|
38 |
+
|
39 |
+
# Map scores to sentiment labels
|
40 |
+
labels = ['Negative', 'Moderately Negative', 'Neutral', 'Moderately Positive', 'Positive']
|
41 |
+
scores = scores.numpy().flatten()
|
42 |
+
scores_dict = {label: score for label, score in zip(labels, scores)}
|
43 |
+
highest_sentiment = max(scores_dict, key=scores_dict.get)
|
44 |
+
self.mood = highest_sentiment
|
45 |
+
return highest_sentiment
|
46 |
+
|
47 |
+
def get_action(self, current_state):
|
48 |
+
current_state_index = self.states.index(current_state)
|
49 |
+
# print(np.argmax(self.q_values[current_state_index, :]))
|
50 |
+
return self.actions[np.argmax(self.q_values[current_state_index, :])]
|
51 |
+
|
52 |
+
def update_q_values(self, current_state, action, reward, next_state):
|
53 |
+
# print(f"state-reward: {current_state} - {reward} -- (b)")
|
54 |
+
current_state_index = self.states.index(current_state)
|
55 |
+
action_index = self.actions.index(action)
|
56 |
+
next_state_index = self.states.index(next_state)
|
57 |
+
|
58 |
+
current_q_value = self.q_values[current_state_index, action_index]
|
59 |
+
max_next_q_value = np.max(self.q_values[next_state_index, :])
|
60 |
+
|
61 |
+
new_q_value = current_q_value + self.learning_rate * (reward + self.discount_factor * max_next_q_value - current_q_value)
|
62 |
+
self.q_values[current_state_index, action_index] = new_q_value
|
63 |
+
|
64 |
+
def update_mood_history(self):
|
65 |
+
st.session_state.entered_mood.append(self.mood)
|
66 |
+
self.mood_history = st.session_state.entered_mood
|
67 |
+
return self.mood_history
|
68 |
+
|
69 |
+
def check_mood_trend(self):
|
70 |
+
mood_dict = {'Negative': 1, 'Moderately Negative': 2, 'Neutral': 3, 'Moderately Positive': 4, 'Positive': 5}
|
71 |
+
|
72 |
+
if len(self.mood_history) >= 2:
|
73 |
+
self.mood_history_int = [mood_dict.get(x) for x in self.mood_history]
|
74 |
+
recent_moods = self.mood_history_int[-2:]
|
75 |
+
if recent_moods[-1] > recent_moods[-2]:
|
76 |
+
return 'increased'
|
77 |
+
elif recent_moods[-1] < recent_moods[-2]:
|
78 |
+
return 'decreased'
|
79 |
+
else:
|
80 |
+
return 'unchanged'
|
81 |
+
else:
|
82 |
+
return 'unchanged'
|