Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
import os | |
import sys | |
from omegaconf import DictConfig | |
from langchain_core.documents import Document | |
from .rag_pipeline.rag_validation import ( | |
get_embedding_model, | |
load_faiss_store, | |
attach_reranker, | |
SparseIndex, | |
get_hyde_model, | |
get_reader_model, | |
get_prompt, | |
get_rag_chain, | |
retrieve_docs_batched, | |
parse_regex, | |
) | |
class Source: | |
name: str | |
text: str | |
index_id: int | |
class RagQA: | |
def __init__(self, conf: DictConfig): | |
self.rag_chain = None | |
self.hyde_pipeline = None | |
self.retriever = None | |
self.conf = conf | |
def load(self): | |
embedding_model = get_embedding_model(self.conf) | |
faiss_store = load_faiss_store(self.conf, embedding_model) | |
self.retriever = faiss_store.as_retriever() | |
if self.conf.rag.reranking.enabled: | |
self.retriever = attach_reranker(self.conf, self.retriever) | |
self.hyde_pipeline = None | |
if self.conf.rag.hyde.enabled or self.conf.rag.summary.enabled: | |
self.hyde_pipeline = get_hyde_model(self.conf) | |
reader_model = get_reader_model(self.conf) | |
prompt = get_prompt(self.conf) | |
self.rag_chain = get_rag_chain(self.conf, reader_model, prompt) | |
def _docs_to_sources(docs: list[Document]) -> list[Source]: | |
return [ | |
Source( | |
name=doc.metadata["source_name"], | |
text=doc.metadata["original_page_content"], | |
index_id=doc.metadata["chunk_id"], | |
) | |
for doc in docs | |
] | |
def answer(self, question: str) -> tuple[str, list[Source]]: | |
docs = retrieve_docs_batched( | |
self.conf, | |
self.retriever, | |
None, | |
self.hyde_pipeline, | |
self.hyde_pipeline, | |
[question], | |
) | |
sources = self._docs_to_sources(docs[0]["docs"]) | |
chain_output = self.rag_chain.batch(docs) | |
batch_answers = [ | |
parse_regex(row["raw_output"])["answer"] for row in chain_output | |
] | |
answer = " ".join(batch_answers[0].strip().split()) | |
return answer, sources | |