Spaces:
Sleeping
Sleeping
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_openai import ChatOpenAI | |
from langchain.retrievers.document_compressors import CohereRerank | |
from langchain_community.retrievers import BM25Retriever | |
import tiktoken | |
# ONLY USE WITH DOCKER, then uncomment | |
import pysqlite3 | |
import sys | |
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") | |
import chromadb | |
import chainlit as cl | |
from langdetect import detect | |
from langchain_community.vectorstores import Chroma | |
from typing import List | |
from typing import Tuple | |
import os | |
import json | |
from langchain.docstore.document import Document | |
from prompts import get_system_prompt, get_human_prompt, get_system_prompt_template, get_full_prompt | |
class HelperMethods: | |
""" | |
Helper class with all important methods for the RAG pipeline. | |
""" | |
def __init__(self): | |
pass | |
def _get_embedding_model(self): | |
""" | |
Gets the finetuned embedding model based on bge-large-en-v1.5 | |
""" | |
path = "Basti8499/bge-large-en-v1.5-ISO-27001" | |
model = HuggingFaceEmbeddings(model_name=path) | |
return model | |
async def get_LLM(self): | |
""" | |
Initializes the gpt-4.5 16k LLM | |
""" | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0, max_tokens=680, streaming=True, api_key=cl.user_session.get("env")["OPENAI_API_KEY"]) | |
max_context_size = 16385 | |
return llm, max_context_size | |
def get_index_vector_db(self, collection_name: str): | |
""" | |
Gets the index vector base based on the collection name, if existent. | |
""" | |
new_client = chromadb.PersistentClient(path=os.environ.get("CHROMA_PATH")) | |
# Check if collection already exists | |
collection_exists = True | |
try: | |
new_client.get_collection(collection_name) | |
except ValueError as e: | |
collection_exists = False | |
if not collection_exists: | |
raise Exception("Error, raised exception: Collection does not exist.") | |
else: | |
embedding_model = self._get_embedding_model() | |
vectordb = Chroma(client=new_client, collection_name=collection_name, embedding_function=embedding_model) | |
return vectordb | |
def _load_documents(self, file_path: str) -> List[Document]: | |
documents = [] | |
with open(file_path, "r") as jsonl_file: | |
for line in jsonl_file: | |
data = json.loads(line) | |
obj = Document(**data) | |
documents.append(obj) | |
return documents | |
def check_if_english(self, query: str) -> bool: | |
""" | |
Uses the langdetect library based on Google's language-detection library to check which language the query is in. | |
Returns True if it is English. | |
""" | |
language = detect(query) | |
return language == "en" | |
def check_if_relevant(self, docs: List[Document]) -> bool: | |
relevance_scores = [doc.metadata["relevance_score"] for doc in docs] | |
avg_score = sum(relevance_scores) / len(relevance_scores) | |
return avg_score > 0.75 | |
def retrieve_contexts(self, vectordb, query: str, k: int = 8, rerank_k: int = 50, dense_percent: float = 0.5) -> List[Document]: | |
""" | |
Retrieves the documents from the vector database by using a hybrid approach (dense (similarity search) + sparse (BM25)) and the Cohere re-ranking endpoint. | |
""" | |
dense_k = int(rerank_k * dense_percent) | |
sparse_k = rerank_k - dense_k | |
# Sparse Retrieval | |
sparse_documents = self._load_documents(f"./../sparse_index/sparse_1536_264") | |
bm25_retriever = BM25Retriever.from_documents(sparse_documents) | |
bm25_retriever.k = sparse_k | |
result_documents_BM25 = bm25_retriever.get_relevant_documents(query) | |
# Dense Retrieval | |
result_documents_Dense = vectordb.similarity_search(query, k=dense_k) | |
result_documents_all = [] | |
result_documents_all.extend(result_documents_BM25) | |
result_documents_all.extend(result_documents_Dense) | |
# Only get unique documents and remove duplicates that were retrieved in both sparse and dense | |
unique_documents_dict = {} | |
for doc in result_documents_all: | |
if doc.page_content not in unique_documents_dict: | |
unique_documents_dict[doc.page_content] = doc | |
result_documents_unique = list(unique_documents_dict.values()) | |
# Re-ranking with Cohere | |
compressor = CohereRerank(top_n=k, user_agent="langchain", cohere_api_key=cl.user_session.get("env")["COHERE_API_KEY"]) | |
result_documents = compressor.compress_documents(documents=result_documents_unique, query=query) | |
return result_documents | |
def is_context_size_valid(self, contexts: List[Document], query: str, max_context_size: int) -> bool: | |
""" | |
Checks if the context size is valid with the cl100k tokenizer which is used for OpenAI LLM's. | |
""" | |
# Transform List[Document] into List[str] | |
concatenated_contexts = "" | |
for index, document in enumerate(contexts): | |
original_text = document.metadata.get("original_text", "") | |
# Replace curly brackets, as otherwise problems can be encountered with formatting the prompt | |
original_text = original_text.replace("{", "").replace("}", "") | |
concatenated_contexts += f"{index+1}. {original_text}\n\n" | |
if not query.endswith("?"): | |
query = query + "?" | |
# Get the prompts | |
system_str, system_prompt = get_system_prompt() | |
human_str, human_prompt = get_human_prompt(concatenated_contexts, query) | |
full_prompt = system_str + "\n" + human_str | |
# Count token length | |
tokenizer = tiktoken.get_encoding("cl100k_base") | |
token_length = len(tokenizer.encode(full_prompt)) | |
if token_length <= max_context_size: | |
return True | |
else: | |
return False | |
def get_full_prompt_sources_and_template(self, contexts: List[Document], llm, prompt: str) -> Tuple[str, str, str, str]: | |
# Check if the query is aimed at a template and check if the context documents also have a template | |
# If it is a template question the query and system prompt has to be altered | |
# Only check first two because otherwise the re-ranked score is not high enough to assume that the retrieved template is valid for that question | |
is_template_question = False | |
template_path = "" | |
template_source = "" | |
if "template" in prompt.lower(): | |
for context in contexts[:2]: | |
if "template_path" in context.metadata: | |
is_template_question = True | |
template_path = context.metadata["template_path"] | |
template_source = context.metadata["source"] | |
break | |
# Concatenate all document texts and sources | |
concatenated_contexts = "" | |
concatenated_sources = "" | |
seen_sources = set() | |
if is_template_question: | |
for index, document in enumerate(contexts[:2]): | |
original_text = document.metadata.get('original_text', '') | |
# Replace curly brackets, as otherwise problems can be encountered with formatting the prompt | |
original_text = original_text.replace("{", "").replace("}", "") | |
concatenated_contexts += f"{index+1}. {original_text}\n\n" | |
source = document.metadata.get('source', '') | |
if source not in seen_sources: | |
concatenated_sources += f"{len(seen_sources) + 1}. {source}\n" | |
seen_sources.add(source) | |
else: | |
for index, document in enumerate(contexts): | |
original_text = document.metadata.get('original_text', '') | |
# Replace curly brackets, as otherwise problems can be encountered with formatting the prompt | |
original_text = original_text.replace("{", "").replace("}", "") | |
concatenated_contexts += f"{index+1}. {original_text}\n\n" | |
source = document.metadata.get('source', '') | |
if source not in seen_sources: | |
concatenated_sources += f"{len(seen_sources) + 1}. {source}\n" | |
seen_sources.add(source) | |
# Check if question mark is at the end of the prompt | |
if not prompt.endswith("?"): | |
prompt = prompt + "?" | |
if is_template_question: | |
system_str, system_prompt = get_system_prompt_template() | |
human_str, human_prompt = get_human_prompt(concatenated_contexts, prompt) | |
full_prompt = get_full_prompt(system_prompt, human_prompt) | |
#answer = llm.invoke(full_prompt).content | |
else: | |
system_str, system_prompt = get_system_prompt() | |
human_str, human_prompt = get_human_prompt(concatenated_contexts, prompt) | |
full_prompt = get_full_prompt(system_prompt, human_prompt) | |
#answer = llm.invoke(full_prompt).content | |
return full_prompt, concatenated_sources, template_path, template_source |