Spaces:
Sleeping
Sleeping
#--- | |
#- Author: Jaelin Lee, Jyoti Nigam | |
#- Date: Mar 16, 2024 | |
#- Description: storing data into vector database called Chroma DB. Based on user input, retrieve most relevant info from knowledge base. | |
#- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base. | |
#--- | |
import nltk | |
from nltk.tokenize import word_tokenize | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain_community.vectorstores import Chroma | |
from langchain_text_splitters import CharacterTextSplitter | |
# Download NLTK data for tokenization | |
nltk.download('punkt') | |
import os | |
global db | |
class QuestionRetriever: | |
def load_documents(self,file_name): | |
current_directory = os.getcwd() | |
data_directory = os.path.join(current_directory, "data") | |
file_path = os.path.join(data_directory, file_name) | |
loader = TextLoader(file_path) | |
documents = loader.load() | |
return documents | |
def store_data_in_vector_db(self,documents): | |
# global db | |
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0,separator="\n") | |
docs = text_splitter.split_documents(documents) | |
# create the open-source embedding function | |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# print(docs) | |
# load it into Chroma | |
db = Chroma.from_documents(docs, embedding_function) | |
return db | |
def get_response(self, user_query, predicted_mental_category): | |
if predicted_mental_category == "depression": | |
documents=self.load_documents("depression_questions.txt") | |
elif predicted_mental_category == "adhd": | |
documents=self.load_documents("adhd_questions.txt") | |
elif predicted_mental_category == "anxiety": | |
documents=self.load_documents("anxiety_questions.txt") | |
else: | |
print("Sorry, allowed predicted_mental_category is ['depresison', 'adhd', 'anxiety'].") | |
return | |
db=self.store_data_in_vector_db(documents) | |
docs = db.similarity_search(user_query) | |
most_similar_question = docs[0].page_content.split("\n")[0] # Extract the first question | |
if user_query==most_similar_question: | |
most_similar_question=docs[1].page_content.split("\n")[0] | |
print(most_similar_question) | |
return most_similar_question | |
if __name__ == "__main__": | |
model = QuestionRetriever() | |
user_input = input("User: ") | |
predicted_mental_condition = "depression" | |
response = model.get_response(user_input, predicted_mental_condition) | |
print("Chatbot:", response) |