from typing import List, Union from dotenv import find_dotenv, load_dotenv from langchain.chains import RetrievalQA from langchain.chat_models import init_chat_model from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_huggingface.embeddings import HuggingFaceEmbeddings def get_default_splitter() -> RecursiveCharacterTextSplitter: """Returns a pre-configured text splitter.""" return RecursiveCharacterTextSplitter( # Using markdown headers as separators is a good strategy separators=["\n### ", "\n## ", "\n# ", "\n\n", "\n", " "], chunk_size=1000, chunk_overlap=200, ) def get_default_embeddings() -> HuggingFaceEmbeddings: """Returns a pre-configured embedding model.""" return HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'} ) def build_retriever( data: Union[str, List[Document]], splitter: RecursiveCharacterTextSplitter = None, embeddings: HuggingFaceEmbeddings = None, top_k: int = 5): """Builds a retriever from either a raw text string or a list of documents. Args: Args: data (Union[str, List[Document]]): The source data to build the retriever from. splitter (RecursiveCharacterTextSplitter, optional): The text splitter to use. Defaults to get_default_splitter(). embeddings (HuggingFaceEmbeddings, optional): The embedding model to use. Defaults to get_default_embeddings(). top_k (int, optional): The number of top results to return. Defaults to 5. """ splitter = splitter or get_default_splitter() embeddings = embeddings or get_default_embeddings() if isinstance(data, str): # If the input is a raw string, split it into chunks first chunks = splitter.split_text(data) # Then convert those chunks into Document objects docs = [Document(page_content=chunk) for chunk in chunks] elif isinstance(data, list): # If the input is already a list of documents, split them directly docs = splitter.split_documents(data) else: raise ValueError(f"Unsupported data type: {type(data)}. Must be str or List[Document].") index = FAISS.from_documents(docs, embeddings) return index.as_retriever(search_kwargs={"k": top_k}) def create_retrieval_qa( retriever, llm=None ) -> RetrievalQA: """Creates a RetrievalQA instance from a given retriever and LLM. Args: retriever (BaseRetriever): The retriever to be used by the QA chain. llm (LLM, optional): The language model to use. If not provided, a default model will be initialized. """ if llm is None: load_dotenv(find_dotenv()) llm = init_chat_model("groq:meta-llama/llama-4-scout-17b-16e-instruct") return RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, )