|
from typing import Any, AsyncGenerator |
|
from app.core.models import LocalLLM, Embedder, Reranker, GeminiLLM, GeminiEmbed, Wrapper |
|
from app.core.processor import DocumentProcessor |
|
from app.core.database import VectorDatabase |
|
import time |
|
import os |
|
from app.settings import settings, BASE_DIR |
|
|
|
|
|
class RagSystem: |
|
def __init__(self): |
|
self.embedder = ( |
|
GeminiEmbed() |
|
if settings.use_gemini |
|
else Embedder(model=settings.models.embedder_model) |
|
) |
|
self.reranker = Reranker(model=settings.models.reranker_model) |
|
self.processor = DocumentProcessor(self.embedder) |
|
self.db = VectorDatabase(embedder=self.embedder) |
|
self.llm = GeminiLLM() if settings.use_gemini else LocalLLM() |
|
self.wrapper = Wrapper() |
|
|
|
""" |
|
Provides a prompt with substituted context from chunks |
|
|
|
TODO: add template to prompt without docs |
|
""" |
|
|
|
def get_general_prompt(self, user_prompt: str, collection_name: str) -> str: |
|
enhanced_prompt = self.enhance_prompt(user_prompt.strip()) |
|
|
|
relevant_chunks = self.db.search(collection_name, query=enhanced_prompt, top_k=30) |
|
if relevant_chunks is not None and len(relevant_chunks) > 0: |
|
ranks = self.reranker.rank(query=enhanced_prompt, chunks=relevant_chunks)[ |
|
: min(5, len(relevant_chunks)) |
|
] |
|
relevant_chunks = [relevant_chunks[rank["corpus_id"]] for rank in ranks] |
|
else: |
|
relevant_chunks = [] |
|
|
|
sources = "" |
|
prompt = "" |
|
|
|
for chunk in relevant_chunks: |
|
citation = ( |
|
f"[Source: {chunk.filename}, " |
|
f"Page: {chunk.page_number}, " |
|
f"Lines: {chunk.start_line}-{chunk.end_line}, " |
|
f"Start: {chunk.start_index}]\n\n" |
|
) |
|
sources += f"Original text:\n{chunk.get_raw_text()}\nCitation:{citation}" |
|
|
|
with open( |
|
os.path.join(BASE_DIR, "app", "prompt_templates", "test2.txt") |
|
) as prompt_file: |
|
prompt = prompt_file.read() |
|
|
|
prompt += ( |
|
"**QUESTION**: " |
|
f"{enhanced_prompt}\n" |
|
"**CONTEXT DOCUMENTS**:\n" |
|
f"{sources}\n" |
|
) |
|
print(prompt) |
|
return prompt |
|
|
|
def enhance_prompt(self, original_prompt: str) -> str: |
|
path_to_wrapping_prompt = os.path.join(BASE_DIR, "app", "prompt_templates", "wrapper.txt") |
|
enhanced_prompt = "" |
|
with open(path_to_wrapping_prompt, "r") as f: |
|
enhanced_prompt = f.read().replace("[USERS_PROMPT]", original_prompt) |
|
return self.wrapper.wrap(enhanced_prompt) |
|
|
|
""" |
|
Splits the list of documents into groups with 'split_by' docs (done to avoid qdrant_client connection error handling), loads them, |
|
splits into chunks, and saves to db |
|
""" |
|
|
|
def upload_documents( |
|
self, |
|
collection_name: str, |
|
documents: list[str], |
|
split_by: int = 3, |
|
debug_mode: bool = True, |
|
) -> None: |
|
|
|
for i in range(0, len(documents), split_by): |
|
|
|
if debug_mode: |
|
print( |
|
"<" |
|
+ "-" * 10 |
|
+ "New document group is taken into processing" |
|
+ "-" * 10 |
|
+ ">" |
|
) |
|
|
|
docs = documents[i : i + split_by] |
|
|
|
loading_time = 0 |
|
chunk_generating_time = 0 |
|
db_saving_time = 0 |
|
|
|
print("Start loading the documents") |
|
start = time.time() |
|
self.processor.load_documents(documents=docs, add_to_unprocessed=True) |
|
loading_time = time.time() - start |
|
|
|
print("Start loading chunk generation") |
|
start = time.time() |
|
self.processor.generate_chunks() |
|
chunk_generating_time = time.time() - start |
|
|
|
print("Start saving to db") |
|
start = time.time() |
|
self.db.store(collection_name, self.processor.get_and_save_unsaved_chunks()) |
|
db_saving_time = time.time() - start |
|
|
|
if debug_mode: |
|
print( |
|
f"loading time = {loading_time}, chunk generation time = {chunk_generating_time}, saving time = {db_saving_time}\n" |
|
) |
|
|
|
def extract_text(self, response) -> str: |
|
text = "" |
|
try: |
|
text = response.candidates[0].content.parts[0].text |
|
except Exception as e: |
|
print(e) |
|
return text |
|
|
|
""" |
|
Produces answer to user's request. First, finds the most relevant chunks, generates prompt with them, and asks llm |
|
""" |
|
|
|
async def generate_response( |
|
self, collection_name: str, user_prompt: str, stream: bool = True |
|
) -> str: |
|
general_prompt = self.get_general_prompt( |
|
user_prompt=user_prompt, collection_name=collection_name |
|
) |
|
|
|
return self.llm.get_response(prompt=general_prompt) |
|
|
|
async def generate_response_stream( |
|
self, collection_name: str, user_prompt: str, stream: bool = True |
|
) -> AsyncGenerator[Any, Any]: |
|
general_prompt = self.get_general_prompt( |
|
user_prompt=user_prompt, collection_name=collection_name |
|
) |
|
|
|
async for chunk in self.llm.get_streaming_response( |
|
prompt=general_prompt, stream=True |
|
): |
|
yield self.extract_text(chunk) |
|
|
|
""" |
|
Produces the list of the most relevant chunkВs |
|
""" |
|
|
|
def get_relevant_chunks(self, collection_name: str, query): |
|
relevant_chunks = self.db.search(collection_name, query=query, top_k=15) |
|
relevant_chunks = [ |
|
relevant_chunks[ranked["corpus_id"]] |
|
for ranked in self.reranker.rank(query=query, chunks=relevant_chunks) |
|
] |
|
return relevant_chunks |
|
|
|
def create_new_collection(self, collection_name: str) -> None: |
|
self.db.create_collection(collection_name) |
|
|
|
def get_collections_names(self) -> list[str]: |
|
return self.db.get_collections() |
|
|