Spaces:
Sleeping
Sleeping
File size: 2,196 Bytes
1536dad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from dataclasses import dataclass
import os
import sys
from omegaconf import DictConfig
from langchain_core.documents import Document
from .rag_pipeline.rag_validation import (
get_embedding_model,
load_faiss_store,
attach_reranker,
SparseIndex,
get_hyde_model,
get_reader_model,
get_prompt,
get_rag_chain,
retrieve_docs_batched,
parse_regex,
)
@dataclass
class Source:
name: str
text: str
index_id: int
class RagQA:
def __init__(self, conf: DictConfig):
self.rag_chain = None
self.hyde_pipeline = None
self.retriever = None
self.conf = conf
def load(self):
embedding_model = get_embedding_model(self.conf)
faiss_store = load_faiss_store(self.conf, embedding_model)
self.retriever = faiss_store.as_retriever()
if self.conf.rag.reranking.enabled:
self.retriever = attach_reranker(self.conf, self.retriever)
self.hyde_pipeline = None
if self.conf.rag.hyde.enabled or self.conf.rag.summary.enabled:
self.hyde_pipeline = get_hyde_model(self.conf)
reader_model = get_reader_model(self.conf)
prompt = get_prompt(self.conf)
self.rag_chain = get_rag_chain(self.conf, reader_model, prompt)
@staticmethod
def _docs_to_sources(docs: list[Document]) -> list[Source]:
return [
Source(
name=doc.metadata["source_name"],
text=doc.metadata["original_page_content"],
index_id=doc.metadata["chunk_id"],
)
for doc in docs
]
def answer(self, question: str) -> tuple[str, list[Source]]:
docs = retrieve_docs_batched(
self.conf,
self.retriever,
None,
self.hyde_pipeline,
self.hyde_pipeline,
[question],
)
sources = self._docs_to_sources(docs[0]["docs"])
chain_output = self.rag_chain.batch(docs)
batch_answers = [
parse_regex(row["raw_output"])["answer"] for row in chain_output
]
answer = " ".join(batch_answers[0].strip().split())
return answer, sources
|