SwatGarg commited on
Commit
0581b98
1 Parent(s): 95a63de

Create q_learning_chatbot.py

Browse files
Files changed (1) hide show
  1. q_learning_chatbot.py +82 -0
q_learning_chatbot.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from xgb_mental_health import MentalHealthClassifier
5
+ import pickle
6
+ import streamlit as st
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from torch.nn.functional import softmax
9
+ import torch
10
+
11
+
12
+
13
+ class QLearningChatbot:
14
+ def __init__(self, states, actions, learning_rate=0.9, discount_factor=0.1):
15
+ self.states = states
16
+ self.actions = actions
17
+ self.learning_rate = learning_rate
18
+ self.discount_factor = discount_factor
19
+ self.q_values = np.random.rand(len(states), len(actions))
20
+ self.mood = "Neutral"
21
+ self.mood_history = []
22
+ self.mood_history_int = []
23
+ self.tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
24
+ self.bert_sentiment_model_path = "bert_sentiment.pkl"
25
+ self.bert_sentiment_model = self.load_model() if os.path.exists(self.bert_sentiment_model_path) else AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
26
+
27
+
28
+ def detect_sentiment(self, input_text):
29
+ # Encode the text
30
+ encoded_input = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=512)
31
+
32
+ # Perform inference
33
+ with torch.no_grad():
34
+ output = self.bert_sentiment_model(**encoded_input)
35
+
36
+ # Process the output (softmax to get probabilities)
37
+ scores = softmax(output.logits, dim=1)
38
+
39
+ # Map scores to sentiment labels
40
+ labels = ['Negative', 'Moderately Negative', 'Neutral', 'Moderately Positive', 'Positive']
41
+ scores = scores.numpy().flatten()
42
+ scores_dict = {label: score for label, score in zip(labels, scores)}
43
+ highest_sentiment = max(scores_dict, key=scores_dict.get)
44
+ self.mood = highest_sentiment
45
+ return highest_sentiment
46
+
47
+ def get_action(self, current_state):
48
+ current_state_index = self.states.index(current_state)
49
+ # print(np.argmax(self.q_values[current_state_index, :]))
50
+ return self.actions[np.argmax(self.q_values[current_state_index, :])]
51
+
52
+ def update_q_values(self, current_state, action, reward, next_state):
53
+ # print(f"state-reward: {current_state} - {reward} -- (b)")
54
+ current_state_index = self.states.index(current_state)
55
+ action_index = self.actions.index(action)
56
+ next_state_index = self.states.index(next_state)
57
+
58
+ current_q_value = self.q_values[current_state_index, action_index]
59
+ max_next_q_value = np.max(self.q_values[next_state_index, :])
60
+
61
+ new_q_value = current_q_value + self.learning_rate * (reward + self.discount_factor * max_next_q_value - current_q_value)
62
+ self.q_values[current_state_index, action_index] = new_q_value
63
+
64
+ def update_mood_history(self):
65
+ st.session_state.entered_mood.append(self.mood)
66
+ self.mood_history = st.session_state.entered_mood
67
+ return self.mood_history
68
+
69
+ def check_mood_trend(self):
70
+ mood_dict = {'Negative': 1, 'Moderately Negative': 2, 'Neutral': 3, 'Moderately Positive': 4, 'Positive': 5}
71
+
72
+ if len(self.mood_history) >= 2:
73
+ self.mood_history_int = [mood_dict.get(x) for x in self.mood_history]
74
+ recent_moods = self.mood_history_int[-2:]
75
+ if recent_moods[-1] > recent_moods[-2]:
76
+ return 'increased'
77
+ elif recent_moods[-1] < recent_moods[-2]:
78
+ return 'decreased'
79
+ else:
80
+ return 'unchanged'
81
+ else:
82
+ return 'unchanged'