File size: 9,193 Bytes
579ab0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.retrievers.document_compressors import CohereRerank
from langchain_community.retrievers import BM25Retriever
import tiktoken

# ONLY USE WITH DOCKER, then uncomment
import pysqlite3
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")

import chromadb
import chainlit as cl
from langdetect import detect
from langchain_community.vectorstores import Chroma
from typing import List
from typing import Tuple
import os
import json
from langchain.docstore.document import Document
from prompts import get_system_prompt, get_human_prompt, get_system_prompt_template, get_full_prompt

class HelperMethods:
    """
    Helper class with all important methods for the RAG pipeline.
    """

    def __init__(self):
        pass

    def _get_embedding_model(self):
        """
        Gets the finetuned embedding model based on bge-large-en-v1.5
        """
        path = "Basti8499/bge-large-en-v1.5-ISO-27001"
        model = HuggingFaceEmbeddings(model_name=path)
        return model


    async def get_LLM(self):
        """
        Initializes the gpt-4.5 16k LLM
        """
        llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0, max_tokens=680, streaming=True, api_key=cl.user_session.get("env")["OPENAI_API_KEY"])
        max_context_size = 16385
        return llm, max_context_size


    def get_index_vector_db(self, collection_name: str):
        """
        Gets the index vector base based on the collection name, if existent.
        """
        new_client = chromadb.PersistentClient(path=os.environ.get("CHROMA_PATH"))

        # Check if collection already exists
        collection_exists = True
        try:
            new_client.get_collection(collection_name)
        except ValueError as e:
            collection_exists = False

        if not collection_exists:
            raise Exception("Error, raised exception: Collection does not exist.")
        else:
            embedding_model = self._get_embedding_model()
            vectordb = Chroma(client=new_client, collection_name=collection_name, embedding_function=embedding_model)

            return vectordb


    def _load_documents(self, file_path: str) -> List[Document]:
        documents = []
        with open(file_path, "r") as jsonl_file:
            for line in jsonl_file:
                data = json.loads(line)
                obj = Document(**data)
                documents.append(obj)
        return documents

    def check_if_english(self, query: str) -> bool:
        """
        Uses the langdetect library based on  Google's language-detection library to check which language the query is in.
        Returns True if it is English.
        """
        language = detect(query)
        return language == "en"
        
    def check_if_relevant(self, docs: List[Document]) -> bool:

        relevance_scores = [doc.metadata["relevance_score"] for doc in docs]
        avg_score = sum(relevance_scores) / len(relevance_scores)
        return avg_score > 0.75

    def retrieve_contexts(self, vectordb, query: str, k: int = 8, rerank_k: int = 50, dense_percent: float = 0.5) -> List[Document]:
        """
        Retrieves the documents from the vector database by using a hybrid approach (dense (similarity search) + sparse (BM25)) and the Cohere re-ranking endpoint.
        """
        dense_k = int(rerank_k * dense_percent)
        sparse_k = rerank_k - dense_k

        # Sparse Retrieval
        sparse_documents = self._load_documents(f"./../sparse_index/sparse_1536_264")
        bm25_retriever = BM25Retriever.from_documents(sparse_documents)
        bm25_retriever.k = sparse_k
        result_documents_BM25 = bm25_retriever.get_relevant_documents(query)

        # Dense Retrieval
        result_documents_Dense = vectordb.similarity_search(query, k=dense_k)

        result_documents_all = []
        result_documents_all.extend(result_documents_BM25)
        result_documents_all.extend(result_documents_Dense)

        # Only get unique documents and remove duplicates that were retrieved in both sparse and dense
        unique_documents_dict = {}
        for doc in result_documents_all:
            if doc.page_content not in unique_documents_dict:
                unique_documents_dict[doc.page_content] = doc
        result_documents_unique = list(unique_documents_dict.values())

        # Re-ranking with Cohere
        compressor = CohereRerank(top_n=k, user_agent="langchain", cohere_api_key=cl.user_session.get("env")["COHERE_API_KEY"])
        result_documents = compressor.compress_documents(documents=result_documents_unique, query=query)

        return result_documents


    def is_context_size_valid(self, contexts: List[Document], query: str, max_context_size: int) -> bool:
        """
        Checks if the context size is valid with the cl100k tokenizer which is used for OpenAI LLM's.
        """
        # Transform List[Document] into List[str]
        concatenated_contexts = ""
        for index, document in enumerate(contexts):
            original_text = document.metadata.get("original_text", "")
            # Replace curly brackets, as otherwise problems can be encountered with formatting the prompt
            original_text = original_text.replace("{", "").replace("}", "")
            concatenated_contexts += f"{index+1}. {original_text}\n\n"

        if not query.endswith("?"):
            query = query + "?"

        # Get the prompts
        system_str, system_prompt = get_system_prompt()
        human_str, human_prompt = get_human_prompt(concatenated_contexts, query)
        full_prompt = system_str + "\n" + human_str

        # Count token length
        tokenizer = tiktoken.get_encoding("cl100k_base")
        token_length = len(tokenizer.encode(full_prompt))

        if token_length <= max_context_size:
            return True
        else:
            return False

    def get_full_prompt_sources_and_template(self, contexts: List[Document], llm, prompt: str) -> Tuple[str, str, str, str]:

        # Check if the query is aimed at a template and check if the context documents also have a template
        # If it is a template question the query and system prompt has to be altered
        # Only check first two because otherwise the re-ranked score is not high enough to assume that the retrieved template is valid for that question
        is_template_question = False
        template_path = ""
        template_source = ""
        if "template" in prompt.lower():
            for context in contexts[:2]:
                if "template_path" in context.metadata:
                    is_template_question = True
                    template_path = context.metadata["template_path"]
                    template_source = context.metadata["source"]
                    break

        # Concatenate all document texts and sources
        concatenated_contexts = ""
        concatenated_sources = ""
        seen_sources = set()       
        if is_template_question:

            for index, document in enumerate(contexts[:2]):
                original_text = document.metadata.get('original_text', '')
                # Replace curly brackets, as otherwise problems can be encountered with formatting the prompt
                original_text = original_text.replace("{", "").replace("}", "")
                concatenated_contexts += f"{index+1}. {original_text}\n\n"

                source = document.metadata.get('source', '')
                if source not in seen_sources: 
                    concatenated_sources += f"{len(seen_sources) + 1}. {source}\n"
                    seen_sources.add(source)

        else:
            for index, document in enumerate(contexts):
                original_text = document.metadata.get('original_text', '')
                # Replace curly brackets, as otherwise problems can be encountered with formatting the prompt
                original_text = original_text.replace("{", "").replace("}", "")
                concatenated_contexts += f"{index+1}. {original_text}\n\n"

                source = document.metadata.get('source', '')
                if source not in seen_sources: 
                    concatenated_sources += f"{len(seen_sources) + 1}. {source}\n"
                    seen_sources.add(source)

        # Check if question mark is at the end of the prompt
        if not prompt.endswith("?"):
            prompt = prompt + "?"

        if is_template_question:
            system_str, system_prompt = get_system_prompt_template()
            human_str, human_prompt = get_human_prompt(concatenated_contexts, prompt)
            full_prompt = get_full_prompt(system_prompt, human_prompt)
            #answer = llm.invoke(full_prompt).content
        else:
            system_str, system_prompt = get_system_prompt()
            human_str, human_prompt = get_human_prompt(concatenated_contexts, prompt)
            full_prompt = get_full_prompt(system_prompt, human_prompt)
            #answer = llm.invoke(full_prompt).content
        
        return full_prompt, concatenated_sources, template_path, template_source