|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from rank_bm25 import BM25Okapi |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
|
|
|
|
nltk.download('punkt') |
|
|
|
class QuestionRetriever: |
|
def __init__(self): |
|
self.depression_questions = self.load_questions_from_file("data/depression_questions.txt") |
|
self.adhd_questions = self.load_questions_from_file("data/adhd_questions.txt") |
|
self.anxiety_questions = self.load_questions_from_file("data/anxiety_questions.txt") |
|
self.social_isolation_questions = self.load_questions_from_file("data/social_isolation.txt") |
|
self.cyberbullying_questions = self.load_questions_from_file("data/cyberbullying.txt") |
|
self.social_media_addiction_questions = self.load_questions_from_file("data/socialmediaaddiction.txt") |
|
|
|
|
|
def load_questions_from_file(self, filename): |
|
with open(filename, "r") as file: |
|
questions = file.readlines() |
|
|
|
questions = [question.strip() for question in questions] |
|
return questions |
|
|
|
def get_response(self, user_query, predicted_mental_category): |
|
if predicted_mental_category == "depression": |
|
knowledge_base = self.depression_questions |
|
elif predicted_mental_category == "adhd": |
|
knowledge_base = self.adhd_questions |
|
elif predicted_mental_category == "anxiety": |
|
knowledge_base = self.anxiety_questions |
|
elif predicted_mental_category == "social isolation": |
|
knowledge_base = self.social_isolation_questions |
|
elif predicted_mental_category == "cyberbullying": |
|
knowledge_base = self.cyberbullying_questions |
|
elif predicted_mental_category == "social media addiction": |
|
knowledge_base = self.social_media_addiction_questions |
|
|
|
else: |
|
knowledge_base = None |
|
print("Sorry, I didn't understand that.") |
|
|
|
if knowledge_base: |
|
tokenized_docs = [word_tokenize(doc.lower()) for doc in knowledge_base] |
|
bm25 = BM25Okapi(tokenized_docs) |
|
tokenized_query = word_tokenize(user_query.lower()) |
|
doc_scores = bm25.get_scores(tokenized_query) |
|
|
|
|
|
most_relevant_doc_index = doc_scores.argmax() |
|
|
|
|
|
response = knowledge_base[most_relevant_doc_index] |
|
return response |
|
else: |
|
return None |
|
|
|
if __name__ == "__main__": |
|
|
|
predicted_mental_category = "cyberbullying" |
|
model = QuestionRetriever() |
|
user_input = input("User: ") |
|
|
|
response = model.get_response(user_input, predicted_mental_category) |
|
print("Chatbot:", response) |
|
|
|
|