File size: 2,886 Bytes
bd9870c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
#---
#- Author: Jaelin Lee, Jyoti Nigam
#- Date: Mar 16, 2024
#- Description: storing data into vector database called Chroma DB. Based on user input, retrieve most relevant info from knowledge base.
#- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base.
#---
import nltk
from nltk.tokenize import word_tokenize
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
# Download NLTK data for tokenization
nltk.download('punkt')
import os
global db
class QuestionRetriever:
def load_documents(self,file_name):
current_directory = os.getcwd()
data_directory = os.path.join(current_directory, "data")
file_path = os.path.join(data_directory, file_name)
loader = TextLoader(file_path)
documents = loader.load()
return documents
def store_data_in_vector_db(self,documents):
# global db
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0,separator="\n")
docs = text_splitter.split_documents(documents)
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
# print(docs)
# load it into Chroma
db = Chroma.from_documents(docs, embedding_function)
return db
def get_response(self, user_query, predicted_mental_category):
if predicted_mental_category == "depression":
documents=self.load_documents("depression_questions.txt")
elif predicted_mental_category == "adhd":
documents=self.load_documents("adhd_questions.txt")
elif predicted_mental_category == "anxiety":
documents=self.load_documents("anxiety_questions.txt")
else:
print("Sorry, allowed predicted_mental_category is ['depresison', 'adhd', 'anxiety'].")
return
db=self.store_data_in_vector_db(documents)
docs = db.similarity_search(user_query)
most_similar_question = docs[0].page_content.split("\n")[0] # Extract the first question
if user_query==most_similar_question:
most_similar_question=docs[1].page_content.split("\n")[0]
print(most_similar_question)
return most_similar_question
if __name__ == "__main__":
model = QuestionRetriever()
user_input = input("User: ")
predicted_mental_condition = "depression"
response = model.get_response(user_input, predicted_mental_condition)
print("Chatbot:", response) |