from typing import Optional, List, Tuple from langchain.docstore.document import Document as LangchainDocument from rank_bm25 import BM25Okapi from langchain_community.vectorstores import FAISS from ragatouille import RAGPretrainedModel from litellm import completion import os import retriver import config class RAGAnswerGenerator: def __init__(self, docs: List[LangchainDocument], bm25: BM25Okapi, knowledge_index: FAISS, reranker: Optional[RAGPretrainedModel] = None): self.bm25 = bm25 self.knowledge_index = knowledge_index self.docs = docs self.reranker = reranker self.llm_key = os.environ['GROQ_API_KEY'] def retrieve_documents( self, question: str, num_retrieved_docs: int, bm_25_flag: bool, semantic_flag: bool ) -> List[str]: print("=> Retrieving documents...") relevant_docs = [] if bm_25_flag or semantic_flag: result = retriver.search( self.docs, self.bm25, self.knowledge_index, question, use_bm25=bm_25_flag, use_semantic_search=semantic_flag, top_k=num_retrieved_docs ) if bm_25_flag and semantic_flag: relevant_docs = [doc.page_content for doc in result] return relevant_docs elif bm_25_flag: relevant_docs = result return relevant_docs elif semantic_flag: relevant_docs = [doc.page_content for doc in result] return relevant_docs def rerank_documents(self, question: str, documents: List[str], num_docs_final: int) -> List[str]: if self.reranker and documents: print("=> Reranking documents...") reranked_docs = self.reranker.rerank(question, documents, k=num_docs_final) return [doc["content"] for doc in reranked_docs] return documents[:num_docs_final] def format_context(self, documents: List[str]) -> str: if not documents: return "No retrieved documents available." return "\n".join([f"[{i + 1}] {doc}" for i, doc in enumerate(documents)]) def generate_answer( self, question: str, context: str, temperature: float, ) -> str: print("=> Generating answer...") if context.strip() == "No retrieved documents available.": response = completion( model="groq/llama3-8b-8192", messages=[ {"role": "system", "content": config.LLM_ONLY_PROMPT}, {"role": "user", "content": f"Question: {question}"} ], api_key=self.llm_key, temperature=temperature ) else: response = completion( model="groq/llama3-8b-8192", messages=[ {"role": "system", "content": config.RAG_PROMPT}, {"role": "user", "content": f""" Context: {context} Question: {question} """} ], api_key=self.llm_key, temperature=temperature ) return response.get("choices", [{}])[0].get("message", {}).get("content", "No response content found") def answer(self, question: str, temperature: float, num_retrieved_docs: int = 30, num_docs_final: int = 5, bm_25_flag=True, semantic_flag=True) -> Tuple[str, List[str]]: relevant_docs = self.retrieve_documents(question, num_retrieved_docs, bm_25_flag, semantic_flag) print(len(relevant_docs)) relevant_docs = self.rerank_documents(question, relevant_docs, num_docs_final) print(len(relevant_docs)) context = self.format_context(relevant_docs) answer = self.generate_answer(question, context, temperature) document_list = [f"[{i + 1}] {doc}" for i, doc in enumerate(relevant_docs)] if relevant_docs else [] return answer, document_list