File size: 5,242 Bytes
b18c318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be6ea65
b18c318
 
 
be6ea65
b18c318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

#load & split data
from langchain.text_splitter import RecursiveCharacterTextSplitter
# embed data
from langchain_mistralai.embeddings import MistralAIEmbeddings
# vector store
from langchain_community.vectorstores import FAISS
# prompt
from langchain.prompts import PromptTemplate
# memory
from langchain.memory import ConversationBufferMemory
#llm
from langchain_mistralai.chat_models import ChatMistralAI

#chain modules
from langchain.chains import RetrievalQA



# import PyPDF2
import os
import re
from dotenv import load_dotenv
load_dotenv()
from collections import defaultdict

api_key = os.environ.get("MISTRAL_API_KEY")

class RagModule():
    def __init__(self):
        self.mistral_api_key = api_key
        self.model_name_embedding = "mistral-embed"
        self.embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key)
        self.chunk_size = 1000
        self.chunk_overlap = 120
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
        self.db_faiss_path = "data/vector_store"
        #params llm
        self.llm_model = "mistral-small"
        self.max_new_tokens = 512
        self.top_p = 0.5
        self.temperature = 0.1
    



    def split_text(self, text:str) -> list:
        """Split the text into chunk

        Args:
            text (str): _description_

        Returns:
            list: _description_
        """
        texts = self.text_splitter.split_text(text)
        return texts
    
    def get_metadata(self, texts:list) -> list:
        """_summary_

        Args:
            texts (list): _description_

        Returns:
            list: _description_
        """
        metadatas = [{"source": f'Paragraphe: {i}'} for i in range(len(texts))]
        return metadatas
    
    def get_faiss_db(self):
        """load local faiss vector store containing all embeddings 

        """
        db = FAISS.load_local(self.db_faiss_path, self.embedding_model)
        return db

    def set_custom_prompt(self, prompt_template:str):
        """Instantiate prompt template for Q&A retreival for each vectore stores

        Args:
            prompt_template (str): description of the prompt
            input_variables (list): variables in the prompt
        """
        prompt = PromptTemplate.from_template(
            template=prompt_template,
            )
    
        return prompt
    
    def load_mistral(self):
        """instantiate LLM
        """

        model_kwargs = {
        "mistral_api_key": self.mistral_api_key,
        "model": self.llm_model,
        "max_new_tokens": self.max_new_tokens,
        "top_p": self.top_p,
        "temperature": self.temperature,
        }

        llm = ChatMistralAI(**model_kwargs)
        
        return llm

    def retrieval_qa_memory_chain(self, db, prompt_template):
        """_summary_
        """
        llm = self.load_mistral()
        prompt = self.set_custom_prompt(prompt_template)
        memory = ConversationBufferMemory(
            memory_key = 'history',
            input_key = 'question'
        )
        chain_type_kwargs= {
            "prompt" : prompt,
            "memory" : memory
            }
        
        qa_chain = RetrievalQA.from_chain_type(
            llm = llm,
            chain_type = 'stuff',
            retriever = db.as_retriever(search_kwargs={"k":5}),
            chain_type_kwargs = chain_type_kwargs,
            return_source_documents = True,
            )

        return qa_chain

    def retrieval_qa_chain(self, db, prompt_template):
        """_summary_
        """
        llm = self.load_llm()
        prompt = self.set_custom_prompt(prompt_template)
       
        chain_type_kwargs= {
            "prompt" : prompt,
            }
        
        qa_chain = RetrievalQA.from_chain_type(
            llm = llm,
            chain_type = 'stuff',
            retriever = db.as_retriever(search_kwargs={"k":3}),
            chain_type_kwargs = chain_type_kwargs,
            return_source_documents = True,
            )

        return qa_chain
    
    
    
    def get_sources_document(self, source_documents:list) -> dict:
        """generate dictionnary with path (as a key) and list of pages associated to one path

        Args:
            source_document (list): list of documents containing source_document of rag response

        Returns:
            dict: {
                path/to/file1 : [0, 1, 3],
                path/to/file2 : [5, 2]
                }
        """
        sources = defaultdict(list)
        for doc in source_documents:
            sources[doc.metadata["source"]].append(doc.metadata["page"])
        
        return sources

    def shape_answer_with_source(self, answer: str, sources: dict):
        """_summary_

        Args:
            answer (str): _description_
            source (dict): _description_
        """
        pattern = r"^(.+)\/([^\/]+)$"
        
        source_msg = ""
        for path, page in sources.items():
            file = re.findall(pattern, path)[0][1]
            source_msg += f"\nFichier: {file} - Page: {page}"
    
        answer += f"\n{source_msg}"

        return answer