File size: 2,886 Bytes
bd9870c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#---
#- Author: Jaelin Lee, Jyoti Nigam
#- Date: Mar 16, 2024
#- Description: storing data into vector database called Chroma DB. Based on user input, retrieve most relevant info from knowledge base.
#- How it works: Tokenize the user input text using NLTK. Then, get TF-IDF based score against knowledge base using BM25. Get the index of the most similar item within knowledgebase using `argmax()`. Then, using the index, retrieve that item from the knowledge base.
#---

import nltk
from nltk.tokenize import word_tokenize
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings,
)
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter

# Download NLTK data for tokenization
nltk.download('punkt')
import os
global db
class QuestionRetriever:
   
    def load_documents(self,file_name):
      current_directory = os.getcwd()
      data_directory = os.path.join(current_directory, "data")
      file_path = os.path.join(data_directory, file_name)
      loader = TextLoader(file_path)
      documents = loader.load()
      return documents
    
    def store_data_in_vector_db(self,documents):
    #   global db
      text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0,separator="\n")
      docs = text_splitter.split_documents(documents)
      # create the open-source embedding function
      embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
      # print(docs)
      # load it into Chroma
      db = Chroma.from_documents(docs, embedding_function)
      return db

    def get_response(self, user_query, predicted_mental_category):
        if predicted_mental_category == "depression":
          documents=self.load_documents("depression_questions.txt")              

        elif predicted_mental_category == "adhd":
          documents=self.load_documents("adhd_questions.txt")

        elif predicted_mental_category == "anxiety":
          documents=self.load_documents("anxiety_questions.txt")

        else:
            print("Sorry, allowed predicted_mental_category is ['depresison', 'adhd', 'anxiety'].")
            return
        db=self.store_data_in_vector_db(documents)

        docs = db.similarity_search(user_query)
        most_similar_question = docs[0].page_content.split("\n")[0]  # Extract the first question
        if user_query==most_similar_question:
          most_similar_question=docs[1].page_content.split("\n")[0]

        print(most_similar_question)
        return most_similar_question

if __name__ == "__main__":    
    model = QuestionRetriever()
    user_input = input("User: ")

    predicted_mental_condition = "depression"
    response = model.get_response(user_input, predicted_mental_condition)
    print("Chatbot:", response)