File size: 3,563 Bytes
bd9870c
 
 
 
 
 
 
 
 
 
cb96e5d
bd9870c
 
 
 
 
 
fd117c1
 
 
 
 
 
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c50568e
bd9870c
 
 
 
c50568e
 
bd9870c
c50568e
 
 
 
 
 
bd9870c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#---
#- Author: Jaelin Lee
#- Date: Mar 23, 2024
#- Description: Similarity search using BM25. Based on user input, retrieve most relevant info from knowledge base.
#- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base.
#---

from rank_bm25 import BM25Okapi
import nltk
from nltk.tokenize import word_tokenize
import streamlit as st

# Download NLTK data for tokenization
nltk.download('punkt')

class QuestionRetriever:
    def __init__(self):
        self.depression_questions = self.load_questions_from_file("data/depression_questions.txt")
        self.adhd_questions = self.load_questions_from_file("data/adhd_questions.txt")
        self.anxiety_questions = self.load_questions_from_file("data/anxiety_questions.txt")
        self.social_isolation_questions = self.load_questions_from_file("data/social_isolation.txt")
        self.cyberbullying_questions = self.load_questions_from_file("data/cyberbullying.txt")
        self.social_media_addiction_questions = self.load_questions_from_file("data/socialmediaaddiction.txt")


    def load_questions_from_file(self, filename):
        with open(filename, "r") as file:
            questions = file.readlines()
        # Remove any leading or trailing whitespace and newline characters
        questions = [question.strip() for question in questions]
        return questions
    
    def get_response(self, user_query, predicted_mental_category):
        if predicted_mental_category == "depression":
            knowledge_base = self.depression_questions
        elif predicted_mental_category == "adhd":
            knowledge_base = self.adhd_questions
        elif predicted_mental_category == "anxiety":
            knowledge_base = self.anxiety_questions
        elif predicted_mental_category == "social isolation":
            knowledge_base = self.social_isolation_questions
        elif predicted_mental_category == "cyberbullying":
            knowledge_base = self.cyberbullying_questions
        elif predicted_mental_category == "social media addiction":
            knowledge_base = self.social_media_addiction_questions

        else:
            knowledge_base = None
            print("Sorry, I didn't understand that.")

        if knowledge_base:
            tokenized_docs = [word_tokenize(doc.lower()) for doc in knowledge_base]  # Ensure lowercase for consistency
            bm25 = BM25Okapi(tokenized_docs)
            tokenized_query = word_tokenize(user_query.lower())  # Ensure lowercase for consistency
            doc_scores = bm25.get_scores(tokenized_query)

            # Sort document scores in descending order
            sorted_doc_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)

            for doc_index in sorted_doc_indices:
                response = knowledge_base[doc_index]
                if response not in st.session_state.asked_questions:
                    return response
            # If no response found that is not in asked_questions, return None
            return None
        else:
            return None

if __name__ == "__main__":
    # knowledge_base = "depression_questions"
    predicted_mental_category = "cyberbullying"
    model = QuestionRetriever()
    user_input = input("User: ")

    response = model.get_response(user_input, predicted_mental_category)
    print("Chatbot:", response)