SwatGarg commited on
Commit
6ede6bb
1 Parent(s): 5e10bd7

Update q_learning_chatbot.py

Browse files
Files changed (1) hide show
  1. q_learning_chatbot.py +6 -13
q_learning_chatbot.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import numpy as np
3
  import pandas as pd
4
- from xgb_mental_health import MentalHealthClassifier
5
  import pickle
6
  import streamlit as st
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
@@ -11,12 +11,11 @@ import torch
11
 
12
 
13
  class QLearningChatbot:
14
- def __init__(self, states, actions, learning_rate=0.9, discount_factor=0.1):
15
  self.states = states
16
- self.actions = actions
17
  self.learning_rate = learning_rate
18
  self.discount_factor = discount_factor
19
- self.q_values = np.random.rand(len(states), len(actions))
20
  self.mood = "Neutral"
21
  self.mood_history = []
22
  self.mood_history_int = []
@@ -44,22 +43,16 @@ class QLearningChatbot:
44
  self.mood = highest_sentiment
45
  return highest_sentiment
46
 
47
- def get_action(self, current_state):
48
- current_state_index = self.states.index(current_state)
49
- # print(np.argmax(self.q_values[current_state_index, :]))
50
- return self.actions[np.argmax(self.q_values[current_state_index, :])]
51
-
52
- def update_q_values(self, current_state, action, reward, next_state):
53
  # print(f"state-reward: {current_state} - {reward} -- (b)")
54
  current_state_index = self.states.index(current_state)
55
- action_index = self.actions.index(action)
56
  next_state_index = self.states.index(next_state)
57
 
58
- current_q_value = self.q_values[current_state_index, action_index]
59
  max_next_q_value = np.max(self.q_values[next_state_index, :])
60
 
61
  new_q_value = current_q_value + self.learning_rate * (reward + self.discount_factor * max_next_q_value - current_q_value)
62
- self.q_values[current_state_index, action_index] = new_q_value
63
 
64
  def update_mood_history(self):
65
  st.session_state.entered_mood.append(self.mood)
 
1
  import os
2
  import numpy as np
3
  import pandas as pd
4
+
5
  import pickle
6
  import streamlit as st
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
11
 
12
 
13
  class QLearningChatbot:
14
+ def __init__(self, states, learning_rate=0.9, discount_factor=0.1):
15
  self.states = states
 
16
  self.learning_rate = learning_rate
17
  self.discount_factor = discount_factor
18
+ self.q_values = np.random.rand(len(states))
19
  self.mood = "Neutral"
20
  self.mood_history = []
21
  self.mood_history_int = []
 
43
  self.mood = highest_sentiment
44
  return highest_sentiment
45
 
46
+ def update_q_values(self, current_state, reward, next_state):
 
 
 
 
 
47
  # print(f"state-reward: {current_state} - {reward} -- (b)")
48
  current_state_index = self.states.index(current_state)
 
49
  next_state_index = self.states.index(next_state)
50
 
51
+ current_q_value = self.q_values[current_state_index]
52
  max_next_q_value = np.max(self.q_values[next_state_index, :])
53
 
54
  new_q_value = current_q_value + self.learning_rate * (reward + self.discount_factor * max_next_q_value - current_q_value)
55
+ self.q_values[current_state_index] = new_q_value
56
 
57
  def update_mood_history(self):
58
  st.session_state.entered_mood.append(self.mood)