MindfulMedia_Mentor / bm25_retreive_question.py
jaelin215's picture
Added streamlit
cb96e5d verified
raw
history blame
3.56 kB
#---
#- Author: Jaelin Lee
#- Date: Mar 23, 2024
#- Description: Similarity search using BM25. Based on user input, retrieve most relevant info from knowledge base.
#- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base.
#---
from rank_bm25 import BM25Okapi
import nltk
from nltk.tokenize import word_tokenize
import streamlit as st
# Download NLTK data for tokenization
nltk.download('punkt')
class QuestionRetriever:
def __init__(self):
self.depression_questions = self.load_questions_from_file("data/depression_questions.txt")
self.adhd_questions = self.load_questions_from_file("data/adhd_questions.txt")
self.anxiety_questions = self.load_questions_from_file("data/anxiety_questions.txt")
self.social_isolation_questions = self.load_questions_from_file("data/social_isolation.txt")
self.cyberbullying_questions = self.load_questions_from_file("data/cyberbullying.txt")
self.social_media_addiction_questions = self.load_questions_from_file("data/socialmediaaddiction.txt")
def load_questions_from_file(self, filename):
with open(filename, "r") as file:
questions = file.readlines()
# Remove any leading or trailing whitespace and newline characters
questions = [question.strip() for question in questions]
return questions
def get_response(self, user_query, predicted_mental_category):
if predicted_mental_category == "depression":
knowledge_base = self.depression_questions
elif predicted_mental_category == "adhd":
knowledge_base = self.adhd_questions
elif predicted_mental_category == "anxiety":
knowledge_base = self.anxiety_questions
elif predicted_mental_category == "social isolation":
knowledge_base = self.social_isolation_questions
elif predicted_mental_category == "cyberbullying":
knowledge_base = self.cyberbullying_questions
elif predicted_mental_category == "social media addiction":
knowledge_base = self.social_media_addiction_questions
else:
knowledge_base = None
print("Sorry, I didn't understand that.")
if knowledge_base:
tokenized_docs = [word_tokenize(doc.lower()) for doc in knowledge_base] # Ensure lowercase for consistency
bm25 = BM25Okapi(tokenized_docs)
tokenized_query = word_tokenize(user_query.lower()) # Ensure lowercase for consistency
doc_scores = bm25.get_scores(tokenized_query)
# Sort document scores in descending order
sorted_doc_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)
for doc_index in sorted_doc_indices:
response = knowledge_base[doc_index]
if response not in st.session_state.asked_questions:
return response
# If no response found that is not in asked_questions, return None
return None
else:
return None
if __name__ == "__main__":
# knowledge_base = "depression_questions"
predicted_mental_category = "cyberbullying"
model = QuestionRetriever()
user_input = input("User: ")
response = model.get_response(user_input, predicted_mental_category)
print("Chatbot:", response)