TalkToYourDocument / retriever /llm_manager.py
gourisankar85's picture
Upload 4 files
010d51d verified
raw
history blame contribute delete
13.8 kB
import logging
import os
from typing import List, Dict, Any, Tuple
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
class LLMManager:
DEFAULT_MODEL = "gemma2-9b-it" # Set the default model name
def __init__(self):
self.generation_llm = None
logging.info("LLMManager initialized")
# Initialize the default model during construction
try:
self.initialize_generation_llm(self.DEFAULT_MODEL)
logging.info(f"Initialized default LLM model: {self.DEFAULT_MODEL}")
except ValueError as e:
logging.error(f"Failed to initialize default LLM model: {str(e)}")
def initialize_generation_llm(self, model_name: str) -> None:
"""
Initialize the generation LLM using the Groq API.
Args:
model_name (str): The name of the model to use for generation.
Raises:
ValueError: If GROQ_API_KEY is not set.
"""
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY is not set. Please add it in your environment variables.")
os.environ["GROQ_API_KEY"] = api_key
self.generation_llm = ChatGroq(model=model_name, temperature=0.7)
self.generation_llm.name = model_name
logging.info(f"Generation LLM {model_name} initialized")
def reinitialize_llm(self, model_name: str) -> str:
"""
Reinitialize the LLM with a new model name.
Args:
model_name (str): The name of the new model to initialize.
Returns:
str: Status message indicating success or failure.
"""
try:
self.initialize_generation_llm(model_name)
return f"LLM model changed to {model_name}"
except ValueError as e:
logging.error(f"Failed to reinitialize LLM with model {model_name}: {str(e)}")
return f"Error: Failed to change LLM model: {str(e)}"
def generate_response(self, question: str, relevant_docs: List[Dict[str, Any]]) -> Tuple[str, List[Document]]:
"""
Generate a response using the generation LLM based on the question and relevant documents.
Args:
question (str): The user's query.
relevant_docs (List[Dict[str, Any]]): List of relevant document chunks with text, metadata, and scores.
Returns:
Tuple[str, List[Document]]: The LLM's response and the source documents used.
Raises:
ValueError: If the generation LLM is not initialized.
Exception: If there's an error during the QA chain invocation.
"""
if not self.generation_llm:
raise ValueError("Generation LLM is not initialized. Call initialize_generation_llm first.")
# Convert the relevant documents into LangChain Document objects
documents = [
Document(page_content=doc['text'], metadata=doc['metadata'])
for doc in relevant_docs
]
# Create a proper retriever by subclassing BaseRetriever
class SimpleRetriever(BaseRetriever):
def __init__(self, docs: List[Document], **kwargs):
super().__init__(**kwargs) # Pass kwargs to BaseRetriever
self._docs = docs # Use a private attribute to store docs
logging.debug(f"SimpleRetriever initialized with {len(docs)} documents")
def _get_relevant_documents(self, query: str) -> List[Document]:
logging.debug(f"SimpleRetriever._get_relevant_documents called with query: {query}")
return self._docs
async def _aget_relevant_documents(self, query: str) -> List[Document]:
logging.debug(f"SimpleRetriever._aget_relevant_documents called with query: {query}")
return self._docs
# Instantiate the retriever
retriever = SimpleRetriever(docs=documents)
# Create a retrieval-based question-answering chain
qa_chain = RetrievalQA.from_chain_type(
llm=self.generation_llm,
retriever=retriever,
return_source_documents=True
)
try:
result = qa_chain.invoke({"query": question})
response = result['result']
source_docs = result['source_documents']
#logging.info(f"Generated response for question: {question} : {response}")
return response, source_docs
except Exception as e:
logging.error(f"Error during QA chain invocation: {str(e)}")
raise e
def generate_summary_v0(self, chunks: any):
logging.info("Generating summary ...")
# Limit the number of chunks (for example, top 30 chunks)
limited_chunks = chunks[:30]
# Combine text from the selected chunks
full_text = "\n".join(chunk['text'] for chunk in limited_chunks)
text_length = len(full_text)
logging.info(f"Total text length (characters): {text_length}")
# Define a maximum character limit to fit in a 1024-token context.
# For many models, roughly 3200 characters is a safe limit.
MAX_CHAR_LIMIT = 3200
if text_length > MAX_CHAR_LIMIT:
logging.warning(f"Input text too long ({text_length} chars), truncating to {MAX_CHAR_LIMIT} chars.")
full_text = full_text[:MAX_CHAR_LIMIT]
# Define a custom prompt to instruct concise summarization in bullet points.
custom_prompt_template = """
You are an expert summarizer. Summarize the following text into a concise summary using bullet points.
Ensure that the final summary is no longer than 20-30 bullet points and fits within 15-20 lines.
Focus only on the most critical points.
Text to summarize:
{text}
Summary:
"""
prompt = PromptTemplate(input_variables=["text"], template=custom_prompt_template)
# Use the 'stuff' chain type to send a single LLM request with our custom prompt.
chain = load_summarize_chain(self.generation_llm, chain_type="stuff", prompt=prompt)
# Wrap the full text in a single Document object (chain expects a list of Documents)
docs = [Document(page_content=full_text)]
# Generate the summary
summary = chain.invoke(docs)
return summary['output_text']
def generate_questions(self, chunks: any):
logging.info("Generating sample questions ...")
# Use the top 30 chunks or fewer
limited_chunks = chunks[:30]
# Combine text from chunks
full_text = "\n".join(chunk['text'] for chunk in limited_chunks)
text_length = len(full_text)
logging.info(f"Total text length for questions: {text_length}")
MAX_CHAR_LIMIT = 3200
if text_length > MAX_CHAR_LIMIT:
logging.warning(f"Input text too long ({text_length} chars), truncating to {MAX_CHAR_LIMIT} chars.")
full_text = full_text[:MAX_CHAR_LIMIT]
# Prompt template for generating questions
question_prompt_template = """
You are an AI expert at creating questions from documents.
Based on the text below, generate not less than 20 insightful and highly relevant sample questions that a user might ask to better understand the content.
**Instructions:**
- Questions must be specific to the document's content and context.
- Avoid generic questions like 'What is this document about?'
- Do not include numbers, prefixes (e.g., '1.', '2.'), or explanations (e.g., '(Clarifies...)').
- Each question should be a single, clear sentence ending with a question mark.
- Focus on key concepts, processes, components, or use cases mentioned in the text.
Text:
{text}
Output format:
What is the purpose of the Communication Server in Collateral Management?
How does the system handle data encryption for secure communication?
...
"""
prompt = PromptTemplate(input_variables=["text"], template=question_prompt_template)
chain = load_summarize_chain(self.generation_llm, chain_type="stuff", prompt=prompt)
docs = [Document(page_content=full_text)]
try:
result = chain.invoke(docs)
question_output = result.get("output_text", "").strip()
# Clean and parse the output into a list of questions
questions = []
for line in question_output.split("\n"):
# Remove any leading/trailing whitespace, numbers, or bullet points
cleaned_line = line.strip().strip("-*1234567890. ").rstrip(".")
# Remove any explanation in parentheses
cleaned_line = cleaned_line.split("(")[0].strip()
# Ensure the line is a valid question (ends with '?' and is not empty)
if cleaned_line and cleaned_line.endswith("?"):
questions.append(cleaned_line)
# Limit to 10 questions
questions = questions[:10]
logging.info(f"Generated questions: {questions}")
return questions
except Exception as e:
logging.error(f"Error generating questions: {e}")
return []
def generate_summary(self, chunks: Any, toc_text: Any, summary_type: str = "medium") -> str:
"""
Generate a summary of the document using LangChain's summarization chains.
Args:
vector_store_manager: Instance of VectorStoreManager with a FAISS vector store.
summary_type (str): Type of summary ("small", "medium", "detailed").
k (int): Number of chunks to retrieve from the vector store.
include_toc (bool): Whether to include the table of contents (if available).
Returns:
str: Generated summary.
Raises:
ValueError: If summary_type is invalid or vector store is not initialized.
"""
# Define chunk retrieval parameters based on summary type
if summary_type == "small":
k = min(k, 3) # Fewer chunks for small summary
chain_type = "stuff" # Use stuff for small summaries
word_count = "50-100"
elif summary_type == "medium":
k = min(k, 10)
chain_type = "map_reduce" # Use map-reduce for medium summaries
word_count = "200-400"
else: # detailed
k = min(k, 20)
chain_type = "map_reduce" # Use map-reduce for detailed summaries
word_count = "500-1000"
# Define prompts
if chain_type == "stuff":
prompt = PromptTemplate(
input_variables=["text"],
template=(
"Generate a {summary_type} summary ({word_count} words) of the following document excerpts. "
"Focus on key points and ensure clarity. Stick strictly to the provided text:\n\n"
"{toc_prompt}{text}"
).format(
summary_type=summary_type,
word_count=word_count,
toc_prompt="Table of Contents:\n{toc_text}\n\n" if toc_text else ""
)
)
chain = load_summarize_chain(
llm=self.generation_llm,
chain_type="stuff",
prompt=prompt
)
else: # map_reduce
map_prompt = PromptTemplate(
input_variables=["text"],
template=(
"Summarize the following document excerpt in 1-2 sentences, focusing on key points. "
"Consider the document's structure from this table of contents:\n\n"
"Table of Contents:\n{toc_text}\n\nExcerpt:\n{text}"
).format(toc_text=toc_text if toc_text else "Not provided")
)
combine_prompt = PromptTemplate(
input_variables=["text"],
template=(
"Combine the following summaries into a cohesive {summary_type} summary "
"({word_count} words) of the document. Ensure clarity, avoid redundancy, and "
"organize by key themes or sections if applicable:\n\n{text}"
).format(summary_type=summary_type, word_count=word_count)
)
chain = load_summarize_chain(
llm=self.generation_llm,
chain_type="map_reduce",
map_prompt=map_prompt,
combine_prompt=combine_prompt,
return_intermediate_steps=False
)
# Run the chain
try:
logging.info(f"Generating {summary_type} summary with {len(chunks)} chunks")
summary = chain.run(chunks)
logging.info(f"{summary_type.capitalize()} summary generated successfully")
return summary
except Exception as e:
logging.error(f"Error generating summary: {str(e)}")
return f"Error generating summary: {str(e)}"