Basti8499's picture
Adds all necessary files
579ab0b verified
raw
history blame
No virus
9.19 kB
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain.retrievers.document_compressors import CohereRerank
from langchain_community.retrievers import BM25Retriever
import tiktoken
# ONLY USE WITH DOCKER, then uncomment
import pysqlite3
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
import chromadb
import chainlit as cl
from langdetect import detect
from langchain_community.vectorstores import Chroma
from typing import List
from typing import Tuple
import os
import json
from langchain.docstore.document import Document
from prompts import get_system_prompt, get_human_prompt, get_system_prompt_template, get_full_prompt
class HelperMethods:
"""
Helper class with all important methods for the RAG pipeline.
"""
def __init__(self):
pass
def _get_embedding_model(self):
"""
Gets the finetuned embedding model based on bge-large-en-v1.5
"""
path = "Basti8499/bge-large-en-v1.5-ISO-27001"
model = HuggingFaceEmbeddings(model_name=path)
return model
async def get_LLM(self):
"""
Initializes the gpt-4.5 16k LLM
"""
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", temperature=0, max_tokens=680, streaming=True, api_key=cl.user_session.get("env")["OPENAI_API_KEY"])
max_context_size = 16385
return llm, max_context_size
def get_index_vector_db(self, collection_name: str):
"""
Gets the index vector base based on the collection name, if existent.
"""
new_client = chromadb.PersistentClient(path=os.environ.get("CHROMA_PATH"))
# Check if collection already exists
collection_exists = True
try:
new_client.get_collection(collection_name)
except ValueError as e:
collection_exists = False
if not collection_exists:
raise Exception("Error, raised exception: Collection does not exist.")
else:
embedding_model = self._get_embedding_model()
vectordb = Chroma(client=new_client, collection_name=collection_name, embedding_function=embedding_model)
return vectordb
def _load_documents(self, file_path: str) -> List[Document]:
documents = []
with open(file_path, "r") as jsonl_file:
for line in jsonl_file:
data = json.loads(line)
obj = Document(**data)
documents.append(obj)
return documents
def check_if_english(self, query: str) -> bool:
"""
Uses the langdetect library based on Google's language-detection library to check which language the query is in.
Returns True if it is English.
"""
language = detect(query)
return language == "en"
def check_if_relevant(self, docs: List[Document]) -> bool:
relevance_scores = [doc.metadata["relevance_score"] for doc in docs]
avg_score = sum(relevance_scores) / len(relevance_scores)
return avg_score > 0.75
def retrieve_contexts(self, vectordb, query: str, k: int = 8, rerank_k: int = 50, dense_percent: float = 0.5) -> List[Document]:
"""
Retrieves the documents from the vector database by using a hybrid approach (dense (similarity search) + sparse (BM25)) and the Cohere re-ranking endpoint.
"""
dense_k = int(rerank_k * dense_percent)
sparse_k = rerank_k - dense_k
# Sparse Retrieval
sparse_documents = self._load_documents(f"./../sparse_index/sparse_1536_264")
bm25_retriever = BM25Retriever.from_documents(sparse_documents)
bm25_retriever.k = sparse_k
result_documents_BM25 = bm25_retriever.get_relevant_documents(query)
# Dense Retrieval
result_documents_Dense = vectordb.similarity_search(query, k=dense_k)
result_documents_all = []
result_documents_all.extend(result_documents_BM25)
result_documents_all.extend(result_documents_Dense)
# Only get unique documents and remove duplicates that were retrieved in both sparse and dense
unique_documents_dict = {}
for doc in result_documents_all:
if doc.page_content not in unique_documents_dict:
unique_documents_dict[doc.page_content] = doc
result_documents_unique = list(unique_documents_dict.values())
# Re-ranking with Cohere
compressor = CohereRerank(top_n=k, user_agent="langchain", cohere_api_key=cl.user_session.get("env")["COHERE_API_KEY"])
result_documents = compressor.compress_documents(documents=result_documents_unique, query=query)
return result_documents
def is_context_size_valid(self, contexts: List[Document], query: str, max_context_size: int) -> bool:
"""
Checks if the context size is valid with the cl100k tokenizer which is used for OpenAI LLM's.
"""
# Transform List[Document] into List[str]
concatenated_contexts = ""
for index, document in enumerate(contexts):
original_text = document.metadata.get("original_text", "")
# Replace curly brackets, as otherwise problems can be encountered with formatting the prompt
original_text = original_text.replace("{", "").replace("}", "")
concatenated_contexts += f"{index+1}. {original_text}\n\n"
if not query.endswith("?"):
query = query + "?"
# Get the prompts
system_str, system_prompt = get_system_prompt()
human_str, human_prompt = get_human_prompt(concatenated_contexts, query)
full_prompt = system_str + "\n" + human_str
# Count token length
tokenizer = tiktoken.get_encoding("cl100k_base")
token_length = len(tokenizer.encode(full_prompt))
if token_length <= max_context_size:
return True
else:
return False
def get_full_prompt_sources_and_template(self, contexts: List[Document], llm, prompt: str) -> Tuple[str, str, str, str]:
# Check if the query is aimed at a template and check if the context documents also have a template
# If it is a template question the query and system prompt has to be altered
# Only check first two because otherwise the re-ranked score is not high enough to assume that the retrieved template is valid for that question
is_template_question = False
template_path = ""
template_source = ""
if "template" in prompt.lower():
for context in contexts[:2]:
if "template_path" in context.metadata:
is_template_question = True
template_path = context.metadata["template_path"]
template_source = context.metadata["source"]
break
# Concatenate all document texts and sources
concatenated_contexts = ""
concatenated_sources = ""
seen_sources = set()
if is_template_question:
for index, document in enumerate(contexts[:2]):
original_text = document.metadata.get('original_text', '')
# Replace curly brackets, as otherwise problems can be encountered with formatting the prompt
original_text = original_text.replace("{", "").replace("}", "")
concatenated_contexts += f"{index+1}. {original_text}\n\n"
source = document.metadata.get('source', '')
if source not in seen_sources:
concatenated_sources += f"{len(seen_sources) + 1}. {source}\n"
seen_sources.add(source)
else:
for index, document in enumerate(contexts):
original_text = document.metadata.get('original_text', '')
# Replace curly brackets, as otherwise problems can be encountered with formatting the prompt
original_text = original_text.replace("{", "").replace("}", "")
concatenated_contexts += f"{index+1}. {original_text}\n\n"
source = document.metadata.get('source', '')
if source not in seen_sources:
concatenated_sources += f"{len(seen_sources) + 1}. {source}\n"
seen_sources.add(source)
# Check if question mark is at the end of the prompt
if not prompt.endswith("?"):
prompt = prompt + "?"
if is_template_question:
system_str, system_prompt = get_system_prompt_template()
human_str, human_prompt = get_human_prompt(concatenated_contexts, prompt)
full_prompt = get_full_prompt(system_prompt, human_prompt)
#answer = llm.invoke(full_prompt).content
else:
system_str, system_prompt = get_system_prompt()
human_str, human_prompt = get_human_prompt(concatenated_contexts, prompt)
full_prompt = get_full_prompt(system_prompt, human_prompt)
#answer = llm.invoke(full_prompt).content
return full_prompt, concatenated_sources, template_path, template_source