#--- #- Author: Jaelin Lee #- Date: Mar 23, 2024 #- Description: Similarity search using BM25. Based on user input, retrieve most relevant info from knowledge base. #- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base. #--- from rank_bm25 import BM25Okapi import nltk from nltk.tokenize import word_tokenize import streamlit as st # Download NLTK data for tokenization nltk.download('punkt') class QuestionRetriever: def __init__(self): self.depression_questions = self.load_questions_from_file("data/depression_questions.txt") self.adhd_questions = self.load_questions_from_file("data/adhd_questions.txt") self.anxiety_questions = self.load_questions_from_file("data/anxiety_questions.txt") self.social_isolation_questions = self.load_questions_from_file("data/social_isolation.txt") self.cyberbullying_questions = self.load_questions_from_file("data/cyberbullying.txt") self.social_media_addiction_questions = self.load_questions_from_file("data/socialmediaaddiction.txt") def load_questions_from_file(self, filename): with open(filename, "r") as file: questions = file.readlines() # Remove any leading or trailing whitespace and newline characters questions = [question.strip() for question in questions] return questions def get_response(self, user_query, predicted_mental_category): if predicted_mental_category == "depression": knowledge_base = self.depression_questions elif predicted_mental_category == "adhd": knowledge_base = self.adhd_questions elif predicted_mental_category == "anxiety": knowledge_base = self.anxiety_questions elif predicted_mental_category == "social isolation": knowledge_base = self.social_isolation_questions elif predicted_mental_category == "cyberbullying": knowledge_base = self.cyberbullying_questions elif predicted_mental_category == "social media addiction": knowledge_base = self.social_media_addiction_questions else: knowledge_base = None print("Sorry, I didn't understand that.") if knowledge_base: tokenized_docs = [word_tokenize(doc.lower()) for doc in knowledge_base] # Ensure lowercase for consistency bm25 = BM25Okapi(tokenized_docs) tokenized_query = word_tokenize(user_query.lower()) # Ensure lowercase for consistency doc_scores = bm25.get_scores(tokenized_query) # Sort document scores in descending order sorted_doc_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True) for doc_index in sorted_doc_indices: response = knowledge_base[doc_index] if response not in st.session_state.asked_questions: return response # If no response found that is not in asked_questions, return None return None else: return None if __name__ == "__main__": # knowledge_base = "depression_questions" predicted_mental_category = "cyberbullying" model = QuestionRetriever() user_input = input("User: ") response = model.get_response(user_input, predicted_mental_category) print("Chatbot:", response)