File size: 2,806 Bytes
7d6888a
 
 
 
 
 
db1d7de
 
7efaec1
 
1cb87ff
 
 
 
 
 
 
 
7d6888a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3cd6cc
7d6888a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_community.embeddings import HuggingFaceEmbeddings
from pymongo import MongoClient
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain_community.llms import HuggingFaceEndpoint

import os

config= {
    'MONGODB_CONN_STRING': os.getenv('MONGODB_CONN_STRING'),  
    'HUGGINGFACEHUB_API_TOKEN': os.getenv('HUGGINGFACEHUB_API_TOKEN'),
    'DB_NAME':os.getenv('DB_NAME'),
    'VECTOR_SEARCH_INDEX':os.getenv('VECTOR_SEARCH_INDEX'),
    'PASSWORD_DB': os.getenv('PASSWORD_DB')
    
}
client = MongoClient(config['MONGODB_CONN_STRING'])
embeddings = HuggingFaceEmbeddings(model_name= "intfloat/e5-large-v2")

llm_model = HuggingFaceEndpoint(repo_id='mistralai/Mistral-7B-Instruct-v0.2',
                                huggingfacehub_api_token=config['HUGGINGFACEHUB_API_TOKEN'],
                                temperature=0.3)

template = """
        <s>[INST] Instruction:Your are a helpful chatbot who can answer all data science ,anime and manga questions.
        You have to follow these rules strictly while answering the question based on context:
        1. Do not use the word context or based on context which is provided in answers.
        2. If there is no context you have to answer in 128 words not more than that.
        3. context are in series format so make your own best pattern based on that give answer.
        [/INST]
        context: 
        {context}</s>
        ### QUESTION:
        {question} [/INST]
         """
prompt = ChatPromptTemplate.from_template(template=template)
parser = StrOutputParser()


def get_all_collections():
    database = client[config['DB_NAME']]
    names = database.list_collection_names()
    coll_dict = {}
    for name in names:
        coll_dict[name] = ' '.join(str(name).capitalize().split('_'))
    return coll_dict
class VECTORDB_STORE:

    def __init__(self, coll_name):
        collection_name = self.get_collection_name(coll_name)
        collection = client[config['DB_NAME']][collection_name]
        self.vectordb_store = MongoDBAtlasVectorSearch(collection =collection,
                                        embedding= embeddings,
                                        index_name= config['VECTOR_SEARCH_INDEX'])
    @staticmethod
    def get_collection_name(coll_name):
        for key, value in get_all_collections().items():
            if coll_name == value:
                return key
        return None

    def chain(self):
        retriever = self.vectordb_store.as_retriever(search_kwargs={"k": 10})
        chain = {'context': retriever, 'question': RunnablePassthrough()} | prompt | llm_model | parser
        return chain