Spaces:
Sleeping
Sleeping
| """RAG (Retrieval-Augmented Generation) chain implementation""" | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnableParallel, RunnablePassthrough | |
| from legisqa_local.core.llm import get_llm | |
| from legisqa_local.core.vectorstore import get_vectorstore, get_vectorstore_filter | |
| from legisqa_local.utils.formatting import format_docs | |
| def create_rag_chain(llm, retriever): | |
| """Create a RAG chain with the given LLM and retriever""" | |
| QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user. | |
| --- | |
| Congressional Legislation Excerpts: | |
| {context} | |
| --- | |
| Query: {query}""" | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("human", QUERY_RAG_TEMPLATE), | |
| ]) | |
| rag_chain = ( | |
| RunnableParallel({ | |
| "docs": retriever, | |
| "query": RunnablePassthrough(), | |
| }) | |
| .assign(context=lambda x: format_docs(x["docs"])) | |
| .assign(aimessage=prompt | llm) | |
| ) | |
| return rag_chain | |
| def process_query(gen_config: dict, ret_config: dict, query: str): | |
| """Process a query using RAG""" | |
| # Check if vectorstore is loaded | |
| vectorstore = get_vectorstore() | |
| if vectorstore is None: | |
| return { | |
| "aimessage": "⏳ Vectorstore is still loading. Please wait a moment and try again.", | |
| "docs": [], | |
| "query": query | |
| } | |
| llm = get_llm(gen_config) | |
| vs_filter = get_vectorstore_filter(ret_config) | |
| # ChromaDB uses 'filter' parameter in search_kwargs | |
| search_kwargs = {"k": ret_config["n_ret_docs"]} | |
| if vs_filter: | |
| search_kwargs["filter"] = vs_filter | |
| retriever = vectorstore.as_retriever(search_kwargs=search_kwargs) | |
| rag_chain = create_rag_chain(llm, retriever) | |
| response = rag_chain.invoke(query) | |
| return response | |